sachin commited on
Commit
dd59a2b
·
1 Parent(s): 1485575

add- blend image

Browse files
Files changed (1) hide show
  1. runway.py +86 -3
runway.py CHANGED
@@ -1,8 +1,9 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
- from fastapi.responses import StreamingResponse, Response
3
  from PIL import Image
4
  import io
5
  import torch
 
6
  from diffusers import StableDiffusionInpaintPipeline
7
 
8
  # Initialize FastAPI app
@@ -23,7 +24,35 @@ async def root():
23
  """
24
  Root endpoint for basic health check.
25
  """
26
- return {"message": "InstructPix2Pix API is running. Use POST /edit-image/ or /inpaint/ to edit images."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  @app.post("/inpaint/")
29
  async def inpaint_image(
@@ -32,7 +61,7 @@ async def inpaint_image(
32
  prompt: str = "Fill the masked area with appropriate content."
33
  ):
34
  """
35
- Endpoint for image inpainting.
36
  - `image`: Original image file (PNG/JPG).
37
  - `mask`: Mask file indicating areas to inpaint (black for masked areas, white for unmasked).
38
  - `prompt`: Text prompt describing the desired output.
@@ -70,6 +99,60 @@ async def inpaint_image(
70
  except Exception as e:
71
  raise HTTPException(status_code=500, detail=f"Error during inpainting: {e}")
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  if __name__ == "__main__":
74
  import uvicorn
75
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import StreamingResponse
3
  from PIL import Image
4
  import io
5
  import torch
6
+ import numpy as np
7
  from diffusers import StableDiffusionInpaintPipeline
8
 
9
  # Initialize FastAPI app
 
24
  """
25
  Root endpoint for basic health check.
26
  """
27
+ return {"message": "InstructPix2Pix API is running. Use POST /inpaint/ or /inpaint-with-reference/ to edit images."}
28
+
29
+ def blend_images(original_image: Image, reference_image: Image, mask_image: Image) -> Image:
30
+ """
31
+ Blend the original image with a reference image using the mask.
32
+ - Unmasked areas (white in mask) take pixels from the original image.
33
+ - Masked areas (black in mask) take pixels from the reference image as a starting point.
34
+
35
+ Args:
36
+ original_image (Image): The original image (RGB).
37
+ reference_image (Image): The reference image to blend with (RGB).
38
+ mask_image (Image): The mask image (grayscale, L mode).
39
+
40
+ Returns:
41
+ Image: The blended image.
42
+ """
43
+ # Convert images to numpy arrays
44
+ original_array = np.array(original_image)
45
+ reference_array = np.array(reference_image)
46
+ mask_array = np.array(mask_image) / 255.0 # Normalize mask to [0, 1]
47
+
48
+ # Ensure mask is broadcastable to RGB channels
49
+ mask_array = mask_array[:, :, np.newaxis]
50
+
51
+ # Blend: unmasked areas (mask=1) keep original, masked areas (mask=0) use reference
52
+ blended_array = original_array * mask_array + reference_array * (1 - mask_array)
53
+ blended_array = blended_array.astype(np.uint8)
54
+
55
+ return Image.fromarray(blended_array)
56
 
57
  @app.post("/inpaint/")
58
  async def inpaint_image(
 
61
  prompt: str = "Fill the masked area with appropriate content."
62
  ):
63
  """
64
+ Endpoint for image inpainting using a text prompt.
65
  - `image`: Original image file (PNG/JPG).
66
  - `mask`: Mask file indicating areas to inpaint (black for masked areas, white for unmasked).
67
  - `prompt`: Text prompt describing the desired output.
 
99
  except Exception as e:
100
  raise HTTPException(status_code=500, detail=f"Error during inpainting: {e}")
101
 
102
+ @app.post("/inpaint-with-reference/")
103
+ async def inpaint_with_reference(
104
+ image: UploadFile = File(...),
105
+ mask: UploadFile = File(...),
106
+ reference_image: UploadFile = File(...),
107
+ prompt: str = "Fill the masked area with appropriate content."
108
+ ):
109
+ """
110
+ Endpoint for image inpainting using both a text prompt and a reference image.
111
+ - `image`: Original image file (PNG/JPG).
112
+ - `mask`: Mask file indicating areas to inpaint (black for masked areas, white for unmasked).
113
+ - `reference_image`: Reference image to guide the inpainting (PNG/JPG).
114
+ - `prompt`: Text prompt describing the desired output.
115
+
116
+ Returns:
117
+ - The inpainted image as a PNG file.
118
+ """
119
+ try:
120
+ # Load the uploaded image, mask, and reference image
121
+ image_bytes = await image.read()
122
+ mask_bytes = await mask.read()
123
+ reference_bytes = await reference_image.read()
124
+
125
+ original_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
126
+ mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L")
127
+ reference_image = Image.open(io.BytesIO(reference_bytes)).convert("RGB")
128
+
129
+ # Ensure dimensions match between image, mask, and reference image
130
+ if original_image.size != mask_image.size:
131
+ raise HTTPException(status_code=400, detail="Image and mask dimensions must match.")
132
+ if original_image.size != reference_image.size:
133
+ reference_image = reference_image.resize(original_image.size, Image.Resampling.LANCZOS)
134
+
135
+ # Blend the original and reference images using the mask
136
+ blended_image = blend_images(original_image, reference_image, mask_image)
137
+
138
+ # Perform inpainting using the pipeline
139
+ result = pipe(prompt=prompt, image=blended_image, mask_image=mask_image).images[0]
140
+
141
+ # Convert result to bytes for response
142
+ result_bytes = io.BytesIO()
143
+ result.save(result_bytes, format="PNG")
144
+ result_bytes.seek(0)
145
+
146
+ # Return the image as a streaming response
147
+ return StreamingResponse(
148
+ result_bytes,
149
+ media_type="image/png",
150
+ headers={"Content-Disposition": "attachment; filename=inpainted_with_reference_image.png"}
151
+ )
152
+
153
+ except Exception as e:
154
+ raise HTTPException(status_code=500, detail=f"Error during inpainting with reference: {e}")
155
+
156
  if __name__ == "__main__":
157
  import uvicorn
158
  uvicorn.run(app, host="0.0.0.0", port=7860)