Spaces:
Paused
Paused
File size: 9,999 Bytes
be6ce26 dd59a2b 7ba3751 be6ce26 dd59a2b be6ce26 1485575 a14b836 dd59a2b 2b9ad07 dd59a2b 2b9ad07 57cd301 dd59a2b a14b836 dd59a2b 57cd301 a14b836 be6ce26 2b9ad07 7ba3751 57cd301 7ba3751 57cd301 7ba3751 57cd301 7ba3751 a14b836 19259ed a14b836 be6ce26 b8546a6 be6ce26 b8546a6 be6ce26 b8546a6 be6ce26 b8546a6 be6ce26 b8546a6 be6ce26 b8546a6 be6ce26 b8546a6 1485575 be6ce26 dd59a2b 7ba3751 dd59a2b b8546a6 a14b836 b8546a6 dd59a2b a14b836 dd59a2b a14b836 19259ed b8546a6 19259ed 7ba3751 2b9ad07 dd59a2b a14b836 2b9ad07 a14b836 2b9ad07 dd59a2b a14b836 dd59a2b 19259ed dd59a2b 19259ed dd59a2b be6ce26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from PIL import Image, ImageDraw
import io
import torch
import numpy as np
from diffusers import StableDiffusionInpaintPipeline
# Initialize FastAPI app
app = FastAPI()
# Load the pre-trained inpainting model (Stable Diffusion)
model_id = "runwayml/stable-diffusion-inpainting"
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id)
pipe.to(device)
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
@app.get("/")
async def root():
"""
Root endpoint for basic health check.
"""
return {"message": "InstructPix2Pix API is running. Use POST /inpaint/, /inpaint-with-reference/, or /fit-image-to-mask/ to edit images."}
def prepare_guided_image(original_image: Image, reference_image: Image, mask_image: Image) -> Image:
"""
Prepare an initial image by softly blending the reference image into the masked area.
- Areas to keep (black in mask, 0) remain fully from the original image.
- Areas to inpaint (white in mask, 255) take content from the reference image with soft blending.
"""
original_array = np.array(original_image)
reference_array = np.array(reference_image)
mask_array = np.array(mask_image) / 255.0
mask_array = mask_array[:, :, np.newaxis]
blended_array = original_array * (1 - mask_array) + reference_array * mask_array
return Image.fromarray(blended_array.astype(np.uint8))
def soften_mask(mask_image: Image, softness: int = 5) -> Image:
"""
Soften the edges of the mask for smoother transitions.
"""
from PIL import ImageFilter
return mask_image.filter(ImageFilter.GaussianBlur(radius=softness))
def generate_rectangular_mask(image_size: tuple, x1: int = 100, y1: int = 100, x2: int = 200, y2: int = 200) -> Image:
"""
Generate a rectangular mask matching the image dimensions.
- Black (0) for areas to keep, white (255) for areas to inpaint.
"""
mask = Image.new("L", image_size, 0)
draw = ImageDraw.Draw(mask)
draw.rectangle([x1, y1, x2, y2], fill=255)
return mask
def fit_image_to_mask(original_image: Image, reference_image: Image, mask_x1: int, mask_y1: int, mask_x2: int, mask_y2: int) -> tuple:
"""
Fit the reference image into the masked region of the original image.
Args:
original_image (Image): The original image (RGB).
reference_image (Image): The image to fit into the masked region (RGB).
mask_x1, mask_y1, mask_x2, mask_y2 (int): Coordinates of the masked region.
Returns:
tuple: (guided_image, mask_image) - The image with the fitted reference and the corresponding mask.
"""
# Calculate mask dimensions
mask_width = mask_x2 - mask_x1
mask_height = mask_y2 - mask_y1
# Ensure mask dimensions are positive
if mask_width <= 0 or mask_height <= 0:
raise ValueError("Mask dimensions must be positive")
# Resize reference image to fit the mask while preserving aspect ratio
ref_width, ref_height = reference_image.size
aspect_ratio = ref_width / ref_height
if mask_width / mask_height > aspect_ratio:
# Fit to height
new_height = mask_height
new_width = int(new_height * aspect_ratio)
else:
# Fit to width
new_width = mask_width
new_height = int(new_width / aspect_ratio)
# Resize reference image
reference_image_resized = reference_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create a copy of the original image to paste the reference image onto
guided_image = original_image.copy()
# Calculate position to center the resized image in the mask
paste_x = mask_x1 + (mask_width - new_width) // 2
paste_y = mask_y1 + (mask_height - new_height) // 2
# Paste the resized reference image onto the original image
guided_image.paste(reference_image_resized, (paste_x, paste_y))
# Generate the mask for inpainting (white in the pasted region)
mask_image = generate_rectangular_mask(original_image.size, mask_x1, mask_y1, mask_x2, mask_y2)
return guided_image, mask_image
@app.post("/inpaint/")
async def inpaint_image(
image: UploadFile = File(...),
mask: UploadFile = File(...),
prompt: str = "Fill the masked area with appropriate content."
):
"""
Endpoint for image inpainting using a text prompt and an uploaded mask.
- `image`: Original image file (PNG/JPG).
- `mask`: Mask file indicating areas to inpaint (white for masked areas, black for unmasked).
- `prompt`: Text prompt describing the desired output.
Returns:
- The inpainted image as a PNG file.
"""
try:
# Load the uploaded image and mask
image_bytes = await image.read()
mask_bytes = await mask.read()
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L")
# Ensure dimensions match between image and mask
if original_image.size != mask_image.size:
raise HTTPException(status_code=400, detail="Image and mask dimensions must match.")
# Perform inpainting using the pipeline
result = pipe(prompt=prompt, image=original_image, mask_image=mask_image).images[0]
# Convert result to bytes for response
result_bytes = io.BytesIO()
result.save(result_bytes, format="PNG")
result_bytes.seek(0)
# Return the image as a streaming response
return StreamingResponse(
result_bytes,
media_type="image/png",
headers={"Content-Disposition": "attachment; filename=inpainted_image.png"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during inpainting: {e}")
@app.post("/inpaint-with-reference/")
async def inpaint_with_reference(
image: UploadFile = File(...),
reference_image: UploadFile = File(...),
prompt: str = "Integrate the reference content naturally into the masked area, matching style and lighting.",
mask_x1: int = 100,
mask_y1: int = 100,
mask_x2: int = 200,
mask_y2: int = 200
):
"""
Endpoint for replacing masked areas with reference image content, refined to look natural, using an autogenerated mask.
"""
try:
image_bytes = await image.read()
reference_bytes = await reference_image.read()
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB")
if original_image.size != reference_image.size:
reference_image = reference_image.resize(original_image.size, Image.Resampling.LANCZOS)
mask_image = generate_rectangular_mask(original_image.size, mask_x1, mask_y1, mask_x2, mask_y2)
softened_mask = soften_mask(mask_image, softness=5)
guided_image = prepare_guided_image(original_image, reference_image, softened_mask)
result = pipe(
prompt=prompt,
image=guided_image,
mask_image=softened_mask,
strength=0.75,
guidance_scale=7.5
).images[0]
result_bytes = io.BytesIO()
result.save(result_bytes, format="PNG")
result_bytes.seek(0)
return StreamingResponse(
result_bytes,
media_type="image/png",
headers={"Content-Disposition": "attachment; filename=natural_inpaint_image.png"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during natural inpainting: {e}")
@app.post("/fit-image-to-mask/")
async def fit_image_to_mask(
image: UploadFile = File(...),
reference_image: UploadFile = File(...),
prompt: str = "Blend the fitted image naturally into the scene, matching style and lighting.",
mask_x1: int = 100,
mask_y1: int = 100,
mask_x2: int = 200,
mask_y2: int = 200
):
"""
Endpoint for fitting a reference image into a masked region of the original image, refined to look natural, using an autogenerated mask.
"""
try:
# Load the uploaded images
image_bytes = await image.read()
reference_bytes = await reference_image.read()
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB")
# Fit the reference image into the masked region
result = fit_image_to_mask(original_image, reference_image, mask_x1, mask_y1, mask_x2, mask_y2)
if not isinstance(result, tuple) or len(result) != 2:
raise ValueError(f"Expected tuple of (guided_image, mask_image), got {type(result)}: {result}")
guided_image, mask_image = result
# Soften the mask for smoother transitions
softened_mask = soften_mask(mask_image, softness=5)
# Perform inpainting to blend the fitted image naturally
result = pipe(
prompt=prompt,
image=guided_image,
mask_image=softened_mask,
strength=0.75,
guidance_scale=7.5
).images[0]
# Convert result to bytes for response
result_bytes = io.BytesIO()
result.save(result_bytes, format="PNG")
result_bytes.seek(0)
return StreamingResponse(
result_bytes,
media_type="image/png",
headers={"Content-Disposition": "attachment; filename=fitted_image.png"}
)
except ValueError as ve:
raise HTTPException(status_code=500, detail=f"ValueError in processing: {str(ve)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during fitting and inpainting: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |