rayochoajr commited on
Commit
3a88d20
ยท
verified ยท
1 Parent(s): 86bb1a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -90
app.py CHANGED
@@ -1,102 +1,250 @@
1
  import requests
2
  from requests.adapters import HTTPAdapter
3
- from requests.packages.urllib3.util.retry import Retry
4
  import json
5
  import base64
6
  import time
7
- import gradio as gr
8
- from PIL import Image
9
- from io import BytesIO
10
  import os
 
 
 
 
 
 
11
 
12
- host = "http://18.119.36.46:8888"
13
-
14
- def image_prompt(prompt, image1, image2, image3, image4):
15
- session = requests.Session()
16
- retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])
17
- session.mount('http://', HTTPAdapter(max_retries=retries))
18
-
19
- # Read and encode images
20
- image_paths = [image1, image2, image3, image4]
21
- image_data = [
22
- {
23
- "cn_img": base64.b64encode(open(image_path, "rb").read()).decode('utf-8'),
24
- "cn_stop": 1,
25
- "cn_weight": 1,
26
- "cn_type": "ImagePrompt"
27
- } for image_path in image_paths if image_path
28
- ]
29
-
30
- params = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  "prompt": prompt,
32
- "image_prompts": image_data,
33
- "async_process": True
34
- }
35
-
36
- response = session.post(
37
- url=f"{host}/v2/generation/text-to-image-with-ip",
38
- data=json.dumps(params),
39
- headers={"Content-Type": "application/json"},
40
- timeout=10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  )
 
 
 
 
 
 
 
 
 
 
42
 
43
- result = response.json()
44
- job_id = result.get('job_id')
45
-
46
- if not job_id:
47
- return None, "Job ID not found."
48
-
49
- # Polling for job status
50
- start_time = time.time()
51
- max_wait_time = 300 # 5 minutes max wait time
52
- while time.time() - start_time < max_wait_time:
53
- query_url = f"{host}/v1/generation/query-job?job_id={job_id}&require_step_preview=true"
54
- response = session.get(query_url, timeout=10)
55
- job_data = response.json()
56
-
57
- job_stage = job_data.get("job_stage")
58
- job_step_preview = job_data.get("job_step_preview")
59
- job_result = job_data.get("job_result")
60
-
61
- # If there is a step preview, display it
62
- if job_step_preview:
63
- step_image = Image.open(BytesIO(base64.b64decode(job_step_preview)))
64
- return step_image, "Processing..." # Update the gr.Image widget with step preview
65
-
66
- # If the job is completed successfully, display the final image
67
- if job_stage == "SUCCESS":
68
- final_image_url = job_result[0].get("url")
69
- if final_image_url:
70
- final_image_url = final_image_url.replace("127.0.0.1", "18.119.36.46")
71
- image_response = session.get(final_image_url, timeout=10)
72
- final_image = Image.open(BytesIO(image_response.content))
73
- return final_image, "Job completed successfully."
74
- return None, "Final image URL not found in the job data."
75
-
76
- # If the job failed
77
- elif job_stage == "FAILED":
78
- return None, "Job failed."
79
-
80
- # If the job is still running, continue polling
81
- time.sleep(2)
82
-
83
- return None, "Job timed out."
84
-
85
- def gradio_app():
86
  with gr.Blocks() as demo:
87
- prompt = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here")
88
  with gr.Row():
89
- image1 = gr.Image(label="Image Prompt 1", type="filepath")
90
- image2 = gr.Image(label="Image Prompt 2", type="filepath")
91
- image3 = gr.Image(label="Image Prompt 3", type="filepath")
92
- image4 = gr.Image(label="Image Prompt 4", type="filepath")
93
- output_image = gr.Image(label="Generated Image")
94
- status = gr.Textbox(label="Status")
95
-
96
- generate_button = gr.Button("Generate Image")
97
- generate_button.click(image_prompt, inputs=[prompt, image1, image2, image3, image4], outputs=[output_image, status])
98
-
99
- demo.launch()
100
-
101
- if __name__ == "__main__":
102
- gradio_app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import requests
2
  from requests.adapters import HTTPAdapter
3
+ from urllib3.util.retry import Retry
4
  import json
5
  import base64
6
  import time
 
 
 
7
  import os
8
+ import random
9
+ import io
10
+ from dotenv import load_dotenv
11
+ import replicate
12
+ from PIL import Image, ImageOps
13
+ from io import BytesIO
14
 
15
+ # Load environment variables
16
+ load_dotenv()
17
+ # Constants
18
+ REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
19
+
20
+ # Create the tab for the image analyzer
21
+ def image_analyzer_tab():
22
+ # Function to analyze the image
23
+ def analyze_image(image):
24
+ buffered = BytesIO()
25
+ image.save(buffered, format="PNG")
26
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
27
+ analysis = replicate.run(
28
+ "andreasjansson/blip-2:4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608",
29
+ input={"image": "data:image/png;base64," + img_str, "prompt": "what's in this picture?"}
30
+ )
31
+ return analysis
32
+
33
+
34
+
35
+ class Config:
36
+ REPLICATE_API_TOKEN = REPLICATE_API_TOKEN
37
+
38
+ class ImageUtils:
39
+ @staticmethod
40
+ def image_to_base64(image):
41
+ buffered = io.BytesIO()
42
+ image.save(buffered, format="JPEG")
43
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
44
+
45
+ @staticmethod
46
+ def convert_image_mode(image, mode="RGB"):
47
+ if image.mode != mode:
48
+ return image.convert(mode)
49
+ return image
50
+
51
+ def pad_image(image, padding_color=(255, 255, 255)):
52
+ width, height = image.size
53
+ new_width = width + 20
54
+ new_height = height + 20
55
+ result = Image.new(image.mode, (new_width, new_height), padding_color)
56
+ result.paste(image, (10, 10))
57
+ return result
58
+
59
+ def resize_and_pad_image(image, target_width, target_height, padding_color=(255, 255, 255)):
60
+ original_width, original_height = image.size
61
+ aspect_ratio = original_width / original_height
62
+ target_aspect_ratio = target_width / target_height
63
+
64
+ if aspect_ratio > target_aspect_ratio:
65
+ new_width = target_width
66
+ new_height = int(target_width / aspect_ratio)
67
+ else:
68
+ new_width = int(target_height * aspect_ratio)
69
+ new_height = target_height
70
+
71
+ resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
72
+ padded_image = Image.new(image.mode, (target_width, target_height), padding_color)
73
+ padded_image.paste(resized_image, ((target_width - new_width) // 2, (target_height - new_height) // 2))
74
+ return padded_image
75
+
76
+ def image_prompt(prompt, cn_img1, cn_img2, cn_img3, cn_img4, weight1, weight2, weight3, weight4):
77
+ cn_img1 = pad_image(cn_img1)
78
+ buffered1 = BytesIO()
79
+ cn_img1.save(buffered1, format="PNG")
80
+ cn_img1_base64 = base64.b64encode(buffered1.getvalue()).decode('utf-8')
81
+
82
+ buffered2 = BytesIO()
83
+ cn_img2.save(buffered2, format="PNG")
84
+ cn_img2_base64 = base64.b64encode(buffered2.getvalue()).decode('utf-8')
85
+
86
+ buffered3 = BytesIO()
87
+ cn_img3.save(buffered3, format="PNG")
88
+ cn_img3_base64 = base64.b64encode(buffered3.getvalue()).decode('utf-8')
89
+
90
+ buffered4 = BytesIO()
91
+ cn_img4.save(buffered4, format="PNG")
92
+ cn_img4_base64 = base64.b64encode(buffered4.getvalue()).decode('utf-8')
93
+
94
+ # Resize and pad the sketch input image to match the aspect ratio selection
95
+ aspect_ratio_width, aspect_ratio_height = 1280, 768
96
+ uov_input_image = resize_and_pad_image(cn_img1, aspect_ratio_width, aspect_ratio_height)
97
+ buffered_uov = BytesIO()
98
+ uov_input_image.save(buffered_uov, format="PNG")
99
+ uov_input_image_base64 = base64.b64encode(buffered_uov.getvalue()).decode('utf-8')
100
+
101
+ # Call the Replicate API to generate the image
102
+ fooocus_model = replicate.models.get("vetkastar/fooocus").versions.get("d555a800025fe1c171e386d299b1de635f8d8fc3f1ade06a14faf5154eba50f3")
103
+ image = replicate.predictions.create(version=fooocus_model, input={
104
  "prompt": prompt,
105
+ "cn_type1": "PyraCanny",
106
+ "cn_type2": "ImagePrompt",
107
+ "cn_type3": "ImagePrompt",
108
+ "cn_type4": "ImagePrompt",
109
+ "cn_weight1": weight1,
110
+ "cn_weight2": weight2,
111
+ "cn_weight3": weight3,
112
+ "cn_weight4": weight4,
113
+ "cn_img1": "data:image/png;base64," + cn_img1_base64,
114
+ "cn_img2": "data:image/png;base64," + cn_img2_base64,
115
+ "cn_img3": "data:image/png;base64," + cn_img3_base64,
116
+ "cn_img4": "data:image/png;base64," + cn_img4_base64,
117
+ "uov_input_image": "data:image/png;base64," + uov_input_image_base64,
118
+ "sharpness": 2,
119
+ "image_seed": -1,
120
+ "image_number": 1,
121
+ "guidance_scale": 7,
122
+ "refiner_switch": 0.5,
123
+ "negative_prompt": "",
124
+ "inpaint_strength": 0.5,
125
+ "style_selections": "Fooocus V2,Fooocus Enhance,Fooocus Sharp",
126
+ "loras_custom_urls": "",
127
+ "uov_upscale_value": 0,
128
+ "use_default_loras": True,
129
+ "outpaint_selections": "",
130
+ "outpaint_distance_top": 0,
131
+ "performance_selection": "Lightning",
132
+ "outpaint_distance_left": 0,
133
+ "aspect_ratios_selection": "1280*768",
134
+ "outpaint_distance_right": 0,
135
+ "outpaint_distance_bottom": 0,
136
+ "inpaint_additional_prompt": "",
137
+ "uov_method": "Vary (Subtle)"
138
+ })
139
+ image.wait()
140
+ # Fetch the generated image from the output URL
141
+ response = requests.get(image.output["paths"][0])
142
+ img = Image.open(BytesIO(response.content))
143
+
144
+ with open("output.png", "wb") as f:
145
+ f.write(response.content)
146
+ return "output.png", "Job completed successfully using Replicate API."
147
+
148
+ def create_status_image():
149
+ if os.path.exists("output.png"):
150
+ return "output.png"
151
+ else:
152
+ return None
153
+
154
+ def preload_images(cn_img2, cn_img3, cn_img4):
155
+ cn_img2 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400"
156
+ cn_img3 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400"
157
+ cn_img4 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400"
158
+ return cn_img2, cn_img3, cn_img4
159
+
160
+ def shuffle_and_load_images(files):
161
+ if not files:
162
+ return generate_placeholder_image(), generate_placeholder_image(), generate_placeholder_image()
163
+ else:
164
+ random.shuffle(files)
165
+ return files[0], files[1], files[2]
166
+
167
+ def analyze_image(image: Image.Image) -> dict:
168
+ buffered = BytesIO()
169
+ image.save(buffered, format="PNG")
170
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
171
+ analysis = replicate.run(
172
+ "andreasjansson/blip-2:4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608",
173
+ input={"image": "data:image/png;base64," + img_str, "prompt": "what's in this picture?"}
174
  )
175
+ return analysis
176
+
177
+ def get_prompt_from_image(image: Image.Image) -> str:
178
+ analysis = analyze_image(image)
179
+ return analysis.get("describe", "")
180
+
181
+ def generate_prompt(image: Image.Image, current_prompt: str) -> str:
182
+ return get_prompt_from_image(image)
183
+
184
+ import gradio as gr
185
 
186
+ def create_gradio_interface():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  with gr.Blocks() as demo:
 
188
  with gr.Row():
189
+ with gr.Column(scale=0):
190
+ with gr.Tab(label="Sketch"):
191
+ image_input = cn_img1_input = gr.Image(label="Sketch", type="pil")
192
+ weight1 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.75)
193
+ copy_to_sketch_button = gr.Button("Grab Last Output")
194
+
195
+
196
+ with gr.Accordion("Upload Project Files", open=False):
197
+ with gr.Accordion("๐Ÿ“", open=False):
198
+ file_upload = gr.File(file_count="multiple", elem_classes="gradio-column")
199
+ image_gallery = gr.Gallery(label="Image Gallery", elem_classes="gradio-column")
200
+ file_upload.change(shuffle_and_load_images, inputs=[file_upload], outputs=[image_gallery])
201
+ with gr.Column(scale=2):
202
+ with gr.Tab(label="Node"):
203
+ with gr.Accordion("Output"):
204
+ with gr.Column():
205
+ status = gr.Textbox(label="Status")
206
+ status_image = gr.Image(label="Queue Status", interactive=False)
207
+ with gr.Row():
208
+ with gr.Column(scale=1):
209
+ analysis_output = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here")
210
+ with gr.Column(scale=0):
211
+ analyze_button = gr.Button("Analyze Image")
212
+ analyze_button.click(fn=analyze_image, inputs=image_input, outputs=analysis_output)
213
+ with gr.Row():
214
+ preload_button = gr.Button("๐ŸŒธ")
215
+ shuffle_and_load_button = gr.Button("๐Ÿ“‚")
216
+ generate_button = gr.Button("๐Ÿš€ Generate ๐Ÿš€")
217
+
218
+ with gr.Row():
219
+ with gr.Column():
220
+ cn_img2_input = gr.Image(label="Image Prompt 2", type="pil", height=256)
221
+ weight2 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5)
222
+ with gr.Column():
223
+ cn_img3_input = gr.Image(label="Image Prompt 3", type="pil", height=256)
224
+ weight3 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5)
225
+ with gr.Column():
226
+ cn_img4_input = gr.Image(label="Image Prompt 4", type="pil", height=256)
227
+ weight4 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5)
228
+
229
+ with gr.Row():
230
+ preload_button.click(preload_images, inputs=[cn_img2_input, cn_img3_input, cn_img4_input], outputs=[cn_img2_input, cn_img3_input, cn_img4_input])
231
+ shuffle_and_load_button.click(shuffle_and_load_images, inputs=[file_upload], outputs=[cn_img2_input, cn_img3_input, cn_img4_input])
232
+
233
+ generate_button.click(
234
+ fn=image_prompt,
235
+ inputs=[analysis_output, cn_img1_input, cn_img2_input, cn_img3_input, cn_img4_input, weight1, weight2, weight3, weight4],
236
+ outputs=[status_image, status]
237
+ )
238
+
239
+ copy_to_sketch_button.click(
240
+ fn=lambda: Image.open("output.png") if os.path.exists("output.png") else None,
241
+ inputs=[],
242
+ outputs=[cn_img1_input]
243
+ )
244
+
245
+ # โฒ๏ธ Update the image every 5 seconds
246
+ demo.load(create_status_image, every=5, outputs=status_image)
247
+
248
+ demo.launch(server_name="0.0.0.0", server_port=6644, share=True)
249
+
250
+ create_gradio_interface()