sachin commited on
Commit
e546f76
·
1 Parent(s): cb5e4aa

add- inpainint

Browse files
Files changed (1) hide show
  1. intruct.py +91 -25
intruct.py CHANGED
@@ -2,9 +2,9 @@ from fastapi import FastAPI, File, UploadFile, Form
2
  from fastapi.responses import StreamingResponse
3
  import io
4
  import math
5
- from PIL import Image, ImageOps
6
  import torch
7
- from diffusers import StableDiffusionInstructPix2PixPipeline
8
  from fastapi import FastAPI, Response
9
  from fastapi.responses import FileResponse
10
  import torch
@@ -13,27 +13,30 @@ from huggingface_hub import hf_hub_download, login
13
  from safetensors.torch import load_file
14
  from io import BytesIO
15
  import os
16
- import base64 # Added for encoding images as base64
17
- from typing import List # Added for type hinting the list of prompts
18
-
19
-
20
 
21
  # Initialize FastAPI app
22
  app = FastAPI()
23
 
24
- # Load the pre-trained model once at startup
25
  model_id = "timbrooks/instruct-pix2pix"
26
- pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
27
  model_id, torch_dtype=torch.float16, safety_checker=None
28
  ).to("cuda")
29
 
 
 
 
 
 
 
30
  # Default configuration values
31
  DEFAULT_STEPS = 50
32
  DEFAULT_TEXT_CFG = 7.5
33
  DEFAULT_IMAGE_CFG = 1.5
34
  DEFAULT_SEED = 1371
35
 
36
-
37
  HF_TOKEN = os.getenv("HF_TOKEN")
38
 
39
  def load_model():
@@ -73,18 +76,16 @@ def load_model():
73
 
74
  # Load model at startup with error handling
75
  try:
76
- pipe = load_model()
77
  except Exception as e:
78
  print(f"Model initialization failed: {str(e)}")
79
  raise
80
 
81
-
82
-
83
  @app.get("/generate")
84
  async def generate_image(prompt: str):
85
  try:
86
  # Generate image
87
- image = pipe(
88
  prompt,
89
  num_inference_steps=4,
90
  guidance_scale=0
@@ -100,7 +101,6 @@ async def generate_image(prompt: str):
100
  except Exception as e:
101
  return {"error": str(e)}
102
 
103
- # New endpoint to handle a list of prompts
104
  @app.get("/generate_multiple")
105
  async def generate_multiple_images(prompts: List[str]):
106
  try:
@@ -109,7 +109,7 @@ async def generate_multiple_images(prompts: List[str]):
109
 
110
  # Generate an image for each prompt
111
  for prompt in prompts:
112
- image = pipe(
113
  prompt,
114
  num_inference_steps=4,
115
  guidance_scale=0
@@ -136,8 +136,6 @@ async def generate_multiple_images(prompts: List[str]):
136
  async def health_check():
137
  return {"status": "healthy"}
138
 
139
-
140
-
141
  def process_image(input_image: Image.Image, instruction: str, steps: int, text_cfg_scale: float, image_cfg_scale: float, seed: int):
142
  """
143
  Process the input image with the given instruction using InstructPix2Pix.
@@ -157,7 +155,7 @@ def process_image(input_image: Image.Image, instruction: str, steps: int, text_c
157
  generator = torch.manual_seed(seed)
158
 
159
  # Generate the edited image
160
- edited_image = pipe(
161
  instruction,
162
  image=input_image,
163
  guidance_scale=text_cfg_scale,
@@ -179,12 +177,6 @@ async def edit_image(
179
  ):
180
  """
181
  Endpoint to edit an image based on a text instruction.
182
- - file: The input image to edit.
183
- - instruction: The text instruction for editing the image.
184
- - steps: Number of inference steps.
185
- - text_cfg_scale: Text CFG weight.
186
- - image_cfg_scale: Image CFG weight.
187
- - seed: Random seed for reproducibility.
188
  """
189
  # Read and convert the uploaded image
190
  image_data = await file.read()
@@ -201,12 +193,86 @@ async def edit_image(
201
  # Return the image as a streaming response
202
  return StreamingResponse(img_byte_arr, media_type="image/png")
203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  @app.get("/")
205
  async def root():
206
  """
207
  Root endpoint for basic health check.
208
  """
209
- return {"message": "InstructPix2Pix API is running. Use POST /edit-image/ to edit images."}
210
 
211
  if __name__ == "__main__":
212
  import uvicorn
 
2
  from fastapi.responses import StreamingResponse
3
  import io
4
  import math
5
+ from PIL import Image, ImageOps, ImageDraw
6
  import torch
7
+ from diffusers import StableDiffusionInstructPix2PixPipeline, StableDiffusionInpaintPipeline
8
  from fastapi import FastAPI, Response
9
  from fastapi.responses import FileResponse
10
  import torch
 
13
  from safetensors.torch import load_file
14
  from io import BytesIO
15
  import os
16
+ import base64
17
+ from typing import List
 
 
18
 
19
  # Initialize FastAPI app
20
  app = FastAPI()
21
 
22
+ # Load the pre-trained InstructPix2Pix model for editing
23
  model_id = "timbrooks/instruct-pix2pix"
24
+ pipe_edit = StableDiffusionInstructPix2PixPipeline.from_pretrained(
25
  model_id, torch_dtype=torch.float16, safety_checker=None
26
  ).to("cuda")
27
 
28
+ # Load the pre-trained Inpainting model
29
+ inpaint_model_id = "stabilityai/stable-diffusion-2-inpainting"
30
+ pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(
31
+ inpaint_model_id, torch_dtype=torch.float16, safety_checker=None
32
+ ).to("cuda")
33
+
34
  # Default configuration values
35
  DEFAULT_STEPS = 50
36
  DEFAULT_TEXT_CFG = 7.5
37
  DEFAULT_IMAGE_CFG = 1.5
38
  DEFAULT_SEED = 1371
39
 
 
40
  HF_TOKEN = os.getenv("HF_TOKEN")
41
 
42
  def load_model():
 
76
 
77
  # Load model at startup with error handling
78
  try:
79
+ pipe_generate = load_model()
80
  except Exception as e:
81
  print(f"Model initialization failed: {str(e)}")
82
  raise
83
 
 
 
84
  @app.get("/generate")
85
  async def generate_image(prompt: str):
86
  try:
87
  # Generate image
88
+ image = pipe_generate(
89
  prompt,
90
  num_inference_steps=4,
91
  guidance_scale=0
 
101
  except Exception as e:
102
  return {"error": str(e)}
103
 
 
104
  @app.get("/generate_multiple")
105
  async def generate_multiple_images(prompts: List[str]):
106
  try:
 
109
 
110
  # Generate an image for each prompt
111
  for prompt in prompts:
112
+ image = pipe_generate(
113
  prompt,
114
  num_inference_steps=4,
115
  guidance_scale=0
 
136
  async def health_check():
137
  return {"status": "healthy"}
138
 
 
 
139
  def process_image(input_image: Image.Image, instruction: str, steps: int, text_cfg_scale: float, image_cfg_scale: float, seed: int):
140
  """
141
  Process the input image with the given instruction using InstructPix2Pix.
 
155
  generator = torch.manual_seed(seed)
156
 
157
  # Generate the edited image
158
+ edited_image = pipe_edit(
159
  instruction,
160
  image=input_image,
161
  guidance_scale=text_cfg_scale,
 
177
  ):
178
  """
179
  Endpoint to edit an image based on a text instruction.
 
 
 
 
 
 
180
  """
181
  # Read and convert the uploaded image
182
  image_data = await file.read()
 
193
  # Return the image as a streaming response
194
  return StreamingResponse(img_byte_arr, media_type="image/png")
195
 
196
+ # New endpoint for inpainting
197
+ @app.post("/inpaint/")
198
+ async def inpaint_image(
199
+ file: UploadFile = File(...),
200
+ prompt: str = Form(...),
201
+ mask_coordinates: str = Form(...), # Format: "x1,y1,x2,y2" (top-left and bottom-right of the rectangle to inpaint)
202
+ steps: int = Form(default=DEFAULT_STEPS),
203
+ guidance_scale: float = Form(default=7.5),
204
+ seed: int = Form(default=DEFAULT_SEED)
205
+ ):
206
+ """
207
+ Endpoint to perform inpainting on an image.
208
+ - file: The input image to inpaint.
209
+ - prompt: The text prompt describing what to generate in the inpainted area.
210
+ - mask_coordinates: Coordinates of the rectangular area to inpaint (format: "x1,y1,x2,y2").
211
+ - steps: Number of inference steps.
212
+ - guidance_scale: Guidance scale for the inpainting process.
213
+ - seed: Random seed for reproducibility.
214
+ """
215
+ try:
216
+ # Read and convert the uploaded image
217
+ image_data = await file.read()
218
+ input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
219
+
220
+ # Resize image to fit model requirements (must be divisible by 8 for inpainting)
221
+ width, height = input_image.size
222
+ factor = 512 / max(width, height)
223
+ factor = math.ceil(min(width, height) * factor / 8) * 8 / min(width, height)
224
+ width = int((width * factor) // 8) * 8
225
+ height = int((height * factor) // 8) * 8
226
+ input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
227
+
228
+ # Create a mask for inpainting
229
+ mask = Image.new("L", (width, height), 0) # Black image (0 = no inpainting)
230
+ draw = ImageDraw.Draw(mask)
231
+
232
+ # Parse the mask coordinates
233
+ try:
234
+ x1, y1, x2, y2 = map(int, mask_coordinates.split(","))
235
+ # Adjust coordinates based on resized image
236
+ x1 = int(x1 * factor)
237
+ y1 = int(y1 * factor)
238
+ x2 = int(x2 * factor)
239
+ y2 = int(y2 * factor)
240
+ except ValueError:
241
+ return {"error": "Invalid mask coordinates format. Use 'x1,y1,x2,y2'."}
242
+
243
+ # Draw a white rectangle on the mask (255 = area to inpaint)
244
+ draw.rectangle([x1, y1, x2, y2], fill=255)
245
+
246
+ # Set the random seed for reproducibility
247
+ generator = torch.manual_seed(seed)
248
+
249
+ # Perform inpainting
250
+ inpainted_image = pipe_inpaint(
251
+ prompt=prompt,
252
+ image=input_image,
253
+ mask_image=mask,
254
+ num_inference_steps=steps,
255
+ guidance_scale=guidance_scale,
256
+ generator=generator,
257
+ ).images[0]
258
+
259
+ # Convert the inpainted image to bytes
260
+ img_byte_arr = io.BytesIO()
261
+ inpainted_image.save(img_byte_arr, format="PNG")
262
+ img_byte_arr.seek(0)
263
+
264
+ # Return the image as a streaming response
265
+ return StreamingResponse(img_byte_arr, media_type="image/png")
266
+
267
+ except Exception as e:
268
+ return {"error": str(e)}
269
+
270
  @app.get("/")
271
  async def root():
272
  """
273
  Root endpoint for basic health check.
274
  """
275
+ return {"message": "InstructPix2Pix API is running. Use POST /edit-image/ or /inpaint/ to edit images."}
276
 
277
  if __name__ == "__main__":
278
  import uvicorn