ghost-vision / runway.py
sachin
adffrim?
be6ce26
raw
history blame
2.42 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
import io
import torch
from diffusers import StableDiffusionInpaintPipeline
# Initialize FastAPI app
app = FastAPI()
# Load the pre-trained inpainting model (Stable Diffusion)
model_id = "runwayml/stable-diffusion-inpainting"
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id)
pipe.to(device)
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
@app.get("/")
def read_root():
return {"message": "Welcome to the Image Inpainting API!"}
@app.post("/inpaint/")
async def inpaint_image(
image: UploadFile = File(...),
mask: UploadFile = File(...),
prompt: str = "Fill the masked area with appropriate content."
):
"""
Endpoint for image inpainting.
- `image`: Original image file (PNG/JPG).
- `mask`: Mask file indicating areas to inpaint (black for masked areas, white for unmasked).
- `prompt`: Text prompt describing the desired output.
"""
try:
# Load the uploaded image and mask
image_bytes = await image.read()
mask_bytes = await mask.read()
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L")
# Ensure dimensions match between image and mask
if original_image.size != mask_image.size:
raise HTTPException(status_code=400, detail="Image and mask dimensions must match.")
# Perform inpainting using the pipeline
result = pipe(prompt=prompt, image=original_image, mask_image=mask_image).images[0]
# Convert result to bytes for response
result_bytes = io.BytesIO()
result.save(result_bytes, format="PNG")
result_bytes.seek(0)
return JSONResponse(content={"message": "Inpainting successful!"}, media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during inpainting: {e}")
@app.get("/")
async def root():
"""
Root endpoint for basic health check.
"""
return {"message": "InstructPix2Pix API is running. Use POST /edit-image/ or /inpaint/ to edit images."}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)