Spaces:
Paused
Paused
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.responses import StreamingResponse | |
from PIL import Image | |
import io | |
import torch | |
import numpy as np | |
from diffusers import StableDiffusionInpaintPipeline | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Load the pre-trained inpainting model (Stable Diffusion) | |
model_id = "runwayml/stable-diffusion-inpainting" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
try: | |
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id) | |
pipe.to(device) | |
except Exception as e: | |
raise RuntimeError(f"Failed to load model: {e}") | |
async def root(): | |
""" | |
Root endpoint for basic health check. | |
""" | |
return {"message": "InstructPix2Pix API is running. Use POST /inpaint/ or /inpaint-with-reference/ to edit images."} | |
def blend_images(original_image: Image, reference_image: Image, mask_image: Image) -> Image: | |
""" | |
Blend the original image with a reference image using the mask. | |
- Unmasked areas (white in mask) take pixels from the original image. | |
- Masked areas (black in mask) take pixels from the reference image as a starting point. | |
Args: | |
original_image (Image): The original image (RGB). | |
reference_image (Image): The reference image to blend with (RGB). | |
mask_image (Image): The mask image (grayscale, L mode). | |
Returns: | |
Image: The blended image. | |
""" | |
# Convert images to numpy arrays | |
original_array = np.array(original_image) | |
reference_array = np.array(reference_image) | |
mask_array = np.array(mask_image) / 255.0 # Normalize mask to [0, 1] | |
# Ensure mask is broadcastable to RGB channels | |
mask_array = mask_array[:, :, np.newaxis] | |
# Blend: unmasked areas (mask=1) keep original, masked areas (mask=0) use reference | |
blended_array = original_array * mask_array + reference_array * (1 - mask_array) | |
blended_array = blended_array.astype(np.uint8) | |
return Image.fromarray(blended_array) | |
async def inpaint_image( | |
image: UploadFile = File(...), | |
mask: UploadFile = File(...), | |
prompt: str = "Fill the masked area with appropriate content." | |
): | |
""" | |
Endpoint for image inpainting using a text prompt. | |
- `image`: Original image file (PNG/JPG). | |
- `mask`: Mask file indicating areas to inpaint (black for masked areas, white for unmasked). | |
- `prompt`: Text prompt describing the desired output. | |
Returns: | |
- The inpainted image as a PNG file. | |
""" | |
try: | |
# Load the uploaded image and mask | |
image_bytes = await image.read() | |
mask_bytes = await mask.read() | |
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L") | |
# Ensure dimensions match between image and mask | |
if original_image.size != mask_image.size: | |
raise HTTPException(status_code=400, detail="Image and mask dimensions must match.") | |
# Perform inpainting using the pipeline | |
result = pipe(prompt=prompt, image=original_image, mask_image=mask_image).images[0] | |
# Convert result to bytes for response | |
result_bytes = io.BytesIO() | |
result.save(result_bytes, format="PNG") | |
result_bytes.seek(0) | |
# Return the image as a streaming response | |
return StreamingResponse( | |
result_bytes, | |
media_type="image/png", | |
headers={"Content-Disposition": "attachment; filename=inpainted_image.png"} | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error during inpainting: {e}") | |
async def inpaint_with_reference( | |
image: UploadFile = File(...), | |
mask: UploadFile = File(...), | |
reference_image: UploadFile = File(...), | |
prompt: str = "Fill the masked area with appropriate content." | |
): | |
""" | |
Endpoint for image inpainting using both a text prompt and a reference image. | |
- `image`: Original image file (PNG/JPG). | |
- `mask`: Mask file indicating areas to inpaint (black for masked areas, white for unmasked). | |
- `reference_image`: Reference image to guide the inpainting (PNG/JPG). | |
- `prompt`: Text prompt describing the desired output. | |
Returns: | |
- The inpainted image as a PNG file. | |
""" | |
try: | |
# Load the uploaded image, mask, and reference image | |
image_bytes = await image.read() | |
mask_bytes = await mask.read() | |
reference_bytes = await reference_image.read() | |
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L") | |
reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB") | |
# Ensure dimensions match between image, mask, and reference image | |
if original_image.size != mask_image.size: | |
raise HTTPException(status_code=400, detail="Image and mask dimensions must match.") | |
if original_image.size != reference_image.size: | |
reference_image = reference_image.resize(original_image.size, Image.Resampling.LANCZOS) | |
# Blend the original and reference images using the mask | |
blended_image = blend_images(original_image, reference_image, mask_image) | |
# Perform inpainting using the pipeline | |
result = pipe(prompt=prompt, image=blended_image, mask_image=mask_image).images[0] | |
# Convert result to bytes for response | |
result_bytes = io.BytesIO() | |
result.save(result_bytes, format="PNG") | |
result_bytes.seek(0) | |
# Return the image as a streaming response | |
return StreamingResponse( | |
result_bytes, | |
media_type="image/png", | |
headers={"Content-Disposition": "attachment; filename=inpainted_with_reference_image.png"} | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error during inpainting with reference: {e}") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |