File size: 9,999 Bytes
be6ce26
dd59a2b
7ba3751
be6ce26
 
dd59a2b
be6ce26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1485575
 
 
 
a14b836
dd59a2b
2b9ad07
dd59a2b
2b9ad07
57cd301
 
dd59a2b
 
 
a14b836
dd59a2b
57cd301
a14b836
be6ce26
2b9ad07
 
 
 
 
 
 
7ba3751
 
 
57cd301
7ba3751
57cd301
7ba3751
57cd301
7ba3751
 
a14b836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19259ed
 
 
 
a14b836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be6ce26
 
 
b8546a6
 
be6ce26
 
b8546a6
 
 
 
 
 
 
be6ce26
 
b8546a6
be6ce26
b8546a6
 
be6ce26
b8546a6
 
 
 
 
 
 
be6ce26
b8546a6
 
be6ce26
 
 
b8546a6
 
1485575
 
 
 
 
be6ce26
 
 
dd59a2b
 
 
 
7ba3751
 
 
 
 
dd59a2b
 
b8546a6
a14b836
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8546a6
dd59a2b
 
a14b836
dd59a2b
 
 
 
 
a14b836
19259ed
 
b8546a6
19259ed
7ba3751
2b9ad07
 
dd59a2b
a14b836
2b9ad07
 
 
a14b836
 
 
2b9ad07
dd59a2b
 
 
 
 
 
 
 
a14b836
dd59a2b
19259ed
 
dd59a2b
19259ed
dd59a2b
be6ce26
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from PIL import Image, ImageDraw
import io
import torch
import numpy as np
from diffusers import StableDiffusionInpaintPipeline

# Initialize FastAPI app
app = FastAPI()

# Load the pre-trained inpainting model (Stable Diffusion)
model_id = "runwayml/stable-diffusion-inpainting"
device = "cuda" if torch.cuda.is_available() else "cpu"

try:
    pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id)
    pipe.to(device)
except Exception as e:
    raise RuntimeError(f"Failed to load model: {e}")

@app.get("/")
async def root():
    """
    Root endpoint for basic health check.
    """
    return {"message": "InstructPix2Pix API is running. Use POST /inpaint/, /inpaint-with-reference/, or /fit-image-to-mask/ to edit images."}

def prepare_guided_image(original_image: Image, reference_image: Image, mask_image: Image) -> Image:
    """
    Prepare an initial image by softly blending the reference image into the masked area.
    - Areas to keep (black in mask, 0) remain fully from the original image.
    - Areas to inpaint (white in mask, 255) take content from the reference image with soft blending.
    """
    original_array = np.array(original_image)
    reference_array = np.array(reference_image)
    mask_array = np.array(mask_image) / 255.0
    mask_array = mask_array[:, :, np.newaxis]
    blended_array = original_array * (1 - mask_array) + reference_array * mask_array
    return Image.fromarray(blended_array.astype(np.uint8))

def soften_mask(mask_image: Image, softness: int = 5) -> Image:
    """
    Soften the edges of the mask for smoother transitions.
    """
    from PIL import ImageFilter
    return mask_image.filter(ImageFilter.GaussianBlur(radius=softness))

def generate_rectangular_mask(image_size: tuple, x1: int = 100, y1: int = 100, x2: int = 200, y2: int = 200) -> Image:
    """
    Generate a rectangular mask matching the image dimensions.
    - Black (0) for areas to keep, white (255) for areas to inpaint.
    """
    mask = Image.new("L", image_size, 0)
    draw = ImageDraw.Draw(mask)
    draw.rectangle([x1, y1, x2, y2], fill=255)
    return mask

def fit_image_to_mask(original_image: Image, reference_image: Image, mask_x1: int, mask_y1: int, mask_x2: int, mask_y2: int) -> tuple:
    """
    Fit the reference image into the masked region of the original image.
    
    Args:
        original_image (Image): The original image (RGB).
        reference_image (Image): The image to fit into the masked region (RGB).
        mask_x1, mask_y1, mask_x2, mask_y2 (int): Coordinates of the masked region.
    
    Returns:
        tuple: (guided_image, mask_image) - The image with the fitted reference and the corresponding mask.
    """
    # Calculate mask dimensions
    mask_width = mask_x2 - mask_x1
    mask_height = mask_y2 - mask_y1

    # Ensure mask dimensions are positive
    if mask_width <= 0 or mask_height <= 0:
        raise ValueError("Mask dimensions must be positive")

    # Resize reference image to fit the mask while preserving aspect ratio
    ref_width, ref_height = reference_image.size
    aspect_ratio = ref_width / ref_height

    if mask_width / mask_height > aspect_ratio:
        # Fit to height
        new_height = mask_height
        new_width = int(new_height * aspect_ratio)
    else:
        # Fit to width
        new_width = mask_width
        new_height = int(new_width / aspect_ratio)

    # Resize reference image
    reference_image_resized = reference_image.resize((new_width, new_height), Image.Resampling.LANCZOS)

    # Create a copy of the original image to paste the reference image onto
    guided_image = original_image.copy()

    # Calculate position to center the resized image in the mask
    paste_x = mask_x1 + (mask_width - new_width) // 2
    paste_y = mask_y1 + (mask_height - new_height) // 2

    # Paste the resized reference image onto the original image
    guided_image.paste(reference_image_resized, (paste_x, paste_y))

    # Generate the mask for inpainting (white in the pasted region)
    mask_image = generate_rectangular_mask(original_image.size, mask_x1, mask_y1, mask_x2, mask_y2)

    return guided_image, mask_image

@app.post("/inpaint/")
async def inpaint_image(
    image: UploadFile = File(...),
    mask: UploadFile = File(...),
    prompt: str = "Fill the masked area with appropriate content."
):
    """
    Endpoint for image inpainting using a text prompt and an uploaded mask.
    - `image`: Original image file (PNG/JPG).
    - `mask`: Mask file indicating areas to inpaint (white for masked areas, black for unmasked).
    - `prompt`: Text prompt describing the desired output.
    
    Returns:
    - The inpainted image as a PNG file.
    """
    try:
        # Load the uploaded image and mask
        image_bytes = await image.read()
        mask_bytes = await mask.read()

        original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L")

        # Ensure dimensions match between image and mask
        if original_image.size != mask_image.size:
            raise HTTPException(status_code=400, detail="Image and mask dimensions must match.")

        # Perform inpainting using the pipeline
        result = pipe(prompt=prompt, image=original_image, mask_image=mask_image).images[0]

        # Convert result to bytes for response
        result_bytes = io.BytesIO()
        result.save(result_bytes, format="PNG")
        result_bytes.seek(0)

        # Return the image as a streaming response
        return StreamingResponse(
            result_bytes,
            media_type="image/png",
            headers={"Content-Disposition": "attachment; filename=inpainted_image.png"}
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error during inpainting: {e}")

@app.post("/inpaint-with-reference/")
async def inpaint_with_reference(
    image: UploadFile = File(...),
    reference_image: UploadFile = File(...),
    prompt: str = "Integrate the reference content naturally into the masked area, matching style and lighting.",
    mask_x1: int = 100,
    mask_y1: int = 100,
    mask_x2: int = 200,
    mask_y2: int = 200
):
    """
    Endpoint for replacing masked areas with reference image content, refined to look natural, using an autogenerated mask.
    """
    try:
        image_bytes = await image.read()
        reference_bytes = await reference_image.read()
        original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB")

        if original_image.size != reference_image.size:
            reference_image = reference_image.resize(original_image.size, Image.Resampling.LANCZOS)

        mask_image = generate_rectangular_mask(original_image.size, mask_x1, mask_y1, mask_x2, mask_y2)
        softened_mask = soften_mask(mask_image, softness=5)
        guided_image = prepare_guided_image(original_image, reference_image, softened_mask)
        result = pipe(
            prompt=prompt,
            image=guided_image,
            mask_image=softened_mask,
            strength=0.75,
            guidance_scale=7.5
        ).images[0]

        result_bytes = io.BytesIO()
        result.save(result_bytes, format="PNG")
        result_bytes.seek(0)
        return StreamingResponse(
            result_bytes,
            media_type="image/png",
            headers={"Content-Disposition": "attachment; filename=natural_inpaint_image.png"}
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error during natural inpainting: {e}")

@app.post("/fit-image-to-mask/")
async def fit_image_to_mask(
    image: UploadFile = File(...),
    reference_image: UploadFile = File(...),
    prompt: str = "Blend the fitted image naturally into the scene, matching style and lighting.",
    mask_x1: int = 100,
    mask_y1: int = 100,
    mask_x2: int = 200,
    mask_y2: int = 200
):
    """
    Endpoint for fitting a reference image into a masked region of the original image, refined to look natural, using an autogenerated mask.
    """
    try:
        # Load the uploaded images
        image_bytes = await image.read()
        reference_bytes = await reference_image.read()
        original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB")

        # Fit the reference image into the masked region
        result = fit_image_to_mask(original_image, reference_image, mask_x1, mask_y1, mask_x2, mask_y2)
        if not isinstance(result, tuple) or len(result) != 2:
            raise ValueError(f"Expected tuple of (guided_image, mask_image), got {type(result)}: {result}")
        guided_image, mask_image = result

        # Soften the mask for smoother transitions
        softened_mask = soften_mask(mask_image, softness=5)

        # Perform inpainting to blend the fitted image naturally
        result = pipe(
            prompt=prompt,
            image=guided_image,
            mask_image=softened_mask,
            strength=0.75,
            guidance_scale=7.5
        ).images[0]

        # Convert result to bytes for response
        result_bytes = io.BytesIO()
        result.save(result_bytes, format="PNG")
        result_bytes.seek(0)
        return StreamingResponse(
            result_bytes,
            media_type="image/png",
            headers={"Content-Disposition": "attachment; filename=fitted_image.png"}
        )
    except ValueError as ve:
        raise HTTPException(status_code=500, detail=f"ValueError in processing: {str(ve)}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error during fitting and inpainting: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)