File size: 7,064 Bytes
be6ce26
dd59a2b
be6ce26
 
 
dd59a2b
be6ce26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1485575
 
 
 
dd59a2b
 
2b9ad07
dd59a2b
2b9ad07
 
 
dd59a2b
 
 
2b9ad07
dd59a2b
 
 
2b9ad07
dd59a2b
 
 
 
2b9ad07
dd59a2b
2b9ad07
dd59a2b
 
2b9ad07
dd59a2b
 
 
 
be6ce26
2b9ad07
 
 
 
 
 
 
 
 
 
 
 
 
 
be6ce26
 
 
 
 
 
 
dd59a2b
be6ce26
 
 
1485575
 
 
be6ce26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1485575
 
 
 
 
 
be6ce26
 
 
 
dd59a2b
 
 
 
 
2b9ad07
dd59a2b
 
2b9ad07
dd59a2b
2b9ad07
 
 
dd59a2b
 
2b9ad07
dd59a2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b9ad07
 
dd59a2b
2b9ad07
 
 
 
 
 
 
 
 
 
 
dd59a2b
 
 
 
 
 
 
 
 
 
2b9ad07
dd59a2b
 
 
2b9ad07
dd59a2b
be6ce26
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from PIL import Image
import io
import torch
import numpy as np
from diffusers import StableDiffusionInpaintPipeline

# Initialize FastAPI app
app = FastAPI()

# Load the pre-trained inpainting model (Stable Diffusion)
model_id = "runwayml/stable-diffusion-inpainting"
device = "cuda" if torch.cuda.is_available() else "cpu"

try:
    pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id)
    pipe.to(device)
except Exception as e:
    raise RuntimeError(f"Failed to load model: {e}")

@app.get("/")
async def root():
    """
    Root endpoint for basic health check.
    """
    return {"message": "InstructPix2Pix API is running. Use POST /inpaint/ or /inpaint-with-reference/ to edit images."}

def prepare_guided_image(original_image: Image, reference_image: Image, mask_image: Image) -> Image:
    """
    Prepare an initial image by softly blending the reference image into the masked area.
    - Unmasked areas (white in mask, 255) remain fully from the original image.
    - Masked areas (black in mask, 0) take content from the reference image with soft blending.
    
    Args:
        original_image (Image): The original image (RGB).
        reference_image (Image): The reference image to copy from (RGB).
        mask_image (Image): The mask image (grayscale, L mode).
    
    Returns:
        Image: The blended image to guide inpainting.
    """
    # Convert images to numpy arrays
    original_array = np.array(original_image)
    reference_array = np.array(reference_image)
    mask_array = np.array(mask_image) / 255.0  # Normalize to [0, 1] for soft blending

    # Expand mask to RGB channels
    mask_array = mask_array[:, :, np.newaxis]

    # Softly blend: unmasked areas (1) keep original, masked areas (0) use reference
    blended_array = original_array * mask_array + reference_array * (1 - mask_array)
    blended_array = blended_array.astype(np.uint8)

    return Image.fromarray(blended_array)

def soften_mask(mask_image: Image, softness: int = 5) -> Image:
    """
    Soften the edges of the mask for smoother transitions.
    
    Args:
        mask_image (Image): The original mask (grayscale, L mode).
        softness (int): Size of the Gaussian blur kernel for softening edges.
    
    Returns:
        Image: The softened mask.
    """
    from PIL import ImageFilter
    return mask_image.filter(ImageFilter.GaussianBlur(radius=softness))

@app.post("/inpaint/")
async def inpaint_image(
    image: UploadFile = File(...),
    mask: UploadFile = File(...),
    prompt: str = "Fill the masked area with appropriate content."
):
    """
    Endpoint for image inpainting using a text prompt.
    - `image`: Original image file (PNG/JPG).
    - `mask`: Mask file indicating areas to inpaint (black for masked areas, white for unmasked).
    - `prompt`: Text prompt describing the desired output.
    
    Returns:
    - The inpainted image as a PNG file.
    """
    try:
        # Load the uploaded image and mask
        image_bytes = await image.read()
        mask_bytes = await mask.read()

        original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L")

        # Ensure dimensions match between image and mask
        if original_image.size != mask_image.size:
            raise HTTPException(status_code=400, detail="Image and mask dimensions must match.")

        # Perform inpainting using the pipeline
        result = pipe(prompt=prompt, image=original_image, mask_image=mask_image).images[0]

        # Convert result to bytes for response
        result_bytes = io.BytesIO()
        result.save(result_bytes, format="PNG")
        result_bytes.seek(0)

        # Return the image as a streaming response
        return StreamingResponse(
            result_bytes,
            media_type="image/png",
            headers={"Content-Disposition": "attachment; filename=inpainted_image.png"}
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error during inpainting: {e}")

@app.post("/inpaint-with-reference/")
async def inpaint_with_reference(
    image: UploadFile = File(...),
    mask: UploadFile = File(...),
    reference_image: UploadFile = File(...),
    prompt: str = "Integrate the reference content naturally into the masked area, matching style and lighting."
):
    """
    Endpoint for replacing masked areas with reference image content, refined to look natural.
    - `image`: Original image file (PNG/JPG).
    - `mask`: Mask file (black for areas to replace, white for areas to keep).
    - `reference_image`: Reference image to guide the replacement (PNG/JPG).
    - `prompt`: Text prompt for inpainting refinement.
    
    Returns:
    - The resulting image as a PNG file.
    """
    try:
        # Load the uploaded image, mask, and reference image
        image_bytes = await image.read()
        mask_bytes = await mask.read()
        reference_bytes = await reference_image.read()

        original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L")
        reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB")

        # Ensure dimensions match between image, mask, and reference image
        if original_image.size != mask_image.size:
            raise HTTPException(status_code=400, detail="Image and mask dimensions must match.")
        if original_image.size != reference_image.size:
            reference_image = reference_image.resize(original_image.size, Image.Resampling.LANCZOS)

        # Soften the mask for smoother transitions
        softened_mask = soften_mask(mask_image, softness=5)

        # Prepare the initial guided image by blending reference content into the masked area
        guided_image = prepare_guided_image(original_image, reference_image, softened_mask)

        # Perform inpainting to refine the result and make it look natural
        result = pipe(
            prompt=prompt,
            image=guided_image,
            mask_image=softened_mask,  # Use softened mask for inpainting
            strength=0.75,  # Control how much inpainting modifies the image (0.0 to 1.0)
            guidance_scale=7.5  # Control how closely the prompt is followed
        ).images[0]

        # Convert result to bytes for response
        result_bytes = io.BytesIO()
        result.save(result_bytes, format="PNG")
        result_bytes.seek(0)

        # Return the image as a streaming response
        return StreamingResponse(
            result_bytes,
            media_type="image/png",
            headers={"Content-Disposition": "attachment; filename=natural_inpaint_image.png"}
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error during natural inpainting: {e}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)