Spaces:
Paused
Paused
sachin
commited on
Commit
·
f7e3ec7
1
Parent(s):
475ca6b
add-instruct pix
Browse files- Dockerfile +1 -1
- intruct.py +96 -0
Dockerfile
CHANGED
@@ -35,4 +35,4 @@ USER appuser
|
|
35 |
EXPOSE 7860
|
36 |
|
37 |
# Run the server
|
38 |
-
CMD ["python", "/app/
|
|
|
35 |
EXPOSE 7860
|
36 |
|
37 |
# Run the server
|
38 |
+
CMD ["python", "/app/intruct.py"]
|
intruct.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, File, UploadFile, Form
|
2 |
+
from fastapi.responses import StreamingResponse
|
3 |
+
import io
|
4 |
+
import math
|
5 |
+
from PIL import Image, ImageOps
|
6 |
+
import torch
|
7 |
+
from diffusers import StableDiffusionInstructPix2PixPipeline
|
8 |
+
|
9 |
+
# Initialize FastAPI app
|
10 |
+
app = FastAPI()
|
11 |
+
|
12 |
+
# Load the pre-trained model once at startup
|
13 |
+
model_id = "timbrooks/instruct-pix2pix"
|
14 |
+
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
15 |
+
model_id, torch_dtype=torch.float16, safety_checker=None
|
16 |
+
).to("cuda")
|
17 |
+
|
18 |
+
# Default configuration values
|
19 |
+
DEFAULT_STEPS = 50
|
20 |
+
DEFAULT_TEXT_CFG = 7.5
|
21 |
+
DEFAULT_IMAGE_CFG = 1.5
|
22 |
+
DEFAULT_SEED = 1371
|
23 |
+
|
24 |
+
def process_image(input_image: Image.Image, instruction: str, steps: int, text_cfg_scale: float, image_cfg_scale: float, seed: int):
|
25 |
+
"""
|
26 |
+
Process the input image with the given instruction using InstructPix2Pix.
|
27 |
+
"""
|
28 |
+
# Resize image to fit model requirements
|
29 |
+
width, height = input_image.size
|
30 |
+
factor = 512 / max(width, height)
|
31 |
+
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
|
32 |
+
width = int((width * factor) // 64) * 64
|
33 |
+
height = int((height * factor) // 64) * 64
|
34 |
+
input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
|
35 |
+
|
36 |
+
if not instruction:
|
37 |
+
return input_image
|
38 |
+
|
39 |
+
# Set the random seed for reproducibility
|
40 |
+
generator = torch.manual_seed(seed)
|
41 |
+
|
42 |
+
# Generate the edited image
|
43 |
+
edited_image = pipe(
|
44 |
+
instruction,
|
45 |
+
image=input_image,
|
46 |
+
guidance_scale=text_cfg_scale,
|
47 |
+
image_guidance_scale=image_cfg_scale,
|
48 |
+
num_inference_steps=steps,
|
49 |
+
generator=generator,
|
50 |
+
).images[0]
|
51 |
+
|
52 |
+
return edited_image
|
53 |
+
|
54 |
+
@app.post("/edit-image/")
|
55 |
+
async def edit_image(
|
56 |
+
file: UploadFile = File(...),
|
57 |
+
instruction: str = Form(...),
|
58 |
+
steps: int = Form(default=DEFAULT_STEPS),
|
59 |
+
text_cfg_scale: float = Form(default=DEFAULT_TEXT_CFG),
|
60 |
+
image_cfg_scale: float = Form(default=DEFAULT_IMAGE_CFG),
|
61 |
+
seed: int = Form(default=DEFAULT_SEED)
|
62 |
+
):
|
63 |
+
"""
|
64 |
+
Endpoint to edit an image based on a text instruction.
|
65 |
+
- file: The input image to edit.
|
66 |
+
- instruction: The text instruction for editing the image.
|
67 |
+
- steps: Number of inference steps.
|
68 |
+
- text_cfg_scale: Text CFG weight.
|
69 |
+
- image_cfg_scale: Image CFG weight.
|
70 |
+
- seed: Random seed for reproducibility.
|
71 |
+
"""
|
72 |
+
# Read and convert the uploaded image
|
73 |
+
image_data = await file.read()
|
74 |
+
input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
|
75 |
+
|
76 |
+
# Process the image
|
77 |
+
edited_image = process_image(input_image, instruction, steps, text_cfg_scale, image_cfg_scale, seed)
|
78 |
+
|
79 |
+
# Convert the edited image to bytes
|
80 |
+
img_byte_arr = io.BytesIO()
|
81 |
+
edited_image.save(img_byte_arr, format="PNG")
|
82 |
+
img_byte_arr.seek(0)
|
83 |
+
|
84 |
+
# Return the image as a streaming response
|
85 |
+
return StreamingResponse(img_byte_arr, media_type="image/png")
|
86 |
+
|
87 |
+
@app.get("/")
|
88 |
+
async def root():
|
89 |
+
"""
|
90 |
+
Root endpoint for basic health check.
|
91 |
+
"""
|
92 |
+
return {"message": "InstructPix2Pix API is running. Use POST /edit-image/ to edit images."}
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
import uvicorn
|
96 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|