File size: 6,106 Bytes
be6ce26
dd59a2b
be6ce26
 
 
dd59a2b
be6ce26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1485575
 
 
 
dd59a2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be6ce26
 
 
 
 
 
 
 
dd59a2b
be6ce26
 
 
1485575
 
 
be6ce26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1485575
 
 
 
 
 
be6ce26
 
 
 
dd59a2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be6ce26
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from PIL import Image
import io
import torch
import numpy as np
from diffusers import StableDiffusionInpaintPipeline

# Initialize FastAPI app
app = FastAPI()

# Load the pre-trained inpainting model (Stable Diffusion)
model_id = "runwayml/stable-diffusion-inpainting"
device = "cuda" if torch.cuda.is_available() else "cpu"

try:
    pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id)
    pipe.to(device)
except Exception as e:
    raise RuntimeError(f"Failed to load model: {e}")

@app.get("/")
async def root():
    """
    Root endpoint for basic health check.
    """
    return {"message": "InstructPix2Pix API is running. Use POST /inpaint/ or /inpaint-with-reference/ to edit images."}

def blend_images(original_image: Image, reference_image: Image, mask_image: Image) -> Image:
    """
    Blend the original image with a reference image using the mask.
    - Unmasked areas (white in mask) take pixels from the original image.
    - Masked areas (black in mask) take pixels from the reference image as a starting point.
    
    Args:
        original_image (Image): The original image (RGB).
        reference_image (Image): The reference image to blend with (RGB).
        mask_image (Image): The mask image (grayscale, L mode).
    
    Returns:
        Image: The blended image.
    """
    # Convert images to numpy arrays
    original_array = np.array(original_image)
    reference_array = np.array(reference_image)
    mask_array = np.array(mask_image) / 255.0  # Normalize mask to [0, 1]

    # Ensure mask is broadcastable to RGB channels
    mask_array = mask_array[:, :, np.newaxis]

    # Blend: unmasked areas (mask=1) keep original, masked areas (mask=0) use reference
    blended_array = original_array * mask_array + reference_array * (1 - mask_array)
    blended_array = blended_array.astype(np.uint8)

    return Image.fromarray(blended_array)

@app.post("/inpaint/")
async def inpaint_image(
    image: UploadFile = File(...),
    mask: UploadFile = File(...),
    prompt: str = "Fill the masked area with appropriate content."
):
    """
    Endpoint for image inpainting using a text prompt.
    - `image`: Original image file (PNG/JPG).
    - `mask`: Mask file indicating areas to inpaint (black for masked areas, white for unmasked).
    - `prompt`: Text prompt describing the desired output.
    
    Returns:
    - The inpainted image as a PNG file.
    """
    try:
        # Load the uploaded image and mask
        image_bytes = await image.read()
        mask_bytes = await mask.read()

        original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L")

        # Ensure dimensions match between image and mask
        if original_image.size != mask_image.size:
            raise HTTPException(status_code=400, detail="Image and mask dimensions must match.")

        # Perform inpainting using the pipeline
        result = pipe(prompt=prompt, image=original_image, mask_image=mask_image).images[0]

        # Convert result to bytes for response
        result_bytes = io.BytesIO()
        result.save(result_bytes, format="PNG")
        result_bytes.seek(0)

        # Return the image as a streaming response
        return StreamingResponse(
            result_bytes,
            media_type="image/png",
            headers={"Content-Disposition": "attachment; filename=inpainted_image.png"}
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error during inpainting: {e}")

@app.post("/inpaint-with-reference/")
async def inpaint_with_reference(
    image: UploadFile = File(...),
    mask: UploadFile = File(...),
    reference_image: UploadFile = File(...),
    prompt: str = "Fill the masked area with appropriate content."
):
    """
    Endpoint for image inpainting using both a text prompt and a reference image.
    - `image`: Original image file (PNG/JPG).
    - `mask`: Mask file indicating areas to inpaint (black for masked areas, white for unmasked).
    - `reference_image`: Reference image to guide the inpainting (PNG/JPG).
    - `prompt`: Text prompt describing the desired output.
    
    Returns:
    - The inpainted image as a PNG file.
    """
    try:
        # Load the uploaded image, mask, and reference image
        image_bytes = await image.read()
        mask_bytes = await mask.read()
        reference_bytes = await reference_image.read()

        original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L")
        reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB")

        # Ensure dimensions match between image, mask, and reference image
        if original_image.size != mask_image.size:
            raise HTTPException(status_code=400, detail="Image and mask dimensions must match.")
        if original_image.size != reference_image.size:
            reference_image = reference_image.resize(original_image.size, Image.Resampling.LANCZOS)

        # Blend the original and reference images using the mask
        blended_image = blend_images(original_image, reference_image, mask_image)

        # Perform inpainting using the pipeline
        result = pipe(prompt=prompt, image=blended_image, mask_image=mask_image).images[0]

        # Convert result to bytes for response
        result_bytes = io.BytesIO()
        result.save(result_bytes, format="PNG")
        result_bytes.seek(0)

        # Return the image as a streaming response
        return StreamingResponse(
            result_bytes,
            media_type="image/png",
            headers={"Content-Disposition": "attachment; filename=inpainted_with_reference_image.png"}
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error during inpainting with reference: {e}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)