from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.middleware.cors import CORSMiddleware import torch from PIL import Image import io import base64 from diffusers import StableDiffusionInpaintPipeline import gc from fastapi.responses import JSONResponse import logging app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variable for the model pipe = None # Add max size limit MAX_SIZE = 512 def load_model(): global pipe if pipe is None: model_id = "Uminosachi/realisticVisionV51_v51VAE-inpainting" try: device = "cuda" if torch.cuda.is_available() else "cpu" pipe = StableDiffusionInpaintPipeline.from_pretrained( model_id, torch_dtype=torch.float16, safety_checker=None ).to(device) if device == "cuda": pipe.enable_attention_slicing() print(f"Model loaded on {device} with optimizations") except Exception as e: print(f"Error loading model: {str(e)}") raise return pipe @app.on_event("startup") async def startup_event(): try: load_model() except Exception as e: print(f"Startup error: {str(e)}") def image_to_base64(image: Image.Image) -> str: buffered = io.BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode() def resize_for_condition_image(input_image: Image.Image, resolution: int): input_width, input_height = input_image.size aspect_ratio = input_height / input_width if input_height > input_width: # vertical image width = resolution height = int(resolution * aspect_ratio) else: # horizontal image height = resolution width = int(resolution / aspect_ratio) return input_image.resize((width, height)) @app.post("/inpaint") async def inpaint( image: UploadFile = File(...), mask: UploadFile = File(...), prompt: str = "add some flowers and a fountain", negative_prompt: str = "blurry, low quality, distorted" ): try: # Add file size check (10MB limit) max_size = 10 * 1024 * 1024 # 10MB if len(await image.read()) > max_size or len(await mask.read()) > max_size: return JSONResponse( status_code=400, content={"error": "File size too large. Maximum size is 10MB"} ) # Reset file positions await image.seek(0) await mask.seek(0) # Read and process input image image_data = await image.read() mask_data = await mask.read() original_image = Image.open(io.BytesIO(image_data)) mask_image = Image.open(io.BytesIO(mask_data)) # Resize images to smaller size original_image = resize_for_condition_image(original_image, MAX_SIZE) mask_image = resize_for_condition_image(mask_image, MAX_SIZE) mask_image = mask_image.convert("L") # Reduce steps even more for CPU num_steps = 5 if not torch.cuda.is_available() else 20 with torch.cuda.amp.autocast(): output_image = pipe( prompt=prompt, negative_prompt=negative_prompt, image=original_image, mask_image=mask_image, num_inference_steps=num_steps, guidance_scale=7.0, # Slightly reduced for speed ).images[0] # Convert output image to base64 output_base64 = image_to_base64(output_image) # Clean up torch.cuda.empty_cache() gc.collect() return {"status": "success", "image": output_base64} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): return {"status": "healthy", "cuda_available": torch.cuda.is_available()}