ghost-vision / runway.py
sachin
inpaint-mask
b8546a6
raw
history blame
10 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from PIL import Image, ImageDraw
import io
import torch
import numpy as np
from diffusers import StableDiffusionInpaintPipeline
# Initialize FastAPI app
app = FastAPI()
# Load the pre-trained inpainting model (Stable Diffusion)
model_id = "runwayml/stable-diffusion-inpainting"
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id)
pipe.to(device)
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
@app.get("/")
async def root():
"""
Root endpoint for basic health check.
"""
return {"message": "InstructPix2Pix API is running. Use POST /inpaint/, /inpaint-with-reference/, or /fit-image-to-mask/ to edit images."}
def prepare_guided_image(original_image: Image, reference_image: Image, mask_image: Image) -> Image:
"""
Prepare an initial image by softly blending the reference image into the masked area.
- Areas to keep (black in mask, 0) remain fully from the original image.
- Areas to inpaint (white in mask, 255) take content from the reference image with soft blending.
"""
original_array = np.array(original_image)
reference_array = np.array(reference_image)
mask_array = np.array(mask_image) / 255.0
mask_array = mask_array[:, :, np.newaxis]
blended_array = original_array * (1 - mask_array) + reference_array * mask_array
return Image.fromarray(blended_array.astype(np.uint8))
def soften_mask(mask_image: Image, softness: int = 5) -> Image:
"""
Soften the edges of the mask for smoother transitions.
"""
from PIL import ImageFilter
return mask_image.filter(ImageFilter.GaussianBlur(radius=softness))
def generate_rectangular_mask(image_size: tuple, x1: int = 100, y1: int = 100, x2: int = 200, y2: int = 200) -> Image:
"""
Generate a rectangular mask matching the image dimensions.
- Black (0) for areas to keep, white (255) for areas to inpaint.
"""
mask = Image.new("L", image_size, 0)
draw = ImageDraw.Draw(mask)
draw.rectangle([x1, y1, x2, y2], fill=255)
return mask
def fit_image_to_mask(original_image: Image, reference_image: Image, mask_x1: int, mask_y1: int, mask_x2: int, mask_y2: int) -> tuple:
"""
Fit the reference image into the masked region of the original image.
Args:
original_image (Image): The original image (RGB).
reference_image (Image): The image to fit into the masked region (RGB).
mask_x1, mask_y1, mask_x2, mask_y2 (int): Coordinates of the masked region.
Returns:
tuple: (guided_image, mask_image) - The image with the fitted reference and the corresponding mask.
"""
# Calculate mask dimensions
mask_width = mask_x2 - mask_x1
mask_height = mask_y2 - mask_y1
# Ensure mask dimensions are positive
if mask_width <= 0 or mask_height <= 0:
raise ValueError("Mask dimensions must be positive")
# Resize reference image to fit the mask while preserving aspect ratio
ref_width, ref_height = reference_image.size
aspect_ratio = ref_width / ref_height
if mask_width / mask_height > aspect_ratio:
# Fit to height
new_height = mask_height
new_width = int(new_height * aspect_ratio)
else:
# Fit to width
new_width = mask_width
new_height = int(new_width / aspect_ratio)
# Resize reference image
reference_image_resized = reference_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Create a copy of the original image to paste the reference image onto
guided_image = original_image.copy()
# Calculate position to center the resized image in the mask
paste_x = mask_x1 + (mask_width - new_width) // 2
paste_y = mask_y1 + (mask_height - new_height) // 2
# Paste the resized reference image onto the original image
guided_image.paste(reference_image_resized, (paste_x, paste_y))
# Generate the mask for inpainting (white in the pasted region)
mask_image = generate_rectangular_mask(original_image.size, mask_x1, mask_y1, mask_x2, mask_y2)
return guided_image, mask_image
@app.post("/inpaint/")
async def inpaint_image(
image: UploadFile = File(...),
mask: UploadFile = File(...),
prompt: str = "Fill the masked area with appropriate content."
):
"""
Endpoint for image inpainting using a text prompt and an uploaded mask.
- `image`: Original image file (PNG/JPG).
- `mask`: Mask file indicating areas to inpaint (white for masked areas, black for unmasked).
- `prompt`: Text prompt describing the desired output.
Returns:
- The inpainted image as a PNG file.
"""
try:
# Load the uploaded image and mask
image_bytes = await image.read()
mask_bytes = await mask.read()
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L")
# Ensure dimensions match between image and mask
if original_image.size != mask_image.size:
raise HTTPException(status_code=400, detail="Image and mask dimensions must match.")
# Perform inpainting using the pipeline
result = pipe(prompt=prompt, image=original_image, mask_image=mask_image).images[0]
# Convert result to bytes for response
result_bytes = io.BytesIO()
result.save(result_bytes, format="PNG")
result_bytes.seek(0)
# Return the image as a streaming response
return StreamingResponse(
result_bytes,
media_type="image/png",
headers={"Content-Disposition": "attachment; filename=inpainted_image.png"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during inpainting: {e}")
@app.post("/inpaint-with-reference/")
async def inpaint_with_reference(
image: UploadFile = File(...),
reference_image: UploadFile = File(...),
prompt: str = "Integrate the reference content naturally into the masked area, matching style and lighting.",
mask_x1: int = 100,
mask_y1: int = 100,
mask_x2: int = 200,
mask_y2: int = 200
):
"""
Endpoint for replacing masked areas with reference image content, refined to look natural, using an autogenerated mask.
"""
try:
image_bytes = await image.read()
reference_bytes = await reference_image.read()
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB")
if original_image.size != reference_image.size:
reference_image = reference_image.resize(original_image.size, Image.Resampling.LANCZOS)
mask_image = generate_rectangular_mask(original_image.size, mask_x1, mask_y1, mask_x2, mask_y2)
softened_mask = soften_mask(mask_image, softness=5)
guided_image = prepare_guided_image(original_image, reference_image, softened_mask)
result = pipe(
prompt=prompt,
image=guided_image,
mask_image=softened_mask,
strength=0.75,
guidance_scale=7.5
).images[0]
result_bytes = io.BytesIO()
result.save(result_bytes, format="PNG")
result_bytes.seek(0)
return StreamingResponse(
result_bytes,
media_type="image/png",
headers={"Content-Disposition": "attachment; filename=natural_inpaint_image.png"}
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during natural inpainting: {e}")
@app.post("/fit-image-to-mask/")
async def fit_image_to_mask(
image: UploadFile = File(...),
reference_image: UploadFile = File(...),
prompt: str = "Blend the fitted image naturally into the scene, matching style and lighting.",
mask_x1: int = 100,
mask_y1: int = 100,
mask_x2: int = 200,
mask_y2: int = 200
):
"""
Endpoint for fitting a reference image into a masked region of the original image, refined to look natural, using an autogenerated mask.
"""
try:
# Load the uploaded images
image_bytes = await image.read()
reference_bytes = await reference_image.read()
original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB")
# Fit the reference image into the masked region
result = fit_image_to_mask(original_image, reference_image, mask_x1, mask_y1, mask_x2, mask_y2)
if not isinstance(result, tuple) or len(result) != 2:
raise ValueError(f"Expected tuple of (guided_image, mask_image), got {type(result)}: {result}")
guided_image, mask_image = result
# Soften the mask for smoother transitions
softened_mask = soften_mask(mask_image, softness=5)
# Perform inpainting to blend the fitted image naturally
result = pipe(
prompt=prompt,
image=guided_image,
mask_image=softened_mask,
strength=0.75,
guidance_scale=7.5
).images[0]
# Convert result to bytes for response
result_bytes = io.BytesIO()
result.save(result_bytes, format="PNG")
result_bytes.seek(0)
return StreamingResponse(
result_bytes,
media_type="image/png",
headers={"Content-Disposition": "attachment; filename=fitted_image.png"}
)
except ValueError as ve:
raise HTTPException(status_code=500, detail=f"ValueError in processing: {str(ve)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error during fitting and inpainting: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)