Spaces:
Paused
Paused
sachin
commited on
Commit
·
e9cc529
1
Parent(s):
c66a631
improve memory management
Browse files- intruct.py +157 -529
intruct.py
CHANGED
@@ -1,231 +1,155 @@
|
|
1 |
-
from fastapi import FastAPI, File, UploadFile, Form
|
2 |
-
from fastapi.responses import StreamingResponse
|
3 |
import io
|
4 |
import math
|
5 |
-
from PIL import Image, ImageOps, ImageDraw
|
6 |
import torch
|
7 |
-
|
8 |
-
from
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
12 |
from huggingface_hub import hf_hub_download, login
|
13 |
from safetensors.torch import load_file
|
14 |
-
from io import BytesIO
|
15 |
-
import os
|
16 |
-
import base64
|
17 |
-
from typing import List
|
18 |
-
from fastapi import FastAPI, File, UploadFile, HTTPException
|
19 |
-
from fastapi.responses import StreamingResponse
|
20 |
-
from PIL import Image, ImageDraw, ImageFilter
|
21 |
-
import io
|
22 |
-
import torch
|
23 |
-
import numpy as np
|
24 |
-
from diffusers import StableDiffusionInpaintPipeline
|
25 |
-
import cv2
|
26 |
-
|
27 |
-
from fastapi import FastAPI, File, UploadFile, HTTPException
|
28 |
-
from fastapi.responses import StreamingResponse, JSONResponse
|
29 |
-
import torch
|
30 |
-
from PIL import Image
|
31 |
-
import io
|
32 |
-
import numpy as np
|
33 |
-
import cv2
|
34 |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
35 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
|
40 |
# Initialize FastAPI app
|
41 |
app = FastAPI()
|
42 |
|
|
|
43 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
#
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
# Default text query
|
56 |
DEFAULT_TEXT_QUERY = "a tank."
|
|
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def process_image_with_dino(image: Image.Image, text_query: str = DEFAULT_TEXT_QUERY):
|
59 |
-
|
60 |
-
inputs =
|
61 |
with torch.no_grad():
|
62 |
-
outputs =
|
63 |
-
|
64 |
-
|
65 |
-
results = dino_processor.post_process_grounded_object_detection(
|
66 |
-
outputs,
|
67 |
-
inputs.input_ids,
|
68 |
-
threshold=0.4,
|
69 |
-
text_threshold=0.3,
|
70 |
-
target_sizes=[image.size[::-1]] # [width, height]
|
71 |
)
|
72 |
-
return results[0]
|
73 |
|
74 |
def segment_with_sam(image: Image.Image, boxes: list):
|
75 |
-
|
76 |
image_np = np.array(image)
|
77 |
-
|
78 |
-
|
79 |
if not boxes:
|
80 |
-
return np.zeros(image_np.shape[:2], dtype=bool)
|
81 |
-
|
82 |
-
# Convert boxes to [x_min, y_min, x_max, y_max] tensor and move to device
|
83 |
boxes_tensor = torch.tensor(
|
84 |
[[box["x_min"], box["y_min"], box["x_max"], box["y_max"]] for box in boxes],
|
85 |
dtype=torch.float32
|
86 |
).to(device)
|
87 |
-
|
88 |
-
|
89 |
-
masks, _, _ = sam_predictor.predict(
|
90 |
-
point_coords=None,
|
91 |
-
point_labels=None,
|
92 |
-
box=boxes_tensor, # Use 'box' argument instead of 'boxes'
|
93 |
-
multimask_output=False
|
94 |
-
)
|
95 |
-
return masks[0] # Return the first mask directly (already a NumPy array)
|
96 |
|
97 |
def create_background_mask(image_np: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
98 |
-
|
99 |
-
|
100 |
-
mask_rgb = cv2.cvtColor(mask_inv, cv2.COLOR_GRAY2RGB) # Convert to RGB
|
101 |
return mask_rgb
|
102 |
|
103 |
def create_object_mask(image_np: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
104 |
-
|
105 |
-
mask_rgb = cv2.cvtColor(mask.astype(np.uint8) * 255, cv2.COLOR_GRAY2RGB) # Object is white, background black
|
106 |
return mask_rgb
|
107 |
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
model_id_runway = "runwayml/stable-diffusion-inpainting"
|
112 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
113 |
-
|
114 |
-
try:
|
115 |
-
pipe_runway = StableDiffusionInpaintPipeline.from_pretrained(model_id_runway)
|
116 |
-
pipe_runway.to(device)
|
117 |
-
except Exception as e:
|
118 |
-
raise RuntimeError(f"Failed to load model: {e}")
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
# Load the pre-trained InstructPix2Pix model for editing
|
123 |
-
model_id = "timbrooks/instruct-pix2pix"
|
124 |
-
pipe_edit = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
125 |
-
model_id, torch_dtype=torch.float16, safety_checker=None
|
126 |
-
).to("cuda")
|
127 |
-
|
128 |
-
# Load the pre-trained Inpainting model
|
129 |
-
inpaint_model_id = "stabilityai/stable-diffusion-2-inpainting"
|
130 |
-
pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
131 |
-
inpaint_model_id, torch_dtype=torch.float16, safety_checker=None
|
132 |
-
).to("cuda")
|
133 |
-
|
134 |
-
# Default configuration values
|
135 |
-
DEFAULT_STEPS = 50
|
136 |
-
DEFAULT_TEXT_CFG = 7.5
|
137 |
-
DEFAULT_IMAGE_CFG = 1.5
|
138 |
-
DEFAULT_SEED = 1371
|
139 |
-
|
140 |
-
HF_TOKEN = os.getenv("HF_TOKEN")
|
141 |
-
|
142 |
-
def load_model():
|
143 |
-
try:
|
144 |
-
# Login to Hugging Face if token is provided
|
145 |
-
if HF_TOKEN:
|
146 |
-
login(token=HF_TOKEN)
|
147 |
-
|
148 |
-
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
149 |
-
repo = "ByteDance/SDXL-Lightning"
|
150 |
-
ckpt = "sdxl_lightning_4step_unet.safetensors"
|
151 |
-
|
152 |
-
# Load model with explicit error handling
|
153 |
-
unet = UNet2DConditionModel.from_config(
|
154 |
-
base,
|
155 |
-
subfolder="unet"
|
156 |
-
).to("cuda", torch.float16)
|
157 |
-
|
158 |
-
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
|
159 |
-
pipe = StableDiffusionXLPipeline.from_pretrained(
|
160 |
-
base,
|
161 |
-
unet=unet,
|
162 |
-
torch_dtype=torch.float16,
|
163 |
-
variant="fp16"
|
164 |
-
).to("cuda")
|
165 |
-
|
166 |
-
# Configure scheduler
|
167 |
-
pipe.scheduler = EulerDiscreteScheduler.from_config(
|
168 |
-
pipe.scheduler.config,
|
169 |
-
timestep_spacing="trailing"
|
170 |
-
)
|
171 |
-
|
172 |
-
return pipe
|
173 |
-
|
174 |
-
except Exception as e:
|
175 |
-
raise Exception(f"Failed to load model: {str(e)}")
|
176 |
-
|
177 |
-
# Load model at startup with error handling
|
178 |
-
try:
|
179 |
-
pipe_generate = load_model()
|
180 |
-
except Exception as e:
|
181 |
-
print(f"Model initialization failed: {str(e)}")
|
182 |
-
raise
|
183 |
-
|
184 |
-
@app.get("/generate")
|
185 |
-
async def generate_image(prompt: str):
|
186 |
-
try:
|
187 |
-
# Generate image
|
188 |
-
image = pipe_generate(
|
189 |
-
prompt,
|
190 |
-
num_inference_steps=4,
|
191 |
-
guidance_scale=0
|
192 |
-
).images[0]
|
193 |
-
|
194 |
-
# Save image to buffer
|
195 |
-
buffer = BytesIO()
|
196 |
-
image.save(buffer, format="PNG")
|
197 |
-
buffer.seek(0)
|
198 |
-
|
199 |
-
return Response(content=buffer.getvalue(), media_type="image/png")
|
200 |
-
|
201 |
-
except Exception as e:
|
202 |
-
return {"error": str(e)}
|
203 |
-
|
204 |
-
|
205 |
-
@app.get("/health")
|
206 |
-
async def health_check():
|
207 |
-
return {"status": "healthy"}
|
208 |
-
|
209 |
def process_image(input_image: Image.Image, instruction: str, steps: int, text_cfg_scale: float, image_cfg_scale: float, seed: int):
|
210 |
-
"""
|
211 |
-
Process the input image with the given instruction using InstructPix2Pix.
|
212 |
-
"""
|
213 |
-
# Resize image to fit model requirements
|
214 |
width, height = input_image.size
|
215 |
factor = 512 / max(width, height)
|
216 |
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
|
217 |
width = int((width * factor) // 64) * 64
|
218 |
height = int((height * factor) // 64) * 64
|
219 |
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
|
220 |
-
|
221 |
if not instruction:
|
222 |
return input_image
|
223 |
-
|
224 |
-
# Set the random seed for reproducibility
|
225 |
generator = torch.manual_seed(seed)
|
226 |
-
|
227 |
-
|
228 |
-
edited_image = pipe_edit(
|
229 |
instruction,
|
230 |
image=input_image,
|
231 |
guidance_scale=text_cfg_scale,
|
@@ -233,9 +157,25 @@ def process_image(input_image: Image.Image, instruction: str, steps: int, text_c
|
|
233 |
num_inference_steps=steps,
|
234 |
generator=generator,
|
235 |
).images[0]
|
236 |
-
|
237 |
return edited_image
|
238 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
239 |
@app.post("/edit-image/")
|
240 |
async def edit_image(
|
241 |
file: UploadFile = File(...),
|
@@ -245,79 +185,43 @@ async def edit_image(
|
|
245 |
image_cfg_scale: float = Form(default=DEFAULT_IMAGE_CFG),
|
246 |
seed: int = Form(default=DEFAULT_SEED)
|
247 |
):
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
# Convert the edited image to bytes
|
259 |
-
img_byte_arr = io.BytesIO()
|
260 |
-
edited_image.save(img_byte_arr, format="PNG")
|
261 |
-
img_byte_arr.seek(0)
|
262 |
-
|
263 |
-
# Return the image as a streaming response
|
264 |
-
return StreamingResponse(img_byte_arr, media_type="image/png")
|
265 |
|
266 |
-
# New endpoint for inpainting
|
267 |
@app.post("/inpaint/")
|
268 |
async def inpaint_image(
|
269 |
file: UploadFile = File(...),
|
270 |
prompt: str = Form(...),
|
271 |
-
mask_coordinates: str = Form(...),
|
272 |
steps: int = Form(default=DEFAULT_STEPS),
|
273 |
guidance_scale: float = Form(default=7.5),
|
274 |
seed: int = Form(default=DEFAULT_SEED)
|
275 |
):
|
276 |
-
"""
|
277 |
-
Endpoint to perform inpainting on an image.
|
278 |
-
- file: The input image to inpaint.
|
279 |
-
- prompt: The text prompt describing what to generate in the inpainted area.
|
280 |
-
- mask_coordinates: Coordinates of the rectangular area to inpaint (format: "x1,y1,x2,y2").
|
281 |
-
- steps: Number of inference steps.
|
282 |
-
- guidance_scale: Guidance scale for the inpainting process.
|
283 |
-
- seed: Random seed for reproducibility.
|
284 |
-
"""
|
285 |
try:
|
286 |
-
# Read and convert the uploaded image
|
287 |
image_data = await file.read()
|
288 |
input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
289 |
-
|
290 |
-
# Resize image to fit model requirements (must be divisible by 8 for inpainting)
|
291 |
width, height = input_image.size
|
292 |
factor = 512 / max(width, height)
|
293 |
factor = math.ceil(min(width, height) * factor / 8) * 8 / min(width, height)
|
294 |
width = int((width * factor) // 8) * 8
|
295 |
height = int((height * factor) // 8) * 8
|
296 |
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
|
297 |
-
|
298 |
-
# Create a mask for inpainting
|
299 |
-
mask = Image.new("L", (width, height), 0) # Black image (0 = no inpainting)
|
300 |
draw = ImageDraw.Draw(mask)
|
301 |
-
|
302 |
-
|
303 |
-
try:
|
304 |
-
x1, y1, x2, y2 = map(int, mask_coordinates.split(","))
|
305 |
-
# Adjust coordinates based on resized image
|
306 |
-
x1 = int(x1 * factor)
|
307 |
-
y1 = int(y1 * factor)
|
308 |
-
x2 = int(x2 * factor)
|
309 |
-
y2 = int(y2 * factor)
|
310 |
-
except ValueError:
|
311 |
-
return {"error": "Invalid mask coordinates format. Use 'x1,y1,x2,y2'."}
|
312 |
-
|
313 |
-
# Draw a white rectangle on the mask (255 = area to inpaint)
|
314 |
draw.rectangle([x1, y1, x2, y2], fill=255)
|
315 |
-
|
316 |
-
# Set the random seed for reproducibility
|
317 |
generator = torch.manual_seed(seed)
|
318 |
-
|
319 |
-
|
320 |
-
inpainted_image = pipe_inpaint(
|
321 |
prompt=prompt,
|
322 |
image=input_image,
|
323 |
mask_image=mask,
|
@@ -325,332 +229,56 @@ async def inpaint_image(
|
|
325 |
guidance_scale=guidance_scale,
|
326 |
generator=generator,
|
327 |
).images[0]
|
328 |
-
|
329 |
-
# Convert the inpainted image to bytes
|
330 |
img_byte_arr = io.BytesIO()
|
331 |
inpainted_image.save(img_byte_arr, format="PNG")
|
332 |
img_byte_arr.seek(0)
|
333 |
-
|
334 |
-
# Return the image as a streaming response
|
335 |
return StreamingResponse(img_byte_arr, media_type="image/png")
|
336 |
-
|
337 |
-
|
338 |
-
return {"error": str(e)}
|
339 |
-
|
340 |
-
@app.get("/")
|
341 |
-
async def root():
|
342 |
-
"""
|
343 |
-
Root endpoint for basic health check.
|
344 |
-
"""
|
345 |
-
return {"message": "InstructPix2Pix API is running. Use POST /edit-image/ or /inpaint/ to edit images."}
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
# Helper functions
|
350 |
-
def prepare_guided_image(original_image: Image, reference_image: Image, mask_image: Image) -> Image:
|
351 |
-
original_array = np.array(original_image)
|
352 |
-
reference_array = np.array(reference_image)
|
353 |
-
mask_array = np.array(mask_image) / 255.0
|
354 |
-
mask_array = mask_array[:, :, np.newaxis]
|
355 |
-
blended_array = original_array * (1 - mask_array) + reference_array * mask_array
|
356 |
-
return Image.fromarray(blended_array.astype(np.uint8))
|
357 |
-
|
358 |
-
def soften_mask(mask_image: Image, softness: int = 5) -> Image:
|
359 |
-
from PIL import ImageFilter
|
360 |
-
return mask_image.filter(ImageFilter.GaussianBlur(radius=softness))
|
361 |
-
|
362 |
-
def generate_rectangular_mask(image_size: tuple, x1: int = 100, y1: int = 100, x2: int = 200, y2: int = 200) -> Image:
|
363 |
-
mask = Image.new("L", image_size, 0)
|
364 |
-
draw = ImageDraw.Draw(mask)
|
365 |
-
draw.rectangle([x1, y1, x2, y2], fill=255)
|
366 |
-
return mask
|
367 |
-
|
368 |
-
def segment_tank(tank_image: Image) -> tuple[Image, Image]:
|
369 |
-
tank_array = np.array(tank_image.convert("RGB"))
|
370 |
-
tank_array = cv2.cvtColor(tank_array, cv2.COLOR_RGB2BGR)
|
371 |
-
hsv = cv2.cvtColor(tank_array, cv2.COLOR_BGR2HSV)
|
372 |
-
lower_snow = np.array([0, 0, 180])
|
373 |
-
upper_snow = np.array([180, 50, 255])
|
374 |
-
snow_mask = cv2.inRange(hsv, lower_snow, upper_snow)
|
375 |
-
tank_mask = cv2.bitwise_not(snow_mask)
|
376 |
-
kernel = np.ones((5, 5), np.uint8)
|
377 |
-
tank_mask = cv2.erode(tank_mask, kernel, iterations=1)
|
378 |
-
tank_mask = cv2.dilate(tank_mask, kernel, iterations=1)
|
379 |
-
tank_mask_image = Image.fromarray(tank_mask, mode="L")
|
380 |
-
tank_array_rgb = np.array(tank_image.convert("RGB"))
|
381 |
-
mask_array = tank_mask / 255.0
|
382 |
-
mask_array = mask_array[:, :, np.newaxis]
|
383 |
-
segmented_tank = (tank_array_rgb * mask_array).astype(np.uint8)
|
384 |
-
alpha = tank_mask
|
385 |
-
segmented_tank_rgba = np.zeros((tank_image.height, tank_image.width, 4), dtype=np.uint8)
|
386 |
-
segmented_tank_rgba[:, :, :3] = segmented_tank
|
387 |
-
segmented_tank_rgba[:, :, 3] = alpha
|
388 |
-
segmented_tank_image = Image.fromarray(segmented_tank_rgba, mode="RGBA")
|
389 |
-
return segmented_tank_image, tank_mask_image
|
390 |
-
|
391 |
-
async def apply_camouflage_to_tank(tank_image: Image) -> Image:
|
392 |
-
segmented_tank, tank_mask = segment_tank(tank_image)
|
393 |
-
segmented_tank.save("segmented_tank.png")
|
394 |
-
tank_mask.save("tank_mask.png")
|
395 |
-
camouflaged_tank = pipe_runway(
|
396 |
-
prompt="Apply a grassy camouflage pattern with shades of green and brown to the tank, preserving its structure.",
|
397 |
-
image=segmented_tank.convert("RGB"),
|
398 |
-
mask_image=tank_mask,
|
399 |
-
strength=0.5,
|
400 |
-
guidance_scale=8.0,
|
401 |
-
num_inference_steps=50,
|
402 |
-
negative_prompt="snow, ice, rock, stone, boat, unrelated objects"
|
403 |
-
).images[0]
|
404 |
-
camouflaged_tank_rgba = np.zeros((camouflaged_tank.height, camouflaged_tank.width, 4), dtype=np.uint8)
|
405 |
-
camouflaged_tank_rgba[:, :, :3] = np.array(camouflaged_tank)
|
406 |
-
camouflaged_tank_rgba[:, :, 3] = np.array(tank_mask)
|
407 |
-
camouflaged_tank_image = Image.fromarray(camouflaged_tank_rgba, mode="RGBA")
|
408 |
-
camouflaged_tank_image.save("camouflaged_tank.png")
|
409 |
-
return camouflaged_tank_image
|
410 |
-
|
411 |
-
def fit_image_to_mask(original_image: Image, reference_image: Image, mask_x1: int, mask_y1: int, mask_x2: int, mask_y2: int) -> tuple:
|
412 |
-
mask_width = mask_x2 - mask_x1
|
413 |
-
mask_height = mask_y2 - mask_y1
|
414 |
-
if mask_width <= 0 or mask_height <= 0:
|
415 |
-
raise ValueError("Mask dimensions must be positive")
|
416 |
-
ref_width, ref_height = reference_image.size
|
417 |
-
aspect_ratio = ref_width / ref_height
|
418 |
-
if mask_width / mask_height > aspect_ratio:
|
419 |
-
new_height = mask_height
|
420 |
-
new_width = int(new_height * aspect_ratio)
|
421 |
-
else:
|
422 |
-
new_width = mask_width
|
423 |
-
new_height = int(new_width / aspect_ratio)
|
424 |
-
reference_image_resized = reference_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
425 |
-
guided_image = original_image.copy().convert("RGB")
|
426 |
-
paste_x = mask_x1 + (mask_width - new_width) // 2
|
427 |
-
paste_y = mask_y1 + (mask_height - new_height) // 2
|
428 |
-
guided_image.paste(reference_image_resized, (paste_x, paste_y), reference_image_resized)
|
429 |
-
mask_image = generate_rectangular_mask(original_image.size, mask_x1, mask_y1, mask_x2, mask_y2)
|
430 |
-
return guided_image, mask_image
|
431 |
-
|
432 |
-
# Endpoints
|
433 |
-
@app.post("/inpaint/")
|
434 |
-
async def inpaint_image(
|
435 |
-
image: UploadFile = File(...),
|
436 |
-
mask: UploadFile = File(...),
|
437 |
-
prompt: str = "Fill the masked area with appropriate content."
|
438 |
-
):
|
439 |
-
try:
|
440 |
-
image_bytes = await image.read()
|
441 |
-
mask_bytes = await mask.read()
|
442 |
-
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
443 |
-
mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L")
|
444 |
-
if original_image.size != mask_image.size:
|
445 |
-
raise HTTPException(status_code=400, detail="Image and mask dimensions must match.")
|
446 |
-
result = pipe_runway(prompt=prompt, image=original_image, mask_image=mask_image).images[0]
|
447 |
-
result_bytes = io.BytesIO()
|
448 |
-
result.save(result_bytes, format="PNG")
|
449 |
-
result_bytes.seek(0)
|
450 |
-
return StreamingResponse(
|
451 |
-
result_bytes,
|
452 |
-
media_type="image/png",
|
453 |
-
headers={"Content-Disposition": "attachment; filename=inpainted_image.png"}
|
454 |
-
)
|
455 |
-
except Exception as e:
|
456 |
-
raise HTTPException(status_code=500, detail=f"Error during inpainting: {e}")
|
457 |
-
|
458 |
-
@app.post("/inpaint-with-reference/")
|
459 |
-
async def inpaint_with_reference(
|
460 |
-
image: UploadFile = File(...),
|
461 |
-
reference_image: UploadFile = File(...),
|
462 |
-
prompt: str = "Integrate the reference content naturally into the masked area, matching style and lighting.",
|
463 |
-
mask_x1: int = 100,
|
464 |
-
mask_y1: int = 100,
|
465 |
-
mask_x2: int = 200,
|
466 |
-
mask_y2: int = 200
|
467 |
-
):
|
468 |
-
try:
|
469 |
-
image_bytes = await image.read()
|
470 |
-
reference_bytes = await reference_image.read()
|
471 |
-
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
472 |
-
reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB")
|
473 |
-
if original_image.size != reference_image.size:
|
474 |
-
reference_image = reference_image.resize(original_image.size, Image.Resampling.LANCZOS)
|
475 |
-
mask_image = generate_rectangular_mask(original_image.size, mask_x1, mask_y1, mask_x2, mask_y2)
|
476 |
-
softened_mask = soften_mask(mask_image, softness=5)
|
477 |
-
guided_image = prepare_guided_image(original_image, reference_image, softened_mask)
|
478 |
-
result = pipe_runway(
|
479 |
-
prompt=prompt,
|
480 |
-
image=guided_image,
|
481 |
-
mask_image=softened_mask,
|
482 |
-
strength=0.75,
|
483 |
-
guidance_scale=7.5
|
484 |
-
).images[0]
|
485 |
-
result_bytes = io.BytesIO()
|
486 |
-
result.save(result_bytes, format="PNG")
|
487 |
-
result_bytes.seek(0)
|
488 |
-
return StreamingResponse(
|
489 |
-
result_bytes,
|
490 |
-
media_type="image/png",
|
491 |
-
headers={"Content-Disposition": "attachment; filename=natural_inpaint_image.png"}
|
492 |
-
)
|
493 |
except Exception as e:
|
494 |
-
raise HTTPException(status_code=500, detail=f"Error
|
495 |
-
|
496 |
-
@app.post("/fit-image-to-mask/")
|
497 |
-
async def fit_image_to_mask_endpoint(
|
498 |
-
image: UploadFile = File(...),
|
499 |
-
reference_image: UploadFile = File(...),
|
500 |
-
mask_x1: int = 200,
|
501 |
-
mask_y1: int = 200,
|
502 |
-
mask_x2: int = 500,
|
503 |
-
mask_y2: int = 500
|
504 |
-
):
|
505 |
-
try:
|
506 |
-
image_bytes = await image.read()
|
507 |
-
reference_bytes = await reference_image.read()
|
508 |
-
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
509 |
-
reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB")
|
510 |
-
camouflaged_tank = await apply_camouflage_to_tank(reference_image)
|
511 |
-
guided_image, mask_image = fit_image_to_mask(original_image, camouflaged_tank, mask_x1, mask_y1, mask_x2, mask_y2)
|
512 |
-
guided_image.save("guided_image_before_blending.png")
|
513 |
-
softened_mask = soften_mask(mask_image, softness=2)
|
514 |
-
result = pipe_runway(
|
515 |
-
prompt="Blend the camouflaged tank into the grassy field with trees, ensuring a non-snowy environment, matching the style, lighting, and surroundings.",
|
516 |
-
image=guided_image,
|
517 |
-
mask_image=softened_mask,
|
518 |
-
strength=0.2,
|
519 |
-
guidance_scale=7.5,
|
520 |
-
num_inference_steps=50,
|
521 |
-
negative_prompt="snow, ice, rock, stone, boat, unrelated objects"
|
522 |
-
).images[0]
|
523 |
-
result_bytes = io.BytesIO()
|
524 |
-
result.save(result_bytes, format="PNG")
|
525 |
-
result_bytes.seek(0)
|
526 |
-
return StreamingResponse(
|
527 |
-
result_bytes,
|
528 |
-
media_type="image/png",
|
529 |
-
headers={"Content-Disposition": "attachment; filename=fitted_image.png"}
|
530 |
-
)
|
531 |
-
except ValueError as ve:
|
532 |
-
raise HTTPException(status_code=400, detail=f"ValueError in processing: {str(ve)}")
|
533 |
-
except Exception as e:
|
534 |
-
raise HTTPException(status_code=500, detail=f"Error during fitting and inpainting: {str(e)}")
|
535 |
-
|
536 |
|
537 |
@app.post("/detect-json/")
|
538 |
-
async def detect_json(
|
539 |
-
file: UploadFile = File(..., description="Image file to process"),
|
540 |
-
text_query: str = DEFAULT_TEXT_QUERY
|
541 |
-
):
|
542 |
-
"""Endpoint to detect objects and return bounding box information as JSON."""
|
543 |
try:
|
544 |
-
# Read and convert the uploaded image
|
545 |
image_data = await file.read()
|
546 |
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
547 |
-
|
548 |
-
# Process with Grounding DINO
|
549 |
results = process_image_with_dino(image, text_query)
|
550 |
-
|
551 |
-
|
552 |
-
detections = []
|
553 |
-
for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
|
554 |
-
x_min, y_min, x_max, y_max = box.tolist()
|
555 |
-
detections.append({
|
556 |
"label": label,
|
557 |
-
"score": float(score),
|
558 |
-
"box": {
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
"y_max": y_max
|
563 |
-
}
|
564 |
-
})
|
565 |
-
|
566 |
return JSONResponse(content={"detections": detections})
|
567 |
except Exception as e:
|
568 |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
569 |
|
570 |
@app.post("/segment-image/")
|
571 |
-
async def segment_image(
|
572 |
-
file: UploadFile = File(..., description="Image file to process"),
|
573 |
-
text_query: str = DEFAULT_TEXT_QUERY
|
574 |
-
):
|
575 |
-
"""Endpoint to segment objects and return the image with background removed."""
|
576 |
try:
|
577 |
-
# Read and convert the uploaded image
|
578 |
image_data = await file.read()
|
579 |
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
580 |
-
|
581 |
-
# Detect objects with Grounding DINO
|
582 |
results = process_image_with_dino(image, text_query)
|
583 |
-
|
584 |
-
# Extract boxes for segmentation, move to CPU
|
585 |
boxes = [
|
586 |
{"x_min": box[0].item(), "y_min": box[1].item(), "x_max": box[2].item(), "y_max": box[3].item()}
|
587 |
-
for box in results["boxes"].cpu()
|
588 |
]
|
589 |
-
|
590 |
-
# Segment with SAM 2
|
591 |
mask = segment_with_sam(image, boxes)
|
592 |
-
|
593 |
-
# Create background mask and apply it
|
594 |
image_np = np.array(image)
|
595 |
background_mask = create_background_mask(image_np, mask)
|
596 |
segmented_image = cv2.bitwise_and(image_np, background_mask)
|
597 |
-
|
598 |
-
# Convert to PIL Image and save to bytes
|
599 |
output_image = Image.fromarray(segmented_image)
|
600 |
img_byte_arr = io.BytesIO()
|
601 |
output_image.save(img_byte_arr, format="PNG")
|
602 |
img_byte_arr.seek(0)
|
603 |
-
|
604 |
-
return StreamingResponse(
|
605 |
-
img_byte_arr,
|
606 |
-
media_type="image/png",
|
607 |
-
headers={"Content-Disposition": "attachment; filename=segmented_image.png"}
|
608 |
-
)
|
609 |
-
except Exception as e:
|
610 |
-
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
611 |
-
|
612 |
-
@app.post("/mask-object/")
|
613 |
-
async def mask_object(
|
614 |
-
file: UploadFile = File(..., description="Image file to process"),
|
615 |
-
text_query: str = DEFAULT_TEXT_QUERY
|
616 |
-
):
|
617 |
-
"""Endpoint to mask the detected object and return the image with the object removed."""
|
618 |
-
try:
|
619 |
-
# Read and convert the uploaded image
|
620 |
-
image_data = await file.read()
|
621 |
-
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
622 |
-
|
623 |
-
# Detect objects with Grounding DINO
|
624 |
-
results = process_image_with_dino(image, text_query)
|
625 |
-
|
626 |
-
# Extract boxes for segmentation, move to CPU
|
627 |
-
boxes = [
|
628 |
-
{"x_min": box[0].item(), "y_min": box[1].item(), "x_max": box[2].item(), "y_max": box[3].item()}
|
629 |
-
for box in results["boxes"].cpu() # Move tensor to CPU here
|
630 |
-
]
|
631 |
-
|
632 |
-
# Segment with SAM 2
|
633 |
-
mask = segment_with_sam(image, boxes)
|
634 |
-
|
635 |
-
# Create object mask and apply it
|
636 |
-
image_np = np.array(image)
|
637 |
-
object_mask = create_object_mask(image_np, mask)
|
638 |
-
masked_image = cv2.bitwise_and(image_np, object_mask)
|
639 |
-
|
640 |
-
# Convert to PIL Image and save to bytes
|
641 |
-
output_image = Image.fromarray(masked_image)
|
642 |
-
img_byte_arr = io.BytesIO()
|
643 |
-
output_image.save(img_byte_arr, format="PNG")
|
644 |
-
img_byte_arr.seek(0)
|
645 |
-
|
646 |
-
return StreamingResponse(
|
647 |
-
img_byte_arr,
|
648 |
-
media_type="image/png",
|
649 |
-
headers={"Content-Disposition": "attachment; filename=masked_object_image.png"}
|
650 |
-
)
|
651 |
except Exception as e:
|
652 |
-
raise HTTPException(status_code=500, detail=f"Error
|
653 |
|
|
|
654 |
|
655 |
if __name__ == "__main__":
|
656 |
import uvicorn
|
|
|
1 |
+
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
2 |
+
from fastapi.responses import StreamingResponse, JSONResponse, Response
|
3 |
import io
|
4 |
import math
|
5 |
+
from PIL import Image, ImageOps, ImageDraw, ImageFilter
|
6 |
import torch
|
7 |
+
import numpy as np
|
8 |
+
from diffusers import (
|
9 |
+
StableDiffusionInstructPix2PixPipeline,
|
10 |
+
StableDiffusionInpaintPipeline,
|
11 |
+
StableDiffusionXLPipeline,
|
12 |
+
UNet2DConditionModel,
|
13 |
+
EulerDiscreteScheduler,
|
14 |
+
)
|
15 |
from huggingface_hub import hf_hub_download, login
|
16 |
from safetensors.torch import load_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
|
18 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
19 |
+
import cv2
|
20 |
+
import os
|
21 |
+
from typing import Optional
|
22 |
|
23 |
# Initialize FastAPI app
|
24 |
app = FastAPI()
|
25 |
|
26 |
+
# Device configuration
|
27 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
|
29 |
+
# Model variables (initially None, loaded lazily)
|
30 |
+
pipe_edit = None # InstructPix2Pix
|
31 |
+
pipe_inpaint = None # Stable Diffusion Inpainting
|
32 |
+
pipe_generate = None # Stable Diffusion XL
|
33 |
+
pipe_runway = None # Runway Inpainting
|
34 |
+
dino_processor = None # Grounding DINO processor
|
35 |
+
dino_model = None # Grounding DINO model
|
36 |
+
sam_predictor = None # SAM 2 predictor
|
37 |
|
38 |
+
# Default configuration values
|
39 |
+
DEFAULT_STEPS = 50
|
40 |
+
DEFAULT_TEXT_CFG = 7.5
|
41 |
+
DEFAULT_IMAGE_CFG = 1.5
|
42 |
+
DEFAULT_SEED = 1371
|
|
|
43 |
DEFAULT_TEXT_QUERY = "a tank."
|
44 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
45 |
|
46 |
+
# Helper functions for lazy loading
|
47 |
+
def load_instruct_pix2pix() -> StableDiffusionInstructPix2PixPipeline:
|
48 |
+
global pipe_edit
|
49 |
+
if pipe_edit is None:
|
50 |
+
model_id = "timbrooks/instruct-pix2pix"
|
51 |
+
pipe_edit = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
52 |
+
model_id, torch_dtype=torch.float16, safety_checker=None
|
53 |
+
).to(device)
|
54 |
+
return pipe_edit
|
55 |
+
|
56 |
+
def load_inpaint_pipeline() -> StableDiffusionInpaintPipeline:
|
57 |
+
global pipe_inpaint
|
58 |
+
if pipe_inpaint is None:
|
59 |
+
inpaint_model_id = "stabilityai/stable-diffusion-2-inpainting"
|
60 |
+
pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(
|
61 |
+
inpaint_model_id, torch_dtype=torch.float16, safety_checker=None
|
62 |
+
).to(device)
|
63 |
+
return pipe_inpaint
|
64 |
+
|
65 |
+
def load_generate_pipeline() -> StableDiffusionXLPipeline:
|
66 |
+
global pipe_generate
|
67 |
+
if pipe_generate is None:
|
68 |
+
try:
|
69 |
+
if HF_TOKEN:
|
70 |
+
login(token=HF_TOKEN)
|
71 |
+
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
72 |
+
repo = "ByteDance/SDXL-Lightning"
|
73 |
+
ckpt = "sdxl_lightning_4step_unet.safetensors"
|
74 |
+
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16)
|
75 |
+
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
|
76 |
+
pipe_generate = StableDiffusionXLPipeline.from_pretrained(
|
77 |
+
base, unet=unet, torch_dtype=torch.float16, variant="fp16"
|
78 |
+
).to(device)
|
79 |
+
pipe_generate.scheduler = EulerDiscreteScheduler.from_config(
|
80 |
+
pipe_generate.scheduler.config, timestep_spacing="trailing"
|
81 |
+
)
|
82 |
+
except Exception as e:
|
83 |
+
raise RuntimeError(f"Failed to load generate pipeline: {str(e)}")
|
84 |
+
return pipe_generate
|
85 |
+
|
86 |
+
def load_runway_inpaint() -> StableDiffusionInpaintPipeline:
|
87 |
+
global pipe_runway
|
88 |
+
if pipe_runway is None:
|
89 |
+
model_id_runway = "runwayml/stable-diffusion-inpainting"
|
90 |
+
pipe_runway = StableDiffusionInpaintPipeline.from_pretrained(model_id_runway).to(device)
|
91 |
+
return pipe_runway
|
92 |
+
|
93 |
+
def load_dino() -> tuple[AutoProcessor, AutoModelForZeroShotObjectDetection]:
|
94 |
+
global dino_processor, dino_model
|
95 |
+
if dino_processor is None or dino_model is None:
|
96 |
+
dino_model_id = "IDEA-Research/grounding-dino-base"
|
97 |
+
dino_processor = AutoProcessor.from_pretrained(dino_model_id)
|
98 |
+
dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(dino_model_id).to(device)
|
99 |
+
return dino_processor, dino_model
|
100 |
+
|
101 |
+
def load_sam() -> SAM2ImagePredictor:
|
102 |
+
global sam_predictor
|
103 |
+
if sam_predictor is None:
|
104 |
+
sam_predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-tiny")
|
105 |
+
sam_predictor.model.to(device)
|
106 |
+
return sam_predictor
|
107 |
+
|
108 |
+
# Image processing helper functions (unchanged, included for completeness)
|
109 |
def process_image_with_dino(image: Image.Image, text_query: str = DEFAULT_TEXT_QUERY):
|
110 |
+
processor, model = load_dino()
|
111 |
+
inputs = processor(images=image, text=text_query, return_tensors="pt").to(device)
|
112 |
with torch.no_grad():
|
113 |
+
outputs = model(**inputs)
|
114 |
+
results = processor.post_process_grounded_object_detection(
|
115 |
+
outputs, inputs.input_ids, threshold=0.4, text_threshold=0.3, target_sizes=[image.size[::-1]]
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
)
|
117 |
+
return results[0]
|
118 |
|
119 |
def segment_with_sam(image: Image.Image, boxes: list):
|
120 |
+
predictor = load_sam()
|
121 |
image_np = np.array(image)
|
122 |
+
predictor.set_image(image_np)
|
|
|
123 |
if not boxes:
|
124 |
+
return np.zeros(image_np.shape[:2], dtype=bool)
|
|
|
|
|
125 |
boxes_tensor = torch.tensor(
|
126 |
[[box["x_min"], box["y_min"], box["x_max"], box["y_max"]] for box in boxes],
|
127 |
dtype=torch.float32
|
128 |
).to(device)
|
129 |
+
masks, _, _ = predictor.predict(point_coords=None, point_labels=None, box=boxes_tensor, multimask_output=False)
|
130 |
+
return masks[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
def create_background_mask(image_np: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
133 |
+
mask_inv = np.logical_not(mask).astype(np.uint8) * 255
|
134 |
+
mask_rgb = cv2.cvtColor(mask_inv, cv2.COLOR_GRAY2RGB)
|
|
|
135 |
return mask_rgb
|
136 |
|
137 |
def create_object_mask(image_np: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
138 |
+
mask_rgb = cv2.cvtColor(mask.astype(np.uint8) * 255, cv2.COLOR_GRAY2RGB)
|
|
|
139 |
return mask_rgb
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
def process_image(input_image: Image.Image, instruction: str, steps: int, text_cfg_scale: float, image_cfg_scale: float, seed: int):
|
|
|
|
|
|
|
|
|
142 |
width, height = input_image.size
|
143 |
factor = 512 / max(width, height)
|
144 |
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
|
145 |
width = int((width * factor) // 64) * 64
|
146 |
height = int((height * factor) // 64) * 64
|
147 |
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
|
|
|
148 |
if not instruction:
|
149 |
return input_image
|
|
|
|
|
150 |
generator = torch.manual_seed(seed)
|
151 |
+
pipe = load_instruct_pix2pix()
|
152 |
+
edited_image = pipe(
|
|
|
153 |
instruction,
|
154 |
image=input_image,
|
155 |
guidance_scale=text_cfg_scale,
|
|
|
157 |
num_inference_steps=steps,
|
158 |
generator=generator,
|
159 |
).images[0]
|
|
|
160 |
return edited_image
|
161 |
|
162 |
+
# Endpoints
|
163 |
+
@app.get("/generate")
|
164 |
+
async def generate_image(prompt: str):
|
165 |
+
try:
|
166 |
+
pipe = load_generate_pipeline()
|
167 |
+
image = pipe(prompt, num_inference_steps=4, guidance_scale=0).images[0]
|
168 |
+
buffer = io.BytesIO()
|
169 |
+
image.save(buffer, format="PNG")
|
170 |
+
buffer.seek(0)
|
171 |
+
return Response(content=buffer.getvalue(), media_type="image/png")
|
172 |
+
except Exception as e:
|
173 |
+
raise HTTPException(status_code=500, detail=f"Error generating image: {str(e)}")
|
174 |
+
|
175 |
+
@app.get("/health")
|
176 |
+
async def health_check():
|
177 |
+
return {"status": "healthy"}
|
178 |
+
|
179 |
@app.post("/edit-image/")
|
180 |
async def edit_image(
|
181 |
file: UploadFile = File(...),
|
|
|
185 |
image_cfg_scale: float = Form(default=DEFAULT_IMAGE_CFG),
|
186 |
seed: int = Form(default=DEFAULT_SEED)
|
187 |
):
|
188 |
+
try:
|
189 |
+
image_data = await file.read()
|
190 |
+
input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
191 |
+
edited_image = process_image(input_image, instruction, steps, text_cfg_scale, image_cfg_scale, seed)
|
192 |
+
img_byte_arr = io.BytesIO()
|
193 |
+
edited_image.save(img_byte_arr, format="PNG")
|
194 |
+
img_byte_arr.seek(0)
|
195 |
+
return StreamingResponse(img_byte_arr, media_type="image/png")
|
196 |
+
except Exception as e:
|
197 |
+
raise HTTPException(status_code=500, detail=f"Error editing image: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
|
|
199 |
@app.post("/inpaint/")
|
200 |
async def inpaint_image(
|
201 |
file: UploadFile = File(...),
|
202 |
prompt: str = Form(...),
|
203 |
+
mask_coordinates: str = Form(...),
|
204 |
steps: int = Form(default=DEFAULT_STEPS),
|
205 |
guidance_scale: float = Form(default=7.5),
|
206 |
seed: int = Form(default=DEFAULT_SEED)
|
207 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
208 |
try:
|
|
|
209 |
image_data = await file.read()
|
210 |
input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
|
|
|
|
211 |
width, height = input_image.size
|
212 |
factor = 512 / max(width, height)
|
213 |
factor = math.ceil(min(width, height) * factor / 8) * 8 / min(width, height)
|
214 |
width = int((width * factor) // 8) * 8
|
215 |
height = int((height * factor) // 8) * 8
|
216 |
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
|
217 |
+
mask = Image.new("L", (width, height), 0)
|
|
|
|
|
218 |
draw = ImageDraw.Draw(mask)
|
219 |
+
x1, y1, x2, y2 = map(int, mask_coordinates.split(","))
|
220 |
+
x1, y1, x2, y2 = int(x1 * factor), int(y1 * factor), int(x2 * factor), int(y2 * factor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
draw.rectangle([x1, y1, x2, y2], fill=255)
|
|
|
|
|
222 |
generator = torch.manual_seed(seed)
|
223 |
+
pipe = load_inpaint_pipeline()
|
224 |
+
inpainted_image = pipe(
|
|
|
225 |
prompt=prompt,
|
226 |
image=input_image,
|
227 |
mask_image=mask,
|
|
|
229 |
guidance_scale=guidance_scale,
|
230 |
generator=generator,
|
231 |
).images[0]
|
|
|
|
|
232 |
img_byte_arr = io.BytesIO()
|
233 |
inpainted_image.save(img_byte_arr, format="PNG")
|
234 |
img_byte_arr.seek(0)
|
|
|
|
|
235 |
return StreamingResponse(img_byte_arr, media_type="image/png")
|
236 |
+
except ValueError:
|
237 |
+
raise HTTPException(status_code=400, detail="Invalid mask coordinates format. Use 'x1,y1,x2,y2'.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
except Exception as e:
|
239 |
+
raise HTTPException(status_code=500, detail=f"Error inpainting image: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
|
241 |
@app.post("/detect-json/")
|
242 |
+
async def detect_json(file: UploadFile = File(...), text_query: str = DEFAULT_TEXT_QUERY):
|
|
|
|
|
|
|
|
|
243 |
try:
|
|
|
244 |
image_data = await file.read()
|
245 |
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
|
|
|
|
246 |
results = process_image_with_dino(image, text_query)
|
247 |
+
detections = [
|
248 |
+
{
|
|
|
|
|
|
|
|
|
249 |
"label": label,
|
250 |
+
"score": float(score),
|
251 |
+
"box": {"x_min": box[0].item(), "y_min": box[1].item(), "x_max": box[2].item(), "y_max": box[3].item()}
|
252 |
+
}
|
253 |
+
for box, label, score in zip(results["boxes"].cpu(), results["labels"], results["scores"])
|
254 |
+
]
|
|
|
|
|
|
|
|
|
255 |
return JSONResponse(content={"detections": detections})
|
256 |
except Exception as e:
|
257 |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
|
258 |
|
259 |
@app.post("/segment-image/")
|
260 |
+
async def segment_image(file: UploadFile = File(...), text_query: str = DEFAULT_TEXT_QUERY):
|
|
|
|
|
|
|
|
|
261 |
try:
|
|
|
262 |
image_data = await file.read()
|
263 |
image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
|
|
|
|
264 |
results = process_image_with_dino(image, text_query)
|
|
|
|
|
265 |
boxes = [
|
266 |
{"x_min": box[0].item(), "y_min": box[1].item(), "x_max": box[2].item(), "y_max": box[3].item()}
|
267 |
+
for box in results["boxes"].cpu()
|
268 |
]
|
|
|
|
|
269 |
mask = segment_with_sam(image, boxes)
|
|
|
|
|
270 |
image_np = np.array(image)
|
271 |
background_mask = create_background_mask(image_np, mask)
|
272 |
segmented_image = cv2.bitwise_and(image_np, background_mask)
|
|
|
|
|
273 |
output_image = Image.fromarray(segmented_image)
|
274 |
img_byte_arr = io.BytesIO()
|
275 |
output_image.save(img_byte_arr, format="PNG")
|
276 |
img_byte_arr.seek(0)
|
277 |
+
return StreamingResponse(img_byte_arr, media_type="image/png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
except Exception as e:
|
279 |
+
raise HTTPException(status_code=500, detail=f"Error segmenting image: {str(e)}")
|
280 |
|
281 |
+
# Add other endpoints (e.g., /mask-object/, /fit-image-to-mask/) with similar lazy loading patterns as needed
|
282 |
|
283 |
if __name__ == "__main__":
|
284 |
import uvicorn
|