from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import StreamingResponse from PIL import Image, ImageDraw import io import torch import numpy as np from diffusers import StableDiffusionInpaintPipeline # Initialize FastAPI app app = FastAPI() # Load the pre-trained inpainting model (Stable Diffusion) model_id = "runwayml/stable-diffusion-inpainting" device = "cuda" if torch.cuda.is_available() else "cpu" try: pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id) pipe.to(device) except Exception as e: raise RuntimeError(f"Failed to load model: {e}") @app.get("/") async def root(): """ Root endpoint for basic health check. """ return {"message": "InstructPix2Pix API is running. Use POST /inpaint/, /inpaint-with-reference/, or /fit-image-to-mask/ to edit images."} def prepare_guided_image(original_image: Image, reference_image: Image, mask_image: Image) -> Image: """ Prepare an initial image by softly blending the reference image into the masked area. - Areas to keep (black in mask, 0) remain fully from the original image. - Areas to inpaint (white in mask, 255) take content from the reference image with soft blending. """ original_array = np.array(original_image) reference_array = np.array(reference_image) mask_array = np.array(mask_image) / 255.0 mask_array = mask_array[:, :, np.newaxis] blended_array = original_array * (1 - mask_array) + reference_array * mask_array return Image.fromarray(blended_array.astype(np.uint8)) def soften_mask(mask_image: Image, softness: int = 5) -> Image: """ Soften the edges of the mask for smoother transitions. """ from PIL import ImageFilter return mask_image.filter(ImageFilter.GaussianBlur(radius=softness)) def generate_rectangular_mask(image_size: tuple, x1: int = 100, y1: int = 100, x2: int = 200, y2: int = 200) -> Image: """ Generate a rectangular mask matching the image dimensions. - Black (0) for areas to keep, white (255) for areas to inpaint. """ mask = Image.new("L", image_size, 0) draw = ImageDraw.Draw(mask) draw.rectangle([x1, y1, x2, y2], fill=255) return mask def fit_image_to_mask(original_image: Image, reference_image: Image, mask_x1: int, mask_y1: int, mask_x2: int, mask_y2: int) -> tuple: """ Fit the reference image into the masked region of the original image. Args: original_image (Image): The original image (RGB). reference_image (Image): The image to fit into the masked region (RGB). mask_x1, mask_y1, mask_x2, mask_y2 (int): Coordinates of the masked region. Returns: tuple: (guided_image, mask_image) - The image with the fitted reference and the corresponding mask. """ # Calculate mask dimensions mask_width = mask_x2 - mask_x1 mask_height = mask_y2 - mask_y1 # Ensure mask dimensions are positive if mask_width <= 0 or mask_height <= 0: raise ValueError("Mask dimensions must be positive") # Resize reference image to fit the mask while preserving aspect ratio ref_width, ref_height = reference_image.size aspect_ratio = ref_width / ref_height if mask_width / mask_height > aspect_ratio: # Fit to height new_height = mask_height new_width = int(new_height * aspect_ratio) else: # Fit to width new_width = mask_width new_height = int(new_width / aspect_ratio) # Resize reference image reference_image_resized = reference_image.resize((new_width, new_height), Image.Resampling.LANCZOS) # Create a copy of the original image to paste the reference image onto guided_image = original_image.copy() # Calculate position to center the resized image in the mask paste_x = mask_x1 + (mask_width - new_width) // 2 paste_y = mask_y1 + (mask_height - new_height) // 2 # Paste the resized reference image onto the original image guided_image.paste(reference_image_resized, (paste_x, paste_y)) # Generate the mask for inpainting (white in the pasted region) mask_image = generate_rectangular_mask(original_image.size, mask_x1, mask_y1, mask_x2, mask_y2) return guided_image, mask_image @app.post("/inpaint/") async def inpaint_image( image: UploadFile = File(...), mask: UploadFile = File(...), prompt: str = "Fill the masked area with appropriate content." ): """ Endpoint for image inpainting using a text prompt and an uploaded mask. - `image`: Original image file (PNG/JPG). - `mask`: Mask file indicating areas to inpaint (white for masked areas, black for unmasked). - `prompt`: Text prompt describing the desired output. Returns: - The inpainted image as a PNG file. """ try: # Load the uploaded image and mask image_bytes = await image.read() mask_bytes = await mask.read() original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L") # Ensure dimensions match between image and mask if original_image.size != mask_image.size: raise HTTPException(status_code=400, detail="Image and mask dimensions must match.") # Perform inpainting using the pipeline result = pipe(prompt=prompt, image=original_image, mask_image=mask_image).images[0] # Convert result to bytes for response result_bytes = io.BytesIO() result.save(result_bytes, format="PNG") result_bytes.seek(0) # Return the image as a streaming response return StreamingResponse( result_bytes, media_type="image/png", headers={"Content-Disposition": "attachment; filename=inpainted_image.png"} ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error during inpainting: {e}") @app.post("/inpaint-with-reference/") async def inpaint_with_reference( image: UploadFile = File(...), reference_image: UploadFile = File(...), prompt: str = "Integrate the reference content naturally into the masked area, matching style and lighting.", mask_x1: int = 100, mask_y1: int = 100, mask_x2: int = 200, mask_y2: int = 200 ): """ Endpoint for replacing masked areas with reference image content, refined to look natural, using an autogenerated mask. """ try: image_bytes = await image.read() reference_bytes = await reference_image.read() original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB") if original_image.size != reference_image.size: reference_image = reference_image.resize(original_image.size, Image.Resampling.LANCZOS) mask_image = generate_rectangular_mask(original_image.size, mask_x1, mask_y1, mask_x2, mask_y2) softened_mask = soften_mask(mask_image, softness=5) guided_image = prepare_guided_image(original_image, reference_image, softened_mask) result = pipe( prompt=prompt, image=guided_image, mask_image=softened_mask, strength=0.75, guidance_scale=7.5 ).images[0] result_bytes = io.BytesIO() result.save(result_bytes, format="PNG") result_bytes.seek(0) return StreamingResponse( result_bytes, media_type="image/png", headers={"Content-Disposition": "attachment; filename=natural_inpaint_image.png"} ) except Exception as e: raise HTTPException(status_code=500, detail=f"Error during natural inpainting: {e}") @app.post("/fit-image-to-mask/") async def fit_image_to_mask( image: UploadFile = File(...), reference_image: UploadFile = File(...), prompt: str = "Blend the fitted image naturally into the scene, matching style and lighting.", mask_x1: int = 100, mask_y1: int = 100, mask_x2: int = 200, mask_y2: int = 200 ): """ Endpoint for fitting a reference image into a masked region of the original image, refined to look natural, using an autogenerated mask. """ try: # Load the uploaded images image_bytes = await image.read() reference_bytes = await reference_image.read() original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB") # Fit the reference image into the masked region result = fit_image_to_mask(original_image, reference_image, mask_x1, mask_y1, mask_x2, mask_y2) if not isinstance(result, tuple) or len(result) != 2: raise ValueError(f"Expected tuple of (guided_image, mask_image), got {type(result)}: {result}") guided_image, mask_image = result # Soften the mask for smoother transitions softened_mask = soften_mask(mask_image, softness=5) # Perform inpainting to blend the fitted image naturally result = pipe( prompt=prompt, image=guided_image, mask_image=softened_mask, strength=0.75, guidance_scale=7.5 ).images[0] # Convert result to bytes for response result_bytes = io.BytesIO() result.save(result_bytes, format="PNG") result_bytes.seek(0) return StreamingResponse( result_bytes, media_type="image/png", headers={"Content-Disposition": "attachment; filename=fitted_image.png"} ) except ValueError as ve: raise HTTPException(status_code=500, detail=f"ValueError in processing: {str(ve)}") except Exception as e: raise HTTPException(status_code=500, detail=f"Error during fitting and inpainting: {str(e)}") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)