sachin commited on
Commit
e9cc529
·
1 Parent(s): c66a631

improve memory management

Browse files
Files changed (1) hide show
  1. intruct.py +157 -529
intruct.py CHANGED
@@ -1,231 +1,155 @@
1
- from fastapi import FastAPI, File, UploadFile, Form
2
- from fastapi.responses import StreamingResponse
3
  import io
4
  import math
5
- from PIL import Image, ImageOps, ImageDraw
6
  import torch
7
- from diffusers import StableDiffusionInstructPix2PixPipeline, StableDiffusionInpaintPipeline
8
- from fastapi import FastAPI, Response
9
- from fastapi.responses import FileResponse
10
- import torch
11
- from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
 
 
 
12
  from huggingface_hub import hf_hub_download, login
13
  from safetensors.torch import load_file
14
- from io import BytesIO
15
- import os
16
- import base64
17
- from typing import List
18
- from fastapi import FastAPI, File, UploadFile, HTTPException
19
- from fastapi.responses import StreamingResponse
20
- from PIL import Image, ImageDraw, ImageFilter
21
- import io
22
- import torch
23
- import numpy as np
24
- from diffusers import StableDiffusionInpaintPipeline
25
- import cv2
26
-
27
- from fastapi import FastAPI, File, UploadFile, HTTPException
28
- from fastapi.responses import StreamingResponse, JSONResponse
29
- import torch
30
- from PIL import Image
31
- import io
32
- import numpy as np
33
- import cv2
34
  from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
35
  from sam2.sam2_image_predictor import SAM2ImagePredictor
36
-
37
-
38
-
39
 
40
  # Initialize FastAPI app
41
  app = FastAPI()
42
 
 
43
  device = "cuda" if torch.cuda.is_available() else "cpu"
44
 
45
- # Load Grounding DINO model and processor at startup
46
- dino_model_id = "IDEA-Research/grounding-dino-base"
47
- dino_processor = AutoProcessor.from_pretrained(dino_model_id)
48
- dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(dino_model_id).to(device)
 
 
 
 
49
 
50
- # Load SAM 2 model at startup
51
- #sam_checkpoint = "sam2.1_hiera_tiny.pt" # Replace with your checkpoint path
52
- sam_predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny")
53
- sam_predictor.model.to(device)
54
-
55
- # Default text query
56
  DEFAULT_TEXT_QUERY = "a tank."
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def process_image_with_dino(image: Image.Image, text_query: str = DEFAULT_TEXT_QUERY):
59
- """Detect objects using Grounding DINO."""
60
- inputs = dino_processor(images=image, text=text_query, return_tensors="pt").to(device)
61
  with torch.no_grad():
62
- outputs = dino_model(**inputs)
63
-
64
- # Post-process results
65
- results = dino_processor.post_process_grounded_object_detection(
66
- outputs,
67
- inputs.input_ids,
68
- threshold=0.4,
69
- text_threshold=0.3,
70
- target_sizes=[image.size[::-1]] # [width, height]
71
  )
72
- return results[0] # Single image result
73
 
74
  def segment_with_sam(image: Image.Image, boxes: list):
75
- """Segment detected objects using SAM 2 and return a mask."""
76
  image_np = np.array(image)
77
- sam_predictor.set_image(image_np)
78
-
79
  if not boxes:
80
- return np.zeros(image_np.shape[:2], dtype=bool) # Empty mask if no boxes
81
-
82
- # Convert boxes to [x_min, y_min, x_max, y_max] tensor and move to device
83
  boxes_tensor = torch.tensor(
84
  [[box["x_min"], box["y_min"], box["x_max"], box["y_max"]] for box in boxes],
85
  dtype=torch.float32
86
  ).to(device)
87
-
88
- # Predict with SAM 2 using boxes directly
89
- masks, _, _ = sam_predictor.predict(
90
- point_coords=None,
91
- point_labels=None,
92
- box=boxes_tensor, # Use 'box' argument instead of 'boxes'
93
- multimask_output=False
94
- )
95
- return masks[0] # Return the first mask directly (already a NumPy array)
96
 
97
  def create_background_mask(image_np: np.ndarray, mask: np.ndarray) -> np.ndarray:
98
- """Create an RGB mask for background removal (object preserved)."""
99
- mask_inv = np.logical_not(mask).astype(np.uint8) * 255 # Invert mask (background is white)
100
- mask_rgb = cv2.cvtColor(mask_inv, cv2.COLOR_GRAY2RGB) # Convert to RGB
101
  return mask_rgb
102
 
103
  def create_object_mask(image_np: np.ndarray, mask: np.ndarray) -> np.ndarray:
104
- """Create an RGB mask for object removal (background preserved)."""
105
- mask_rgb = cv2.cvtColor(mask.astype(np.uint8) * 255, cv2.COLOR_GRAY2RGB) # Object is white, background black
106
  return mask_rgb
107
 
108
-
109
-
110
-
111
- model_id_runway = "runwayml/stable-diffusion-inpainting"
112
- device = "cuda" if torch.cuda.is_available() else "cpu"
113
-
114
- try:
115
- pipe_runway = StableDiffusionInpaintPipeline.from_pretrained(model_id_runway)
116
- pipe_runway.to(device)
117
- except Exception as e:
118
- raise RuntimeError(f"Failed to load model: {e}")
119
-
120
-
121
-
122
- # Load the pre-trained InstructPix2Pix model for editing
123
- model_id = "timbrooks/instruct-pix2pix"
124
- pipe_edit = StableDiffusionInstructPix2PixPipeline.from_pretrained(
125
- model_id, torch_dtype=torch.float16, safety_checker=None
126
- ).to("cuda")
127
-
128
- # Load the pre-trained Inpainting model
129
- inpaint_model_id = "stabilityai/stable-diffusion-2-inpainting"
130
- pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(
131
- inpaint_model_id, torch_dtype=torch.float16, safety_checker=None
132
- ).to("cuda")
133
-
134
- # Default configuration values
135
- DEFAULT_STEPS = 50
136
- DEFAULT_TEXT_CFG = 7.5
137
- DEFAULT_IMAGE_CFG = 1.5
138
- DEFAULT_SEED = 1371
139
-
140
- HF_TOKEN = os.getenv("HF_TOKEN")
141
-
142
- def load_model():
143
- try:
144
- # Login to Hugging Face if token is provided
145
- if HF_TOKEN:
146
- login(token=HF_TOKEN)
147
-
148
- base = "stabilityai/stable-diffusion-xl-base-1.0"
149
- repo = "ByteDance/SDXL-Lightning"
150
- ckpt = "sdxl_lightning_4step_unet.safetensors"
151
-
152
- # Load model with explicit error handling
153
- unet = UNet2DConditionModel.from_config(
154
- base,
155
- subfolder="unet"
156
- ).to("cuda", torch.float16)
157
-
158
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
159
- pipe = StableDiffusionXLPipeline.from_pretrained(
160
- base,
161
- unet=unet,
162
- torch_dtype=torch.float16,
163
- variant="fp16"
164
- ).to("cuda")
165
-
166
- # Configure scheduler
167
- pipe.scheduler = EulerDiscreteScheduler.from_config(
168
- pipe.scheduler.config,
169
- timestep_spacing="trailing"
170
- )
171
-
172
- return pipe
173
-
174
- except Exception as e:
175
- raise Exception(f"Failed to load model: {str(e)}")
176
-
177
- # Load model at startup with error handling
178
- try:
179
- pipe_generate = load_model()
180
- except Exception as e:
181
- print(f"Model initialization failed: {str(e)}")
182
- raise
183
-
184
- @app.get("/generate")
185
- async def generate_image(prompt: str):
186
- try:
187
- # Generate image
188
- image = pipe_generate(
189
- prompt,
190
- num_inference_steps=4,
191
- guidance_scale=0
192
- ).images[0]
193
-
194
- # Save image to buffer
195
- buffer = BytesIO()
196
- image.save(buffer, format="PNG")
197
- buffer.seek(0)
198
-
199
- return Response(content=buffer.getvalue(), media_type="image/png")
200
-
201
- except Exception as e:
202
- return {"error": str(e)}
203
-
204
-
205
- @app.get("/health")
206
- async def health_check():
207
- return {"status": "healthy"}
208
-
209
  def process_image(input_image: Image.Image, instruction: str, steps: int, text_cfg_scale: float, image_cfg_scale: float, seed: int):
210
- """
211
- Process the input image with the given instruction using InstructPix2Pix.
212
- """
213
- # Resize image to fit model requirements
214
  width, height = input_image.size
215
  factor = 512 / max(width, height)
216
  factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
217
  width = int((width * factor) // 64) * 64
218
  height = int((height * factor) // 64) * 64
219
  input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
220
-
221
  if not instruction:
222
  return input_image
223
-
224
- # Set the random seed for reproducibility
225
  generator = torch.manual_seed(seed)
226
-
227
- # Generate the edited image
228
- edited_image = pipe_edit(
229
  instruction,
230
  image=input_image,
231
  guidance_scale=text_cfg_scale,
@@ -233,9 +157,25 @@ def process_image(input_image: Image.Image, instruction: str, steps: int, text_c
233
  num_inference_steps=steps,
234
  generator=generator,
235
  ).images[0]
236
-
237
  return edited_image
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  @app.post("/edit-image/")
240
  async def edit_image(
241
  file: UploadFile = File(...),
@@ -245,79 +185,43 @@ async def edit_image(
245
  image_cfg_scale: float = Form(default=DEFAULT_IMAGE_CFG),
246
  seed: int = Form(default=DEFAULT_SEED)
247
  ):
248
- """
249
- Endpoint to edit an image based on a text instruction.
250
- """
251
- # Read and convert the uploaded image
252
- image_data = await file.read()
253
- input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
254
-
255
- # Process the image
256
- edited_image = process_image(input_image, instruction, steps, text_cfg_scale, image_cfg_scale, seed)
257
-
258
- # Convert the edited image to bytes
259
- img_byte_arr = io.BytesIO()
260
- edited_image.save(img_byte_arr, format="PNG")
261
- img_byte_arr.seek(0)
262
-
263
- # Return the image as a streaming response
264
- return StreamingResponse(img_byte_arr, media_type="image/png")
265
 
266
- # New endpoint for inpainting
267
  @app.post("/inpaint/")
268
  async def inpaint_image(
269
  file: UploadFile = File(...),
270
  prompt: str = Form(...),
271
- mask_coordinates: str = Form(...), # Format: "x1,y1,x2,y2" (top-left and bottom-right of the rectangle to inpaint)
272
  steps: int = Form(default=DEFAULT_STEPS),
273
  guidance_scale: float = Form(default=7.5),
274
  seed: int = Form(default=DEFAULT_SEED)
275
  ):
276
- """
277
- Endpoint to perform inpainting on an image.
278
- - file: The input image to inpaint.
279
- - prompt: The text prompt describing what to generate in the inpainted area.
280
- - mask_coordinates: Coordinates of the rectangular area to inpaint (format: "x1,y1,x2,y2").
281
- - steps: Number of inference steps.
282
- - guidance_scale: Guidance scale for the inpainting process.
283
- - seed: Random seed for reproducibility.
284
- """
285
  try:
286
- # Read and convert the uploaded image
287
  image_data = await file.read()
288
  input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
289
-
290
- # Resize image to fit model requirements (must be divisible by 8 for inpainting)
291
  width, height = input_image.size
292
  factor = 512 / max(width, height)
293
  factor = math.ceil(min(width, height) * factor / 8) * 8 / min(width, height)
294
  width = int((width * factor) // 8) * 8
295
  height = int((height * factor) // 8) * 8
296
  input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
297
-
298
- # Create a mask for inpainting
299
- mask = Image.new("L", (width, height), 0) # Black image (0 = no inpainting)
300
  draw = ImageDraw.Draw(mask)
301
-
302
- # Parse the mask coordinates
303
- try:
304
- x1, y1, x2, y2 = map(int, mask_coordinates.split(","))
305
- # Adjust coordinates based on resized image
306
- x1 = int(x1 * factor)
307
- y1 = int(y1 * factor)
308
- x2 = int(x2 * factor)
309
- y2 = int(y2 * factor)
310
- except ValueError:
311
- return {"error": "Invalid mask coordinates format. Use 'x1,y1,x2,y2'."}
312
-
313
- # Draw a white rectangle on the mask (255 = area to inpaint)
314
  draw.rectangle([x1, y1, x2, y2], fill=255)
315
-
316
- # Set the random seed for reproducibility
317
  generator = torch.manual_seed(seed)
318
-
319
- # Perform inpainting
320
- inpainted_image = pipe_inpaint(
321
  prompt=prompt,
322
  image=input_image,
323
  mask_image=mask,
@@ -325,332 +229,56 @@ async def inpaint_image(
325
  guidance_scale=guidance_scale,
326
  generator=generator,
327
  ).images[0]
328
-
329
- # Convert the inpainted image to bytes
330
  img_byte_arr = io.BytesIO()
331
  inpainted_image.save(img_byte_arr, format="PNG")
332
  img_byte_arr.seek(0)
333
-
334
- # Return the image as a streaming response
335
  return StreamingResponse(img_byte_arr, media_type="image/png")
336
-
337
- except Exception as e:
338
- return {"error": str(e)}
339
-
340
- @app.get("/")
341
- async def root():
342
- """
343
- Root endpoint for basic health check.
344
- """
345
- return {"message": "InstructPix2Pix API is running. Use POST /edit-image/ or /inpaint/ to edit images."}
346
-
347
-
348
-
349
- # Helper functions
350
- def prepare_guided_image(original_image: Image, reference_image: Image, mask_image: Image) -> Image:
351
- original_array = np.array(original_image)
352
- reference_array = np.array(reference_image)
353
- mask_array = np.array(mask_image) / 255.0
354
- mask_array = mask_array[:, :, np.newaxis]
355
- blended_array = original_array * (1 - mask_array) + reference_array * mask_array
356
- return Image.fromarray(blended_array.astype(np.uint8))
357
-
358
- def soften_mask(mask_image: Image, softness: int = 5) -> Image:
359
- from PIL import ImageFilter
360
- return mask_image.filter(ImageFilter.GaussianBlur(radius=softness))
361
-
362
- def generate_rectangular_mask(image_size: tuple, x1: int = 100, y1: int = 100, x2: int = 200, y2: int = 200) -> Image:
363
- mask = Image.new("L", image_size, 0)
364
- draw = ImageDraw.Draw(mask)
365
- draw.rectangle([x1, y1, x2, y2], fill=255)
366
- return mask
367
-
368
- def segment_tank(tank_image: Image) -> tuple[Image, Image]:
369
- tank_array = np.array(tank_image.convert("RGB"))
370
- tank_array = cv2.cvtColor(tank_array, cv2.COLOR_RGB2BGR)
371
- hsv = cv2.cvtColor(tank_array, cv2.COLOR_BGR2HSV)
372
- lower_snow = np.array([0, 0, 180])
373
- upper_snow = np.array([180, 50, 255])
374
- snow_mask = cv2.inRange(hsv, lower_snow, upper_snow)
375
- tank_mask = cv2.bitwise_not(snow_mask)
376
- kernel = np.ones((5, 5), np.uint8)
377
- tank_mask = cv2.erode(tank_mask, kernel, iterations=1)
378
- tank_mask = cv2.dilate(tank_mask, kernel, iterations=1)
379
- tank_mask_image = Image.fromarray(tank_mask, mode="L")
380
- tank_array_rgb = np.array(tank_image.convert("RGB"))
381
- mask_array = tank_mask / 255.0
382
- mask_array = mask_array[:, :, np.newaxis]
383
- segmented_tank = (tank_array_rgb * mask_array).astype(np.uint8)
384
- alpha = tank_mask
385
- segmented_tank_rgba = np.zeros((tank_image.height, tank_image.width, 4), dtype=np.uint8)
386
- segmented_tank_rgba[:, :, :3] = segmented_tank
387
- segmented_tank_rgba[:, :, 3] = alpha
388
- segmented_tank_image = Image.fromarray(segmented_tank_rgba, mode="RGBA")
389
- return segmented_tank_image, tank_mask_image
390
-
391
- async def apply_camouflage_to_tank(tank_image: Image) -> Image:
392
- segmented_tank, tank_mask = segment_tank(tank_image)
393
- segmented_tank.save("segmented_tank.png")
394
- tank_mask.save("tank_mask.png")
395
- camouflaged_tank = pipe_runway(
396
- prompt="Apply a grassy camouflage pattern with shades of green and brown to the tank, preserving its structure.",
397
- image=segmented_tank.convert("RGB"),
398
- mask_image=tank_mask,
399
- strength=0.5,
400
- guidance_scale=8.0,
401
- num_inference_steps=50,
402
- negative_prompt="snow, ice, rock, stone, boat, unrelated objects"
403
- ).images[0]
404
- camouflaged_tank_rgba = np.zeros((camouflaged_tank.height, camouflaged_tank.width, 4), dtype=np.uint8)
405
- camouflaged_tank_rgba[:, :, :3] = np.array(camouflaged_tank)
406
- camouflaged_tank_rgba[:, :, 3] = np.array(tank_mask)
407
- camouflaged_tank_image = Image.fromarray(camouflaged_tank_rgba, mode="RGBA")
408
- camouflaged_tank_image.save("camouflaged_tank.png")
409
- return camouflaged_tank_image
410
-
411
- def fit_image_to_mask(original_image: Image, reference_image: Image, mask_x1: int, mask_y1: int, mask_x2: int, mask_y2: int) -> tuple:
412
- mask_width = mask_x2 - mask_x1
413
- mask_height = mask_y2 - mask_y1
414
- if mask_width <= 0 or mask_height <= 0:
415
- raise ValueError("Mask dimensions must be positive")
416
- ref_width, ref_height = reference_image.size
417
- aspect_ratio = ref_width / ref_height
418
- if mask_width / mask_height > aspect_ratio:
419
- new_height = mask_height
420
- new_width = int(new_height * aspect_ratio)
421
- else:
422
- new_width = mask_width
423
- new_height = int(new_width / aspect_ratio)
424
- reference_image_resized = reference_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
425
- guided_image = original_image.copy().convert("RGB")
426
- paste_x = mask_x1 + (mask_width - new_width) // 2
427
- paste_y = mask_y1 + (mask_height - new_height) // 2
428
- guided_image.paste(reference_image_resized, (paste_x, paste_y), reference_image_resized)
429
- mask_image = generate_rectangular_mask(original_image.size, mask_x1, mask_y1, mask_x2, mask_y2)
430
- return guided_image, mask_image
431
-
432
- # Endpoints
433
- @app.post("/inpaint/")
434
- async def inpaint_image(
435
- image: UploadFile = File(...),
436
- mask: UploadFile = File(...),
437
- prompt: str = "Fill the masked area with appropriate content."
438
- ):
439
- try:
440
- image_bytes = await image.read()
441
- mask_bytes = await mask.read()
442
- original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
443
- mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L")
444
- if original_image.size != mask_image.size:
445
- raise HTTPException(status_code=400, detail="Image and mask dimensions must match.")
446
- result = pipe_runway(prompt=prompt, image=original_image, mask_image=mask_image).images[0]
447
- result_bytes = io.BytesIO()
448
- result.save(result_bytes, format="PNG")
449
- result_bytes.seek(0)
450
- return StreamingResponse(
451
- result_bytes,
452
- media_type="image/png",
453
- headers={"Content-Disposition": "attachment; filename=inpainted_image.png"}
454
- )
455
- except Exception as e:
456
- raise HTTPException(status_code=500, detail=f"Error during inpainting: {e}")
457
-
458
- @app.post("/inpaint-with-reference/")
459
- async def inpaint_with_reference(
460
- image: UploadFile = File(...),
461
- reference_image: UploadFile = File(...),
462
- prompt: str = "Integrate the reference content naturally into the masked area, matching style and lighting.",
463
- mask_x1: int = 100,
464
- mask_y1: int = 100,
465
- mask_x2: int = 200,
466
- mask_y2: int = 200
467
- ):
468
- try:
469
- image_bytes = await image.read()
470
- reference_bytes = await reference_image.read()
471
- original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
472
- reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB")
473
- if original_image.size != reference_image.size:
474
- reference_image = reference_image.resize(original_image.size, Image.Resampling.LANCZOS)
475
- mask_image = generate_rectangular_mask(original_image.size, mask_x1, mask_y1, mask_x2, mask_y2)
476
- softened_mask = soften_mask(mask_image, softness=5)
477
- guided_image = prepare_guided_image(original_image, reference_image, softened_mask)
478
- result = pipe_runway(
479
- prompt=prompt,
480
- image=guided_image,
481
- mask_image=softened_mask,
482
- strength=0.75,
483
- guidance_scale=7.5
484
- ).images[0]
485
- result_bytes = io.BytesIO()
486
- result.save(result_bytes, format="PNG")
487
- result_bytes.seek(0)
488
- return StreamingResponse(
489
- result_bytes,
490
- media_type="image/png",
491
- headers={"Content-Disposition": "attachment; filename=natural_inpaint_image.png"}
492
- )
493
  except Exception as e:
494
- raise HTTPException(status_code=500, detail=f"Error during natural inpainting: {e}")
495
-
496
- @app.post("/fit-image-to-mask/")
497
- async def fit_image_to_mask_endpoint(
498
- image: UploadFile = File(...),
499
- reference_image: UploadFile = File(...),
500
- mask_x1: int = 200,
501
- mask_y1: int = 200,
502
- mask_x2: int = 500,
503
- mask_y2: int = 500
504
- ):
505
- try:
506
- image_bytes = await image.read()
507
- reference_bytes = await reference_image.read()
508
- original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
509
- reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB")
510
- camouflaged_tank = await apply_camouflage_to_tank(reference_image)
511
- guided_image, mask_image = fit_image_to_mask(original_image, camouflaged_tank, mask_x1, mask_y1, mask_x2, mask_y2)
512
- guided_image.save("guided_image_before_blending.png")
513
- softened_mask = soften_mask(mask_image, softness=2)
514
- result = pipe_runway(
515
- prompt="Blend the camouflaged tank into the grassy field with trees, ensuring a non-snowy environment, matching the style, lighting, and surroundings.",
516
- image=guided_image,
517
- mask_image=softened_mask,
518
- strength=0.2,
519
- guidance_scale=7.5,
520
- num_inference_steps=50,
521
- negative_prompt="snow, ice, rock, stone, boat, unrelated objects"
522
- ).images[0]
523
- result_bytes = io.BytesIO()
524
- result.save(result_bytes, format="PNG")
525
- result_bytes.seek(0)
526
- return StreamingResponse(
527
- result_bytes,
528
- media_type="image/png",
529
- headers={"Content-Disposition": "attachment; filename=fitted_image.png"}
530
- )
531
- except ValueError as ve:
532
- raise HTTPException(status_code=400, detail=f"ValueError in processing: {str(ve)}")
533
- except Exception as e:
534
- raise HTTPException(status_code=500, detail=f"Error during fitting and inpainting: {str(e)}")
535
-
536
 
537
  @app.post("/detect-json/")
538
- async def detect_json(
539
- file: UploadFile = File(..., description="Image file to process"),
540
- text_query: str = DEFAULT_TEXT_QUERY
541
- ):
542
- """Endpoint to detect objects and return bounding box information as JSON."""
543
  try:
544
- # Read and convert the uploaded image
545
  image_data = await file.read()
546
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
547
-
548
- # Process with Grounding DINO
549
  results = process_image_with_dino(image, text_query)
550
-
551
- # Format results as JSON-compatible data
552
- detections = []
553
- for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
554
- x_min, y_min, x_max, y_max = box.tolist()
555
- detections.append({
556
  "label": label,
557
- "score": float(score), # Convert tensor to float
558
- "box": {
559
- "x_min": x_min,
560
- "y_min": y_min,
561
- "x_max": x_max,
562
- "y_max": y_max
563
- }
564
- })
565
-
566
  return JSONResponse(content={"detections": detections})
567
  except Exception as e:
568
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
569
 
570
  @app.post("/segment-image/")
571
- async def segment_image(
572
- file: UploadFile = File(..., description="Image file to process"),
573
- text_query: str = DEFAULT_TEXT_QUERY
574
- ):
575
- """Endpoint to segment objects and return the image with background removed."""
576
  try:
577
- # Read and convert the uploaded image
578
  image_data = await file.read()
579
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
580
-
581
- # Detect objects with Grounding DINO
582
  results = process_image_with_dino(image, text_query)
583
-
584
- # Extract boxes for segmentation, move to CPU
585
  boxes = [
586
  {"x_min": box[0].item(), "y_min": box[1].item(), "x_max": box[2].item(), "y_max": box[3].item()}
587
- for box in results["boxes"].cpu() # Move tensor to CPU here
588
  ]
589
-
590
- # Segment with SAM 2
591
  mask = segment_with_sam(image, boxes)
592
-
593
- # Create background mask and apply it
594
  image_np = np.array(image)
595
  background_mask = create_background_mask(image_np, mask)
596
  segmented_image = cv2.bitwise_and(image_np, background_mask)
597
-
598
- # Convert to PIL Image and save to bytes
599
  output_image = Image.fromarray(segmented_image)
600
  img_byte_arr = io.BytesIO()
601
  output_image.save(img_byte_arr, format="PNG")
602
  img_byte_arr.seek(0)
603
-
604
- return StreamingResponse(
605
- img_byte_arr,
606
- media_type="image/png",
607
- headers={"Content-Disposition": "attachment; filename=segmented_image.png"}
608
- )
609
- except Exception as e:
610
- raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
611
-
612
- @app.post("/mask-object/")
613
- async def mask_object(
614
- file: UploadFile = File(..., description="Image file to process"),
615
- text_query: str = DEFAULT_TEXT_QUERY
616
- ):
617
- """Endpoint to mask the detected object and return the image with the object removed."""
618
- try:
619
- # Read and convert the uploaded image
620
- image_data = await file.read()
621
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
622
-
623
- # Detect objects with Grounding DINO
624
- results = process_image_with_dino(image, text_query)
625
-
626
- # Extract boxes for segmentation, move to CPU
627
- boxes = [
628
- {"x_min": box[0].item(), "y_min": box[1].item(), "x_max": box[2].item(), "y_max": box[3].item()}
629
- for box in results["boxes"].cpu() # Move tensor to CPU here
630
- ]
631
-
632
- # Segment with SAM 2
633
- mask = segment_with_sam(image, boxes)
634
-
635
- # Create object mask and apply it
636
- image_np = np.array(image)
637
- object_mask = create_object_mask(image_np, mask)
638
- masked_image = cv2.bitwise_and(image_np, object_mask)
639
-
640
- # Convert to PIL Image and save to bytes
641
- output_image = Image.fromarray(masked_image)
642
- img_byte_arr = io.BytesIO()
643
- output_image.save(img_byte_arr, format="PNG")
644
- img_byte_arr.seek(0)
645
-
646
- return StreamingResponse(
647
- img_byte_arr,
648
- media_type="image/png",
649
- headers={"Content-Disposition": "attachment; filename=masked_object_image.png"}
650
- )
651
  except Exception as e:
652
- raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
653
 
 
654
 
655
  if __name__ == "__main__":
656
  import uvicorn
 
1
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
2
+ from fastapi.responses import StreamingResponse, JSONResponse, Response
3
  import io
4
  import math
5
+ from PIL import Image, ImageOps, ImageDraw, ImageFilter
6
  import torch
7
+ import numpy as np
8
+ from diffusers import (
9
+ StableDiffusionInstructPix2PixPipeline,
10
+ StableDiffusionInpaintPipeline,
11
+ StableDiffusionXLPipeline,
12
+ UNet2DConditionModel,
13
+ EulerDiscreteScheduler,
14
+ )
15
  from huggingface_hub import hf_hub_download, login
16
  from safetensors.torch import load_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
18
  from sam2.sam2_image_predictor import SAM2ImagePredictor
19
+ import cv2
20
+ import os
21
+ from typing import Optional
22
 
23
  # Initialize FastAPI app
24
  app = FastAPI()
25
 
26
+ # Device configuration
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
+ # Model variables (initially None, loaded lazily)
30
+ pipe_edit = None # InstructPix2Pix
31
+ pipe_inpaint = None # Stable Diffusion Inpainting
32
+ pipe_generate = None # Stable Diffusion XL
33
+ pipe_runway = None # Runway Inpainting
34
+ dino_processor = None # Grounding DINO processor
35
+ dino_model = None # Grounding DINO model
36
+ sam_predictor = None # SAM 2 predictor
37
 
38
+ # Default configuration values
39
+ DEFAULT_STEPS = 50
40
+ DEFAULT_TEXT_CFG = 7.5
41
+ DEFAULT_IMAGE_CFG = 1.5
42
+ DEFAULT_SEED = 1371
 
43
  DEFAULT_TEXT_QUERY = "a tank."
44
+ HF_TOKEN = os.getenv("HF_TOKEN")
45
 
46
+ # Helper functions for lazy loading
47
+ def load_instruct_pix2pix() -> StableDiffusionInstructPix2PixPipeline:
48
+ global pipe_edit
49
+ if pipe_edit is None:
50
+ model_id = "timbrooks/instruct-pix2pix"
51
+ pipe_edit = StableDiffusionInstructPix2PixPipeline.from_pretrained(
52
+ model_id, torch_dtype=torch.float16, safety_checker=None
53
+ ).to(device)
54
+ return pipe_edit
55
+
56
+ def load_inpaint_pipeline() -> StableDiffusionInpaintPipeline:
57
+ global pipe_inpaint
58
+ if pipe_inpaint is None:
59
+ inpaint_model_id = "stabilityai/stable-diffusion-2-inpainting"
60
+ pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(
61
+ inpaint_model_id, torch_dtype=torch.float16, safety_checker=None
62
+ ).to(device)
63
+ return pipe_inpaint
64
+
65
+ def load_generate_pipeline() -> StableDiffusionXLPipeline:
66
+ global pipe_generate
67
+ if pipe_generate is None:
68
+ try:
69
+ if HF_TOKEN:
70
+ login(token=HF_TOKEN)
71
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
72
+ repo = "ByteDance/SDXL-Lightning"
73
+ ckpt = "sdxl_lightning_4step_unet.safetensors"
74
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16)
75
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
76
+ pipe_generate = StableDiffusionXLPipeline.from_pretrained(
77
+ base, unet=unet, torch_dtype=torch.float16, variant="fp16"
78
+ ).to(device)
79
+ pipe_generate.scheduler = EulerDiscreteScheduler.from_config(
80
+ pipe_generate.scheduler.config, timestep_spacing="trailing"
81
+ )
82
+ except Exception as e:
83
+ raise RuntimeError(f"Failed to load generate pipeline: {str(e)}")
84
+ return pipe_generate
85
+
86
+ def load_runway_inpaint() -> StableDiffusionInpaintPipeline:
87
+ global pipe_runway
88
+ if pipe_runway is None:
89
+ model_id_runway = "runwayml/stable-diffusion-inpainting"
90
+ pipe_runway = StableDiffusionInpaintPipeline.from_pretrained(model_id_runway).to(device)
91
+ return pipe_runway
92
+
93
+ def load_dino() -> tuple[AutoProcessor, AutoModelForZeroShotObjectDetection]:
94
+ global dino_processor, dino_model
95
+ if dino_processor is None or dino_model is None:
96
+ dino_model_id = "IDEA-Research/grounding-dino-base"
97
+ dino_processor = AutoProcessor.from_pretrained(dino_model_id)
98
+ dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(dino_model_id).to(device)
99
+ return dino_processor, dino_model
100
+
101
+ def load_sam() -> SAM2ImagePredictor:
102
+ global sam_predictor
103
+ if sam_predictor is None:
104
+ sam_predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny")
105
+ sam_predictor.model.to(device)
106
+ return sam_predictor
107
+
108
+ # Image processing helper functions (unchanged, included for completeness)
109
  def process_image_with_dino(image: Image.Image, text_query: str = DEFAULT_TEXT_QUERY):
110
+ processor, model = load_dino()
111
+ inputs = processor(images=image, text=text_query, return_tensors="pt").to(device)
112
  with torch.no_grad():
113
+ outputs = model(**inputs)
114
+ results = processor.post_process_grounded_object_detection(
115
+ outputs, inputs.input_ids, threshold=0.4, text_threshold=0.3, target_sizes=[image.size[::-1]]
 
 
 
 
 
 
116
  )
117
+ return results[0]
118
 
119
  def segment_with_sam(image: Image.Image, boxes: list):
120
+ predictor = load_sam()
121
  image_np = np.array(image)
122
+ predictor.set_image(image_np)
 
123
  if not boxes:
124
+ return np.zeros(image_np.shape[:2], dtype=bool)
 
 
125
  boxes_tensor = torch.tensor(
126
  [[box["x_min"], box["y_min"], box["x_max"], box["y_max"]] for box in boxes],
127
  dtype=torch.float32
128
  ).to(device)
129
+ masks, _, _ = predictor.predict(point_coords=None, point_labels=None, box=boxes_tensor, multimask_output=False)
130
+ return masks[0]
 
 
 
 
 
 
 
131
 
132
  def create_background_mask(image_np: np.ndarray, mask: np.ndarray) -> np.ndarray:
133
+ mask_inv = np.logical_not(mask).astype(np.uint8) * 255
134
+ mask_rgb = cv2.cvtColor(mask_inv, cv2.COLOR_GRAY2RGB)
 
135
  return mask_rgb
136
 
137
  def create_object_mask(image_np: np.ndarray, mask: np.ndarray) -> np.ndarray:
138
+ mask_rgb = cv2.cvtColor(mask.astype(np.uint8) * 255, cv2.COLOR_GRAY2RGB)
 
139
  return mask_rgb
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  def process_image(input_image: Image.Image, instruction: str, steps: int, text_cfg_scale: float, image_cfg_scale: float, seed: int):
 
 
 
 
142
  width, height = input_image.size
143
  factor = 512 / max(width, height)
144
  factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
145
  width = int((width * factor) // 64) * 64
146
  height = int((height * factor) // 64) * 64
147
  input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
 
148
  if not instruction:
149
  return input_image
 
 
150
  generator = torch.manual_seed(seed)
151
+ pipe = load_instruct_pix2pix()
152
+ edited_image = pipe(
 
153
  instruction,
154
  image=input_image,
155
  guidance_scale=text_cfg_scale,
 
157
  num_inference_steps=steps,
158
  generator=generator,
159
  ).images[0]
 
160
  return edited_image
161
 
162
+ # Endpoints
163
+ @app.get("/generate")
164
+ async def generate_image(prompt: str):
165
+ try:
166
+ pipe = load_generate_pipeline()
167
+ image = pipe(prompt, num_inference_steps=4, guidance_scale=0).images[0]
168
+ buffer = io.BytesIO()
169
+ image.save(buffer, format="PNG")
170
+ buffer.seek(0)
171
+ return Response(content=buffer.getvalue(), media_type="image/png")
172
+ except Exception as e:
173
+ raise HTTPException(status_code=500, detail=f"Error generating image: {str(e)}")
174
+
175
+ @app.get("/health")
176
+ async def health_check():
177
+ return {"status": "healthy"}
178
+
179
  @app.post("/edit-image/")
180
  async def edit_image(
181
  file: UploadFile = File(...),
 
185
  image_cfg_scale: float = Form(default=DEFAULT_IMAGE_CFG),
186
  seed: int = Form(default=DEFAULT_SEED)
187
  ):
188
+ try:
189
+ image_data = await file.read()
190
+ input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
191
+ edited_image = process_image(input_image, instruction, steps, text_cfg_scale, image_cfg_scale, seed)
192
+ img_byte_arr = io.BytesIO()
193
+ edited_image.save(img_byte_arr, format="PNG")
194
+ img_byte_arr.seek(0)
195
+ return StreamingResponse(img_byte_arr, media_type="image/png")
196
+ except Exception as e:
197
+ raise HTTPException(status_code=500, detail=f"Error editing image: {str(e)}")
 
 
 
 
 
 
 
198
 
 
199
  @app.post("/inpaint/")
200
  async def inpaint_image(
201
  file: UploadFile = File(...),
202
  prompt: str = Form(...),
203
+ mask_coordinates: str = Form(...),
204
  steps: int = Form(default=DEFAULT_STEPS),
205
  guidance_scale: float = Form(default=7.5),
206
  seed: int = Form(default=DEFAULT_SEED)
207
  ):
 
 
 
 
 
 
 
 
 
208
  try:
 
209
  image_data = await file.read()
210
  input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
 
 
211
  width, height = input_image.size
212
  factor = 512 / max(width, height)
213
  factor = math.ceil(min(width, height) * factor / 8) * 8 / min(width, height)
214
  width = int((width * factor) // 8) * 8
215
  height = int((height * factor) // 8) * 8
216
  input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
217
+ mask = Image.new("L", (width, height), 0)
 
 
218
  draw = ImageDraw.Draw(mask)
219
+ x1, y1, x2, y2 = map(int, mask_coordinates.split(","))
220
+ x1, y1, x2, y2 = int(x1 * factor), int(y1 * factor), int(x2 * factor), int(y2 * factor)
 
 
 
 
 
 
 
 
 
 
 
221
  draw.rectangle([x1, y1, x2, y2], fill=255)
 
 
222
  generator = torch.manual_seed(seed)
223
+ pipe = load_inpaint_pipeline()
224
+ inpainted_image = pipe(
 
225
  prompt=prompt,
226
  image=input_image,
227
  mask_image=mask,
 
229
  guidance_scale=guidance_scale,
230
  generator=generator,
231
  ).images[0]
 
 
232
  img_byte_arr = io.BytesIO()
233
  inpainted_image.save(img_byte_arr, format="PNG")
234
  img_byte_arr.seek(0)
 
 
235
  return StreamingResponse(img_byte_arr, media_type="image/png")
236
+ except ValueError:
237
+ raise HTTPException(status_code=400, detail="Invalid mask coordinates format. Use 'x1,y1,x2,y2'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  except Exception as e:
239
+ raise HTTPException(status_code=500, detail=f"Error inpainting image: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  @app.post("/detect-json/")
242
+ async def detect_json(file: UploadFile = File(...), text_query: str = DEFAULT_TEXT_QUERY):
 
 
 
 
243
  try:
 
244
  image_data = await file.read()
245
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
 
 
246
  results = process_image_with_dino(image, text_query)
247
+ detections = [
248
+ {
 
 
 
 
249
  "label": label,
250
+ "score": float(score),
251
+ "box": {"x_min": box[0].item(), "y_min": box[1].item(), "x_max": box[2].item(), "y_max": box[3].item()}
252
+ }
253
+ for box, label, score in zip(results["boxes"].cpu(), results["labels"], results["scores"])
254
+ ]
 
 
 
 
255
  return JSONResponse(content={"detections": detections})
256
  except Exception as e:
257
  raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
258
 
259
  @app.post("/segment-image/")
260
+ async def segment_image(file: UploadFile = File(...), text_query: str = DEFAULT_TEXT_QUERY):
 
 
 
 
261
  try:
 
262
  image_data = await file.read()
263
  image = Image.open(io.BytesIO(image_data)).convert("RGB")
 
 
264
  results = process_image_with_dino(image, text_query)
 
 
265
  boxes = [
266
  {"x_min": box[0].item(), "y_min": box[1].item(), "x_max": box[2].item(), "y_max": box[3].item()}
267
+ for box in results["boxes"].cpu()
268
  ]
 
 
269
  mask = segment_with_sam(image, boxes)
 
 
270
  image_np = np.array(image)
271
  background_mask = create_background_mask(image_np, mask)
272
  segmented_image = cv2.bitwise_and(image_np, background_mask)
 
 
273
  output_image = Image.fromarray(segmented_image)
274
  img_byte_arr = io.BytesIO()
275
  output_image.save(img_byte_arr, format="PNG")
276
  img_byte_arr.seek(0)
277
+ return StreamingResponse(img_byte_arr, media_type="image/png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  except Exception as e:
279
+ raise HTTPException(status_code=500, detail=f"Error segmenting image: {str(e)}")
280
 
281
+ # Add other endpoints (e.g., /mask-object/, /fit-image-to-mask/) with similar lazy loading patterns as needed
282
 
283
  if __name__ == "__main__":
284
  import uvicorn