sachin commited on
Commit
f7e3ec7
·
1 Parent(s): 475ca6b

add-instruct pix

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. intruct.py +96 -0
Dockerfile CHANGED
@@ -35,4 +35,4 @@ USER appuser
35
  EXPOSE 7860
36
 
37
  # Run the server
38
- CMD ["python", "/app/main.py"]
 
35
  EXPOSE 7860
36
 
37
  # Run the server
38
+ CMD ["python", "/app/intruct.py"]
intruct.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile, Form
2
+ from fastapi.responses import StreamingResponse
3
+ import io
4
+ import math
5
+ from PIL import Image, ImageOps
6
+ import torch
7
+ from diffusers import StableDiffusionInstructPix2PixPipeline
8
+
9
+ # Initialize FastAPI app
10
+ app = FastAPI()
11
+
12
+ # Load the pre-trained model once at startup
13
+ model_id = "timbrooks/instruct-pix2pix"
14
+ pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
15
+ model_id, torch_dtype=torch.float16, safety_checker=None
16
+ ).to("cuda")
17
+
18
+ # Default configuration values
19
+ DEFAULT_STEPS = 50
20
+ DEFAULT_TEXT_CFG = 7.5
21
+ DEFAULT_IMAGE_CFG = 1.5
22
+ DEFAULT_SEED = 1371
23
+
24
+ def process_image(input_image: Image.Image, instruction: str, steps: int, text_cfg_scale: float, image_cfg_scale: float, seed: int):
25
+ """
26
+ Process the input image with the given instruction using InstructPix2Pix.
27
+ """
28
+ # Resize image to fit model requirements
29
+ width, height = input_image.size
30
+ factor = 512 / max(width, height)
31
+ factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
32
+ width = int((width * factor) // 64) * 64
33
+ height = int((height * factor) // 64) * 64
34
+ input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
35
+
36
+ if not instruction:
37
+ return input_image
38
+
39
+ # Set the random seed for reproducibility
40
+ generator = torch.manual_seed(seed)
41
+
42
+ # Generate the edited image
43
+ edited_image = pipe(
44
+ instruction,
45
+ image=input_image,
46
+ guidance_scale=text_cfg_scale,
47
+ image_guidance_scale=image_cfg_scale,
48
+ num_inference_steps=steps,
49
+ generator=generator,
50
+ ).images[0]
51
+
52
+ return edited_image
53
+
54
+ @app.post("/edit-image/")
55
+ async def edit_image(
56
+ file: UploadFile = File(...),
57
+ instruction: str = Form(...),
58
+ steps: int = Form(default=DEFAULT_STEPS),
59
+ text_cfg_scale: float = Form(default=DEFAULT_TEXT_CFG),
60
+ image_cfg_scale: float = Form(default=DEFAULT_IMAGE_CFG),
61
+ seed: int = Form(default=DEFAULT_SEED)
62
+ ):
63
+ """
64
+ Endpoint to edit an image based on a text instruction.
65
+ - file: The input image to edit.
66
+ - instruction: The text instruction for editing the image.
67
+ - steps: Number of inference steps.
68
+ - text_cfg_scale: Text CFG weight.
69
+ - image_cfg_scale: Image CFG weight.
70
+ - seed: Random seed for reproducibility.
71
+ """
72
+ # Read and convert the uploaded image
73
+ image_data = await file.read()
74
+ input_image = Image.open(io.BytesIO(image_data)).convert("RGB")
75
+
76
+ # Process the image
77
+ edited_image = process_image(input_image, instruction, steps, text_cfg_scale, image_cfg_scale, seed)
78
+
79
+ # Convert the edited image to bytes
80
+ img_byte_arr = io.BytesIO()
81
+ edited_image.save(img_byte_arr, format="PNG")
82
+ img_byte_arr.seek(0)
83
+
84
+ # Return the image as a streaming response
85
+ return StreamingResponse(img_byte_arr, media_type="image/png")
86
+
87
+ @app.get("/")
88
+ async def root():
89
+ """
90
+ Root endpoint for basic health check.
91
+ """
92
+ return {"message": "InstructPix2Pix API is running. Use POST /edit-image/ to edit images."}
93
+
94
+ if __name__ == "__main__":
95
+ import uvicorn
96
+ uvicorn.run(app, host="0.0.0.0", port=7860)