ghost-vision / runway.py
sachin
add- blend image
dd59a2b
raw
history blame
6.11 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from PIL import Image
import io
import torch
import numpy as np
from diffusers import StableDiffusionInpaintPipeline
# Initialize FastAPI app
app = FastAPI()
# Load the pre-trained inpainting model (Stable Diffusion)
model_id = "runwayml/stable-diffusion-inpainting"
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id)
pipe.to(device)
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
@app.get("/")
async def root():
"""
Root endpoint for basic health check.
"""
return {"message": "InstructPix2Pix API is running. Use POST /inpaint/ or /inpaint-with-reference/ to edit images."}
def blend_images(original_image: Image, reference_image: Image, mask_image: Image) -> Image:
"""
Blend the original image with a reference image using the mask.
- Unmasked areas (white in mask) take pixels from the original image.
- Masked areas (black in mask) take pixels from the reference image as a starting point.
Args:
original_image (Image): The original image (RGB).
reference_image (Image): The reference image to blend with (RGB).
mask_image (Image): The mask image (grayscale, L mode).
Returns:
Image: The blended image.
"""
# Convert images to numpy arrays
original_array = np.array(original_image)
reference_array = np.array(reference_image)
mask_array = np.array(mask_image) / 255.0 # Normalize mask to [0, 1]
# Ensure mask is broadcastable to RGB channels
mask_array = mask_array[:, :, np.newaxis]
# Blend: unmasked areas (mask=1) keep original, masked areas (mask=0) use reference
blended_array = original_array * mask_array + reference_array * (1 - mask_array)
blended_array = blended_array.astype(np.uint8)
return Image.fromarray(blended_array)
@app.post("/inpaint/")
async def inpaint_image(
image: UploadFile = File(...),
mask: UploadFile = File(...),
prompt: str = "Fill the masked area with appropriate content."
):
"""
Endpoint for image inpainting using a text prompt.
- `image`: Original image file (PNG/JPG).
- `mask`: Mask file indicating areas to inpaint (black for masked areas, white for unmasked).
- `prompt`: Text prompt describing the desired output.
Returns:
- The inpainted image as a PNG file.
"""
try:
# Load the uploaded image and mask
image_bytes = await image.read()
mask_bytes = await mask.read()
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L")
# Ensure dimensions match between image and mask
if original_image.size != mask_image.size:
raise HTTPException(status_code=400, detail="Image and mask dimensions must match.")
# Perform inpainting using the pipeline
result = pipe(prompt=prompt, image=original_image, mask_image=mask_image).images[0]
# Convert result to bytes for response
result_bytes = io.BytesIO()
result.save(result_bytes, format="PNG")
result_bytes.seek(0)
# Return the image as a streaming response
return StreamingResponse(
result_bytes,
media_type="image/png",
headers={"Content-Disposition": "attachment; filename=inpainted_image.png"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during inpainting: {e}")
@app.post("/inpaint-with-reference/")
async def inpaint_with_reference(
image: UploadFile = File(...),
mask: UploadFile = File(...),
reference_image: UploadFile = File(...),
prompt: str = "Fill the masked area with appropriate content."
):
"""
Endpoint for image inpainting using both a text prompt and a reference image.
- `image`: Original image file (PNG/JPG).
- `mask`: Mask file indicating areas to inpaint (black for masked areas, white for unmasked).
- `reference_image`: Reference image to guide the inpainting (PNG/JPG).
- `prompt`: Text prompt describing the desired output.
Returns:
- The inpainted image as a PNG file.
"""
try:
# Load the uploaded image, mask, and reference image
image_bytes = await image.read()
mask_bytes = await mask.read()
reference_bytes = await reference_image.read()
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L")
reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB")
# Ensure dimensions match between image, mask, and reference image
if original_image.size != mask_image.size:
raise HTTPException(status_code=400, detail="Image and mask dimensions must match.")
if original_image.size != reference_image.size:
reference_image = reference_image.resize(original_image.size, Image.Resampling.LANCZOS)
# Blend the original and reference images using the mask
blended_image = blend_images(original_image, reference_image, mask_image)
# Perform inpainting using the pipeline
result = pipe(prompt=prompt, image=blended_image, mask_image=mask_image).images[0]
# Convert result to bytes for response
result_bytes = io.BytesIO()
result.save(result_bytes, format="PNG")
result_bytes.seek(0)
# Return the image as a streaming response
return StreamingResponse(
result_bytes,
media_type="image/png",
headers={"Content-Disposition": "attachment; filename=inpainted_with_reference_image.png"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during inpainting with reference: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)