Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,18 +1,15 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
-
from
|
3 |
-
import
|
|
|
4 |
from pathlib import Path
|
5 |
-
import
|
6 |
-
import
|
7 |
-
import os
|
8 |
from PIL import Image
|
9 |
-
import
|
10 |
-
import
|
11 |
-
|
12 |
-
from PIL import ImageOps, ImageEnhance, ImageFilter
|
13 |
-
from huggingface_hub import hf_hub_download, snapshot_download
|
14 |
-
from PIL import ImageEnhance
|
15 |
-
import replicate
|
16 |
from dotenv import load_dotenv
|
17 |
|
18 |
# Load environment variables from .env file
|
@@ -20,348 +17,438 @@ load_dotenv()
|
|
20 |
|
21 |
USERNAME = os.getenv("USERNAME")
|
22 |
PASSWORD = os.getenv("PASSWORD")
|
23 |
-
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
|
24 |
-
|
25 |
-
# Set the Replicate API token
|
26 |
-
os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN
|
27 |
-
|
28 |
-
qrcode_generator = qrcode.QRCode(
|
29 |
-
version=1,
|
30 |
-
error_correction=qrcode.ERROR_CORRECT_H,
|
31 |
-
box_size=10,
|
32 |
-
border=4,
|
33 |
-
)
|
34 |
-
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
}
|
42 |
|
43 |
-
|
44 |
-
"GhostMix": "digiplay/GhostMixV1.2VAE",
|
45 |
-
"Stable v1.5": "Jiali/stable-diffusion-1.5",
|
46 |
-
# Add more diffusion models here
|
47 |
-
}
|
48 |
|
49 |
-
# Global variables to store loaded models
|
50 |
-
loaded_controlnet = None
|
51 |
-
loaded_pipe = None
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
#
|
69 |
-
#
|
70 |
-
|
71 |
-
#
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
#
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
try:
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
if
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
else:
|
169 |
-
|
170 |
-
|
171 |
-
def invert_displayed_image(image):
|
172 |
-
if image is None:
|
173 |
-
return None
|
174 |
-
inverted = invert_image(image)
|
175 |
-
if isinstance(inverted, np.ndarray):
|
176 |
-
return Image.fromarray(inverted)
|
177 |
-
return inverted
|
178 |
-
|
179 |
-
|
180 |
-
#@spaces.GPU()
|
181 |
-
def inference(
|
182 |
-
qr_code_content: str,
|
183 |
-
prompt: str,
|
184 |
-
negative_prompt: str,
|
185 |
-
guidance_scale: float = 9.0,
|
186 |
-
qr_conditioning_scale: float = 1.47,
|
187 |
-
num_inference_steps: int = 20,
|
188 |
-
seed: int = -1,
|
189 |
-
image_resolution: int = 512,
|
190 |
-
scheduler: str = "K_EULER",
|
191 |
-
eta: float = 0.0,
|
192 |
-
num_outputs: int = 1,
|
193 |
-
low_threshold: int = 100,
|
194 |
-
high_threshold: int = 200,
|
195 |
-
guess_mode: bool = False,
|
196 |
-
disable_safety_check: bool = False,
|
197 |
-
):
|
198 |
-
try:
|
199 |
-
progress = gr.Progress()
|
200 |
-
progress(0, desc="Generating QR code...")
|
201 |
-
|
202 |
-
# Generate QR code image
|
203 |
-
qr = qrcode.QRCode(
|
204 |
-
version=1,
|
205 |
-
error_correction=qrcode.constants.ERROR_CORRECT_H,
|
206 |
-
box_size=10,
|
207 |
-
border=4,
|
208 |
-
)
|
209 |
-
qr.add_data(qr_code_content)
|
210 |
-
qr.make(fit=True)
|
211 |
-
qr_image = qr.make_image(fill_color="black", back_color="white")
|
212 |
-
|
213 |
-
# Save QR code image to a temporary file
|
214 |
-
temp_qr_path = "temp_qr.png"
|
215 |
-
qr_image.save(temp_qr_path)
|
216 |
-
|
217 |
-
progress(0.3, desc="Running inference...")
|
218 |
-
|
219 |
-
# Ensure num_outputs is within the allowed range
|
220 |
-
num_outputs = max(1, min(num_outputs, 10))
|
221 |
-
|
222 |
-
# Ensure seed is an integer and not null
|
223 |
-
seed = int(seed) if seed != -1 else None
|
224 |
-
|
225 |
-
# Ensure high_threshold is at least 1
|
226 |
-
high_threshold = max(1, int(high_threshold))
|
227 |
-
|
228 |
-
# Prepare the input dictionary
|
229 |
-
input_dict = {
|
230 |
-
"prompt": prompt,
|
231 |
-
"qr_image": open(temp_qr_path, "rb"),
|
232 |
-
"negative_prompt": negative_prompt,
|
233 |
-
"guidance_scale": float(guidance_scale),
|
234 |
-
"qr_conditioning_scale": float(qr_conditioning_scale),
|
235 |
-
"num_inference_steps": int(num_inference_steps),
|
236 |
-
"image_resolution": int(image_resolution),
|
237 |
-
"scheduler": scheduler,
|
238 |
-
"eta": float(eta),
|
239 |
-
"num_outputs": num_outputs,
|
240 |
-
"low_threshold": int(low_threshold),
|
241 |
-
"high_threshold": high_threshold,
|
242 |
-
"guess_mode": guess_mode,
|
243 |
-
"disable_safety_check": disable_safety_check,
|
244 |
-
}
|
245 |
-
|
246 |
-
# Only add seed to input_dict if it's not None
|
247 |
-
if seed is not None:
|
248 |
-
input_dict["seed"] = seed
|
249 |
-
|
250 |
-
# Run inference using Replicate API
|
251 |
-
output = replicate.run(
|
252 |
-
"anotherjesse/multi-control:76d8414a702e66c84fe2e6e9c8cbdc12e53f950f255aae9ffa5caa7873b12de0",
|
253 |
-
input=input_dict
|
254 |
-
)
|
255 |
-
|
256 |
-
progress(0.9, desc="Processing results...")
|
257 |
-
|
258 |
-
# Download the generated image
|
259 |
-
response = requests.get(output[0])
|
260 |
-
img = Image.open(io.BytesIO(response.content))
|
261 |
-
|
262 |
-
# Clean up temporary file
|
263 |
-
os.remove(temp_qr_path)
|
264 |
-
|
265 |
-
progress(1.0, desc="Done!")
|
266 |
-
return img, seed if seed is not None else -1
|
267 |
-
except Exception as e:
|
268 |
-
print(f"Error in inference: {str(e)}")
|
269 |
-
return Image.new('RGB', (512, 512), color='white'), -1
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
def invert_init_image_display(image):
|
274 |
-
if image is None:
|
275 |
-
return None
|
276 |
-
inverted = invert_image(image)
|
277 |
-
if isinstance(inverted, np.ndarray):
|
278 |
-
return Image.fromarray(inverted)
|
279 |
-
return inverted
|
280 |
-
|
281 |
-
def adjust_color_balance(image, r, g, b):
|
282 |
-
# Convert image to RGB if it's not already
|
283 |
-
image = image.convert('RGB')
|
284 |
-
|
285 |
-
# Split the image into its RGB channels
|
286 |
-
r_channel, g_channel, b_channel = image.split()
|
287 |
-
|
288 |
-
# Adjust each channel
|
289 |
-
r_channel = r_channel.point(lambda i: i + (i * r))
|
290 |
-
g_channel = g_channel.point(lambda i: i + (i * g))
|
291 |
-
b_channel = b_channel.point(lambda i: i + (i * b))
|
292 |
-
|
293 |
-
# Merge the channels back
|
294 |
-
return Image.merge('RGB', (r_channel, g_channel, b_channel))
|
295 |
|
296 |
-
|
297 |
-
|
298 |
-
return image
|
299 |
-
|
300 |
-
# Resize original QR to match the generated image
|
301 |
-
original_qr = original_qr.resize(image.size)
|
302 |
-
|
303 |
-
# Create a new image blending the generated image and the QR code
|
304 |
-
return Image.blend(image, original_qr, opacity)
|
305 |
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
|
|
|
|
315 |
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
|
317 |
css = """
|
318 |
-
h1, h2, h3, h4, h5, h6, p, li, ul, ol, a,
|
319 |
-
text-align:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
display: block;
|
321 |
margin-left: auto;
|
322 |
margin-right: auto;
|
|
|
|
|
323 |
}
|
324 |
ul, ol {
|
325 |
-
|
326 |
-
margin-right: auto;
|
327 |
-
display: table;
|
328 |
}
|
329 |
-
.
|
330 |
-
max-width: 100
|
331 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
}
|
333 |
"""
|
334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
def login(username, password):
|
336 |
if username == USERNAME and password == PASSWORD:
|
337 |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value="Login successful! You can now access the QR Code Art Generator tab.", visible=True)
|
338 |
else:
|
339 |
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value="Invalid username or password. Please try again.", visible=True)
|
340 |
-
|
341 |
-
# Add login elements to the Gradio interface
|
342 |
-
with gr.Blocks(theme='Hev832/Applio', css=css, fill_width=True, fill_height=True) as blocks:
|
343 |
-
generated_images = gr.State([])
|
344 |
|
|
|
|
|
345 |
with gr.Tab("Welcome"):
|
346 |
with gr.Row():
|
347 |
-
with gr.Column(scale=2):
|
348 |
gr.Markdown(
|
349 |
"""
|
350 |
-
<img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/
|
351 |
-
|
352 |
-
# 🎨
|
353 |
-
|
354 |
-
##
|
355 |
-
|
356 |
-
|
|
|
|
|
357 |
## 🚀 How It Works:
|
358 |
-
1. **
|
359 |
-
2. **
|
360 |
-
3. **
|
361 |
-
4. **Generate and Iterate**: Click '
|
362 |
"""
|
363 |
)
|
364 |
-
|
365 |
with gr.Column(scale=1):
|
366 |
with gr.Row():
|
367 |
gr.Markdown(
|
@@ -382,479 +469,133 @@ with gr.Blocks(theme='Hev832/Applio', css=css, fill_width=True, fill_height=True
|
|
382 |
login_button = gr.Button("Login", size="sm")
|
383 |
login_message = gr.Markdown(visible=False)
|
384 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
|
386 |
-
with gr.Tab("QR Code Art Generator", visible=False) as app_container:
|
387 |
with gr.Row():
|
388 |
with gr.Column():
|
389 |
-
|
390 |
-
label="QR Code Content",
|
391 |
-
placeholder="Enter URL or text for your QR code",
|
392 |
-
info="This is what your QR code will link to or display when scanned.",
|
393 |
-
value="https://theunderground.digital/",
|
394 |
-
lines=1,
|
395 |
-
)
|
396 |
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
value="
|
401 |
-
info="Describe the style or theme for your QR code art (For best results, keep the prompt to 75 characters or less as seen in the example)",
|
402 |
-
lines=8,
|
403 |
)
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
lines=4,
|
410 |
)
|
411 |
|
412 |
-
|
|
|
|
|
|
|
|
|
|
|
413 |
|
414 |
-
|
415 |
-
gr.Markdown(
|
416 |
-
"""
|
417 |
-
## 🌟 Tips for Spectacular Results:
|
418 |
-
- Use concise details in your prompt to help the AI understand your vision.
|
419 |
-
- Use negative prompts to avoid unwanted elements in your image.
|
420 |
-
- Experiment with different ControlNet models and diffusion models to find the best combination for your prompt.
|
421 |
-
|
422 |
-
## 🎭 Prompt Ideas to Spark Your Creativity:
|
423 |
-
- "A serene Japanese garden with cherry blossoms and a koi pond"
|
424 |
-
- "A futuristic cityscape with neon lights and flying cars"
|
425 |
-
- "An abstract painting with swirling colors and geometric shapes"
|
426 |
-
- "A vintage-style travel poster featuring iconic landmarks"
|
427 |
-
|
428 |
-
Remember, the magic lies in the details of your prompt and the fine-tuning of your settings.
|
429 |
-
Happy creating!
|
430 |
-
"""
|
431 |
-
)
|
432 |
|
433 |
-
with gr.Accordion("Set Custom QR Code Colors", open=False, visible=False):
|
434 |
-
bg_color = gr.ColorPicker(
|
435 |
-
label="Background Color",
|
436 |
-
value="#FFFFFF",
|
437 |
-
info="Choose the background color for the QR code"
|
438 |
-
)
|
439 |
-
qr_color = gr.ColorPicker(
|
440 |
-
label="QR Code Color",
|
441 |
-
value="#000000",
|
442 |
-
info="Choose the color for the QR code pattern"
|
443 |
-
)
|
444 |
-
invert_final_image = gr.Checkbox(
|
445 |
-
label="Invert Final Image",
|
446 |
-
value=False,
|
447 |
-
info="Check this to invert the colors of the final image",
|
448 |
-
visible=False,
|
449 |
-
)
|
450 |
-
with gr.Accordion("AI Model Selection", open=False, visible=False):
|
451 |
-
controlnet_model_dropdown = gr.Dropdown(
|
452 |
-
choices=list(CONTROLNET_MODELS.keys()),
|
453 |
-
value="QR Code Monster",
|
454 |
-
label="ControlNet Model",
|
455 |
-
info="Select the ControlNet model for QR code generation"
|
456 |
-
)
|
457 |
-
diffusion_model_dropdown = gr.Dropdown(
|
458 |
-
choices=list(DIFFUSION_MODELS.keys()),
|
459 |
-
value="GhostMix",
|
460 |
-
label="Diffusion Model",
|
461 |
-
info="Select the main diffusion model for image generation"
|
462 |
-
)
|
463 |
|
464 |
-
|
465 |
-
with gr.Accordion(label="QR Code Image (Optional)", open=False, visible=False):
|
466 |
-
qr_code_image = gr.Image(
|
467 |
-
label="QR Code Image (Optional). Leave blank to automatically generate QR code",
|
468 |
-
type="pil",
|
469 |
-
)
|
470 |
-
|
471 |
with gr.Column():
|
472 |
-
gr.Markdown(
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
used_seed = gr.Number(label="Seed Used", interactive=False)
|
484 |
-
|
485 |
-
with gr.Accordion(label="Use Your Own Image as a Reference", open=True, visible=True) as init_image_acc:
|
486 |
-
init_image = gr.Image(label="Reference Image", type="pil")
|
487 |
-
with gr.Row():
|
488 |
-
use_qr_code_as_init_image = gr.Checkbox(
|
489 |
-
label="Uncheck to use your own image for generation",
|
490 |
-
value=True,
|
491 |
-
interactive=True,
|
492 |
-
info="Allows you to use your own image for generation, otherwise a generic QR Code is created automatically as the base image"
|
493 |
-
)
|
494 |
-
reference_image_strength = gr.Slider(
|
495 |
-
minimum=0.0,
|
496 |
-
maximum=5.0,
|
497 |
-
step=0.05,
|
498 |
-
value=0.6,
|
499 |
-
label="Reference Image Influence",
|
500 |
-
info="Controls how much the reference image influences the final result (0 = ignore, 5 = copy exactly)",
|
501 |
-
visible=False
|
502 |
-
)
|
503 |
-
invert_init_image_button = gr.Button("Invert Init Image", size="sm", visible=False)
|
504 |
-
|
505 |
-
with gr.Tab("Advanced Settings"):
|
506 |
-
with gr.Accordion("Advanced Art Controls", open=True):
|
507 |
-
with gr.Row():
|
508 |
-
qr_conditioning_scale = gr.Slider(
|
509 |
-
minimum=0.0,
|
510 |
-
maximum=5.0,
|
511 |
-
step=0.01,
|
512 |
-
value=1.47,
|
513 |
-
label="QR Code Visibility",
|
514 |
-
)
|
515 |
-
with gr.Accordion("QR Code Visibility Explanation", open=False):
|
516 |
-
gr.Markdown(
|
517 |
-
"""
|
518 |
-
**QR Code Visibility** controls how prominent the QR code is in the final image:
|
519 |
-
|
520 |
-
- **Low (0.0-1.0)**: QR code blends more with the art, potentially harder to scan.
|
521 |
-
- **Medium (1.0-3.0)**: Balanced visibility, usually scannable while maintaining artistic quality.
|
522 |
-
- **High (3.0-5.0)**: QR code stands out more, easier to scan but less artistic.
|
523 |
-
|
524 |
-
Start with 1.47 for a good balance between art and functionality.
|
525 |
-
"""
|
526 |
)
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
maximum=30.0,
|
532 |
-
step=0.1,
|
533 |
-
value=9.0,
|
534 |
-
label="Prompt Adherence",
|
535 |
-
)
|
536 |
-
with gr.Accordion("Prompt Adherence Explanation", open=False):
|
537 |
-
gr.Markdown(
|
538 |
-
"""
|
539 |
-
**Prompt Adherence** determines how closely the AI follows your prompt:
|
540 |
-
|
541 |
-
- **Low (0.1-5.0)**: More creative freedom, may deviate from prompt.
|
542 |
-
- **Medium (5.0-15.0)**: Balanced between prompt and AI creativity.
|
543 |
-
- **High (15.0-30.0)**: Strictly follows the prompt, less creative freedom.
|
544 |
-
|
545 |
-
A value of 9.0 provides a good balance between creativity and prompt adherence.
|
546 |
-
"""
|
547 |
-
)
|
548 |
-
|
549 |
-
with gr.Row():
|
550 |
-
num_inference_steps = gr.Slider(
|
551 |
-
minimum=1,
|
552 |
-
maximum=100,
|
553 |
-
step=1,
|
554 |
-
value=20,
|
555 |
-
label="Generation Steps",
|
556 |
-
)
|
557 |
-
with gr.Accordion("Generation Steps Explanation", open=False):
|
558 |
-
gr.Markdown(
|
559 |
-
"""
|
560 |
-
**Generation Steps** affects the detail and quality of the generated image:
|
561 |
-
|
562 |
-
- **Low (1-10)**: Faster generation, less detailed results.
|
563 |
-
- **Medium (11-30)**: Good balance between speed and quality.
|
564 |
-
- **High (31-100)**: More detailed results, slower generation.
|
565 |
-
|
566 |
-
20 steps is a good starting point for most generations.
|
567 |
-
"""
|
568 |
)
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
maximum=1024,
|
574 |
-
step=64,
|
575 |
-
value=512,
|
576 |
-
label="Image Resolution",
|
577 |
-
)
|
578 |
-
with gr.Accordion("Image Resolution Explanation", open=False):
|
579 |
-
gr.Markdown(
|
580 |
-
"""
|
581 |
-
**Image Resolution** determines the size and detail of the generated image:
|
582 |
-
|
583 |
-
- **Low (256-384)**: Faster generation, less detailed.
|
584 |
-
- **Medium (512-768)**: Good balance of detail and generation time.
|
585 |
-
- **High (832-1024)**: More detailed, slower generation.
|
586 |
-
|
587 |
-
512x512 is a good default for most use cases.
|
588 |
-
"""
|
589 |
-
)
|
590 |
-
|
591 |
-
with gr.Row():
|
592 |
-
seed = gr.Slider(
|
593 |
-
minimum=-1,
|
594 |
-
maximum=9999999999,
|
595 |
-
step=1,
|
596 |
-
value=-1,
|
597 |
-
label="Generation Seed",
|
598 |
-
)
|
599 |
-
with gr.Accordion("Generation Seed Explanation", open=False):
|
600 |
-
gr.Markdown(
|
601 |
-
"""
|
602 |
-
**Generation Seed** controls the randomness of the generation:
|
603 |
-
|
604 |
-
- **-1**: Random seed each time, producing different results.
|
605 |
-
- **Any positive number**: Consistent results for the same inputs.
|
606 |
-
|
607 |
-
Use -1 to explore various designs, or set a specific seed to recreate a particular result.
|
608 |
-
"""
|
609 |
-
)
|
610 |
-
|
611 |
-
with gr.Row():
|
612 |
-
scheduler = gr.Dropdown(
|
613 |
-
choices=["DDIM", "K_EULER", "DPMSolverMultistep", "K_EULER_ANCESTRAL", "PNDM", "KLMS"],
|
614 |
-
value="K_EULER",
|
615 |
-
label="Sampling Method",
|
616 |
-
)
|
617 |
-
with gr.Accordion("Sampling Method Explanation", open=False):
|
618 |
-
gr.Markdown(
|
619 |
-
"""
|
620 |
-
**Sampling Method** affects the image generation process:
|
621 |
-
|
622 |
-
- **K_EULER**: Good balance of speed and quality.
|
623 |
-
- **DDIM**: Can produce sharper results but may be slower.
|
624 |
-
- **DPMSolverMultistep**: Often produces high-quality results.
|
625 |
-
- **K_EULER_ANCESTRAL**: Can introduce more variations.
|
626 |
-
- **PNDM**: Another quality-focused option.
|
627 |
-
- **KLMS**: Can produce smooth results.
|
628 |
-
|
629 |
-
Experiment with different methods to find what works best for your specific prompts.
|
630 |
-
"""
|
631 |
-
)
|
632 |
-
|
633 |
-
with gr.Row():
|
634 |
-
eta = gr.Slider(
|
635 |
-
minimum=0.0,
|
636 |
-
maximum=1.0,
|
637 |
-
step=0.01,
|
638 |
-
value=0.0,
|
639 |
-
label="ETA (Noise Level)",
|
640 |
-
)
|
641 |
-
with gr.Accordion("ETA Explanation", open=False):
|
642 |
-
gr.Markdown(
|
643 |
-
"""
|
644 |
-
**ETA (Noise Level)** controls the amount of noise in the generation process:
|
645 |
-
|
646 |
-
- **0.0**: No added noise, more deterministic results.
|
647 |
-
- **0.1-0.5**: Slight variations in output.
|
648 |
-
- **0.6-1.0**: More variations, potentially more creative results.
|
649 |
-
|
650 |
-
Start with 0.0 and increase if you want more variation in your outputs.
|
651 |
-
"""
|
652 |
)
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
maximum=255,
|
658 |
-
step=1,
|
659 |
-
value=100,
|
660 |
-
label="Edge Detection Low Threshold",
|
661 |
-
)
|
662 |
-
high_threshold = gr.Slider(
|
663 |
-
minimum=1,
|
664 |
-
maximum=255,
|
665 |
-
step=1,
|
666 |
-
value=200,
|
667 |
-
label="Edge Detection High Threshold",
|
668 |
-
)
|
669 |
-
with gr.Accordion("Edge Detection Thresholds Explanation", open=False):
|
670 |
-
gr.Markdown(
|
671 |
-
"""
|
672 |
-
**Edge Detection Thresholds** affect how the QR code edges are processed:
|
673 |
-
|
674 |
-
- **Low Threshold**: Lower values detect more edges, higher values fewer.
|
675 |
-
- **High Threshold**: Determines which edges are strong. Higher values result in fewer strong edges.
|
676 |
-
|
677 |
-
Default values (100, 200) work well for most QR codes. Adjust if you need more or less edge definition.
|
678 |
-
"""
|
679 |
)
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
value=False,
|
685 |
-
)
|
686 |
-
with gr.Accordion("Guess Mode Explanation", open=False):
|
687 |
-
gr.Markdown(
|
688 |
-
"""
|
689 |
-
**Guess Mode**, when enabled, allows the AI to interpret the input image more freely:
|
690 |
-
|
691 |
-
- **Unchecked**: AI follows the QR code structure more strictly.
|
692 |
-
- **Checked**: AI has more freedom to interpret the input, potentially leading to more creative results.
|
693 |
-
|
694 |
-
Use this if you want more artistic interpretations of your QR code.
|
695 |
-
"""
|
696 |
)
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
value=False,
|
702 |
-
)
|
703 |
-
with gr.Accordion("Safety Check Explanation", open=False):
|
704 |
-
gr.Markdown(
|
705 |
-
"""
|
706 |
-
**Disable Safety Check** removes content filtering from the generation process:
|
707 |
-
|
708 |
-
- **Unchecked**: Normal content filtering applied.
|
709 |
-
- **Checked**: No content filtering, may produce unexpected or inappropriate results.
|
710 |
-
|
711 |
-
Use with caution and only if necessary for your specific use case.
|
712 |
-
"""
|
713 |
)
|
714 |
-
with gr.Tab("Image Editing"):
|
715 |
-
with gr.Column():
|
716 |
-
image_selector = gr.Dropdown(label="Select Image to Edit", choices=[], interactive=True, visible=False)
|
717 |
-
image_to_edit = gr.Image(label="Your Artistic QR Code", show_download_button=True, show_fullscreen_button=True, container=True)
|
718 |
-
|
719 |
-
with gr.Row():
|
720 |
-
qr_overlay = gr.Checkbox(label="Overlay Original QR Code", value=False, visible=False)
|
721 |
-
qr_opacity = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="QR Overlay Opacity", visible=False)
|
722 |
-
edge_enhance = gr.Slider(minimum=0.0, maximum=5.0, step=0.1, value=0.0, label="Edge Enhancement", visible=False)
|
723 |
-
|
724 |
-
with gr.Row():
|
725 |
-
red_balance = gr.Slider(minimum=-1.0, maximum=1.0, step=0.1, value=0.0, label="Red Balance")
|
726 |
-
green_balance = gr.Slider(minimum=-1.0, maximum=1.0, step=0.1, value=0.0, label="Green Balance")
|
727 |
-
blue_balance = gr.Slider(minimum=-1.0, maximum=1.0, step=0.1, value=0.0, label="Blue Balance")
|
728 |
-
|
729 |
-
|
730 |
-
with gr.Row():
|
731 |
-
brightness = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Brightness")
|
732 |
-
contrast = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Contrast")
|
733 |
-
saturation = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Saturation")
|
734 |
-
with gr.Row():
|
735 |
-
invert_button = gr.Button("Invert Image", size="sm")
|
736 |
-
|
737 |
-
with gr.Row():
|
738 |
-
edited_image = gr.Image(label="Edited QR Code", show_download_button=True, show_fullscreen_button=True, visible=False)
|
739 |
-
scan_button = gr.Button("Verify QR Code Works", size="sm", visible=False)
|
740 |
-
scan_result = gr.Textbox(label="Validation Result of QR Code", interactive=False, visible=False)
|
741 |
-
|
742 |
-
used_seed = gr.Number(label="Seed Used", interactive=False)
|
743 |
-
|
744 |
-
gr.Markdown(
|
745 |
-
"""
|
746 |
-
### 🔍 Analyzing Your Creation
|
747 |
-
- Is the QR code scannable? Check with your phone camera to see if it can scan it.
|
748 |
-
- If not scannable, use the Brightness, Contrast, and Saturation sliders to optimize the QR code for scanning.
|
749 |
-
- Does the art style match your prompt? If not, try adjusting the 'Prompt Adherence'.
|
750 |
-
- Want more artistic flair? Increase the 'Artistic Freedom'.
|
751 |
-
- Need a clearer QR code? Raise the 'QR Code Visibility'.
|
752 |
-
"""
|
753 |
-
)
|
754 |
-
|
755 |
-
def scan_and_display(image):
|
756 |
-
if image is None:
|
757 |
-
return "No image to scan"
|
758 |
-
|
759 |
-
scanned_text = scan_qr_code(image)
|
760 |
-
if scanned_text:
|
761 |
-
return f"Scanned successfully: {scanned_text}"
|
762 |
-
else:
|
763 |
-
return "Failed to scan QR code. Try adjusting the settings for better visibility."
|
764 |
-
|
765 |
-
def invert_displayed_image(image):
|
766 |
-
if image is None:
|
767 |
-
return None
|
768 |
-
return invert_image(image)
|
769 |
-
|
770 |
-
scan_button.click(
|
771 |
-
scan_and_display,
|
772 |
-
inputs=[result_image],
|
773 |
-
outputs=[scan_result]
|
774 |
-
)
|
775 |
-
|
776 |
-
invert_button.click(
|
777 |
-
invert_displayed_image,
|
778 |
-
inputs=[result_image],
|
779 |
-
outputs=[result_image]
|
780 |
-
)
|
781 |
-
|
782 |
-
invert_init_image_button.click(
|
783 |
-
invert_init_image_display,
|
784 |
-
inputs=[init_image],
|
785 |
-
outputs=[init_image]
|
786 |
-
)
|
787 |
-
|
788 |
-
brightness.change(
|
789 |
-
adjust_image,
|
790 |
-
inputs=[result_image, brightness, contrast, saturation],
|
791 |
-
outputs=[result_image]
|
792 |
-
)
|
793 |
-
contrast.change(
|
794 |
-
adjust_image,
|
795 |
-
inputs=[result_image, brightness, contrast, saturation],
|
796 |
-
outputs=[result_image]
|
797 |
-
)
|
798 |
-
saturation.change(
|
799 |
-
adjust_image,
|
800 |
-
inputs=[result_image, brightness, contrast, saturation],
|
801 |
-
outputs=[result_image]
|
802 |
-
)
|
803 |
-
|
804 |
-
# Add logic to show/hide the reference_image_strength slider
|
805 |
-
def update_reference_image_strength_visibility(init_image, use_qr_code_as_init_image):
|
806 |
-
return gr.update(visible=init_image is not None and not use_qr_code_as_init_image)
|
807 |
-
|
808 |
-
init_image.change(
|
809 |
-
update_reference_image_strength_visibility,
|
810 |
-
inputs=[init_image, use_qr_code_as_init_image],
|
811 |
-
outputs=[reference_image_strength]
|
812 |
-
)
|
813 |
|
814 |
-
|
815 |
-
|
816 |
-
inputs=[init_image, use_qr_code_as_init_image],
|
817 |
-
outputs=[reference_image_strength]
|
818 |
-
)
|
819 |
-
|
820 |
-
run_btn.click(
|
821 |
-
fn=inference,
|
822 |
-
inputs=[
|
823 |
-
qr_code_content,
|
824 |
-
prompt,
|
825 |
-
negative_prompt,
|
826 |
-
guidance_scale,
|
827 |
-
qr_conditioning_scale,
|
828 |
-
num_inference_steps,
|
829 |
-
seed,
|
830 |
-
image_resolution,
|
831 |
-
scheduler,
|
832 |
-
eta,
|
833 |
-
low_threshold,
|
834 |
-
high_threshold,
|
835 |
-
guess_mode,
|
836 |
-
disable_safety_check,
|
837 |
-
],
|
838 |
-
outputs=[result_image, used_seed],
|
839 |
-
concurrency_limit=20
|
840 |
-
)
|
841 |
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
|
848 |
-
|
849 |
-
|
850 |
-
|
851 |
-
|
852 |
-
|
853 |
-
|
|
|
854 |
)
|
855 |
|
856 |
-
# Load models on launch
|
857 |
-
#load_models_on_launch()
|
858 |
|
859 |
-
|
860 |
-
|
|
|
1 |
+
import spaces
|
2 |
import gradio as gr
|
3 |
+
from huggingface_hub import InferenceClient
|
4 |
+
from torch import nn
|
5 |
+
from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
|
6 |
from pathlib import Path
|
7 |
+
import torch
|
8 |
+
import torch.amp.autocast_mode
|
|
|
9 |
from PIL import Image
|
10 |
+
import os
|
11 |
+
import torchvision.transforms.functional as TVF
|
12 |
+
|
|
|
|
|
|
|
|
|
13 |
from dotenv import load_dotenv
|
14 |
|
15 |
# Load environment variables from .env file
|
|
|
17 |
|
18 |
USERNAME = os.getenv("USERNAME")
|
19 |
PASSWORD = os.getenv("PASSWORD")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
CLIP_PATH = "google/siglip-so400m-patch14-384"
|
22 |
+
MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
|
23 |
+
CHECKPOINT_PATH = Path("9em124t2-499968")
|
24 |
+
TITLE = "<h1><center>JoyCaption Alpha One (2024-09-20a)</center></h1>"
|
25 |
+
CAPTION_TYPE_MAP = {
|
26 |
+
("descriptive", "formal", False, False): ["Write a descriptive caption for this image in a formal tone."],
|
27 |
+
("descriptive", "formal", False, True): ["Write a descriptive caption for this image in a formal tone within {word_count} words."],
|
28 |
+
("descriptive", "formal", True, False): ["Write a {length} descriptive caption for this image in a formal tone."],
|
29 |
+
("descriptive", "informal", False, False): ["Write a descriptive caption for this image in a casual tone."],
|
30 |
+
("descriptive", "informal", False, True): ["Write a descriptive caption for this image in a casual tone within {word_count} words."],
|
31 |
+
("descriptive", "informal", True, False): ["Write a {length} descriptive caption for this image in a casual tone."],
|
32 |
+
|
33 |
+
("training_prompt", "formal", False, False): ["Write a stable diffusion prompt for this image."],
|
34 |
+
("training_prompt", "formal", False, True): ["Write a stable diffusion prompt for this image within {word_count} words."],
|
35 |
+
("training_prompt", "formal", True, False): ["Write a {length} stable diffusion prompt for this image."],
|
36 |
+
|
37 |
+
("rng-tags", "formal", False, False): ["Write a list of Booru tags for this image."],
|
38 |
+
("rng-tags", "formal", False, True): ["Write a list of Booru tags for this image within {word_count} words."],
|
39 |
+
("rng-tags", "formal", True, False): ["Write a {length} list of Booru tags for this image."],
|
40 |
+
|
41 |
+
("style_prompt", "formal", False, False): ["Generate a detailed style prompt for this image, including lens type, film stock, composition notes, lighting aspects, and any special photographic techniques."],
|
42 |
+
("style_prompt", "formal", False, True): ["Generate a detailed style prompt for this image within {word_count} words, including lens type, film stock, composition notes, lighting aspects, and any special photographic techniques."],
|
43 |
+
("style_prompt", "formal", True, False): ["Generate a {length} detailed style prompt for this image, including lens type, film stock, composition notes, lighting aspects, and any special photographic techniques."],
|
44 |
}
|
45 |
|
46 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
|
|
|
|
|
|
|
|
47 |
|
|
|
|
|
|
|
48 |
|
49 |
+
class ImageAdapter(nn.Module):
|
50 |
+
def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
|
51 |
+
super().__init__()
|
52 |
+
self.deep_extract = deep_extract
|
53 |
+
|
54 |
+
if self.deep_extract:
|
55 |
+
input_features = input_features * 5
|
56 |
+
|
57 |
+
self.linear1 = nn.Linear(input_features, output_features)
|
58 |
+
self.activation = nn.GELU()
|
59 |
+
self.linear2 = nn.Linear(output_features, output_features)
|
60 |
+
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
|
61 |
+
self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
|
62 |
+
|
63 |
+
# Mode token
|
64 |
+
#self.mode_token = nn.Embedding(n_modes, output_features)
|
65 |
+
#self.mode_token.weight.data.normal_(mean=0.0, std=0.02) # Matches HF's implementation of llama3
|
66 |
+
|
67 |
+
# Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>)
|
68 |
+
self.other_tokens = nn.Embedding(3, output_features)
|
69 |
+
self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) # Matches HF's implementation of llama3
|
70 |
+
|
71 |
+
def forward(self, vision_outputs: torch.Tensor):
|
72 |
+
if self.deep_extract:
|
73 |
+
x = torch.concat((
|
74 |
+
vision_outputs[-2],
|
75 |
+
vision_outputs[3],
|
76 |
+
vision_outputs[7],
|
77 |
+
vision_outputs[13],
|
78 |
+
vision_outputs[20],
|
79 |
+
), dim=-1)
|
80 |
+
assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}" # batch, tokens, features
|
81 |
+
assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
|
82 |
+
else:
|
83 |
+
x = vision_outputs[-2]
|
84 |
+
|
85 |
+
x = self.ln1(x)
|
86 |
+
|
87 |
+
if self.pos_emb is not None:
|
88 |
+
assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
|
89 |
+
x = x + self.pos_emb
|
90 |
+
|
91 |
+
x = self.linear1(x)
|
92 |
+
x = self.activation(x)
|
93 |
+
x = self.linear2(x)
|
94 |
+
|
95 |
+
# Mode token
|
96 |
+
#mode_token = self.mode_token(mode)
|
97 |
+
#assert mode_token.shape == (x.shape[0], mode_token.shape[1], x.shape[2]), f"Expected {(x.shape[0], 1, x.shape[2])}, got {mode_token.shape}"
|
98 |
+
#x = torch.cat((x, mode_token), dim=1)
|
99 |
+
|
100 |
+
# <|image_start|>, IMAGE, <|image_end|>
|
101 |
+
other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
|
102 |
+
assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
|
103 |
+
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
|
104 |
+
|
105 |
+
return x
|
106 |
+
|
107 |
+
def get_eot_embedding(self):
|
108 |
+
return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
# Load CLIP
|
113 |
+
print("Loading CLIP")
|
114 |
+
clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
|
115 |
+
clip_model = AutoModel.from_pretrained(CLIP_PATH)
|
116 |
+
clip_model = clip_model.vision_model
|
117 |
+
|
118 |
+
if (CHECKPOINT_PATH / "clip_model.pt").exists():
|
119 |
+
print("Loading VLM's custom vision model")
|
120 |
+
checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu')
|
121 |
+
checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
|
122 |
+
clip_model.load_state_dict(checkpoint)
|
123 |
+
del checkpoint
|
124 |
+
|
125 |
+
clip_model.eval()
|
126 |
+
clip_model.requires_grad_(False)
|
127 |
+
clip_model.to("cuda")
|
128 |
+
|
129 |
+
|
130 |
+
# Tokenizer
|
131 |
+
print("Loading tokenizer")
|
132 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
|
133 |
+
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
|
134 |
+
|
135 |
+
# LLM
|
136 |
+
print("Loading LLM")
|
137 |
+
if (CHECKPOINT_PATH / "text_model").exists:
|
138 |
+
print("Loading VLM's custom text model")
|
139 |
+
text_model = AutoModelForCausalLM.from_pretrained(CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16)
|
140 |
+
else:
|
141 |
+
text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
|
142 |
+
|
143 |
+
text_model.eval()
|
144 |
+
|
145 |
+
# Image Adapter
|
146 |
+
print("Loading image adapter")
|
147 |
+
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False)
|
148 |
+
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu", weights_only=True))
|
149 |
+
image_adapter.eval()
|
150 |
+
image_adapter.to("cuda")
|
151 |
+
|
152 |
+
|
153 |
+
def preprocess_image(input_image: Image.Image) -> torch.Tensor:
|
154 |
+
"""
|
155 |
+
Preprocess the input image for the CLIP model.
|
156 |
+
"""
|
157 |
+
image = input_image.resize((384, 384), Image.LANCZOS)
|
158 |
+
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
159 |
+
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
160 |
+
return pixel_values.to('cuda')
|
161 |
+
|
162 |
+
def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max_new_tokens: int = 300) -> str:
|
163 |
+
"""
|
164 |
+
Generate a caption based on the image features and prompt.
|
165 |
+
"""
|
166 |
+
prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
|
167 |
+
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
|
168 |
+
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
|
169 |
+
eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
|
170 |
+
|
171 |
+
inputs_embeds = torch.cat([
|
172 |
+
embedded_bos.expand(image_features.shape[0], -1, -1),
|
173 |
+
image_features.to(dtype=embedded_bos.dtype),
|
174 |
+
prompt_embeds.expand(image_features.shape[0], -1, -1),
|
175 |
+
eot_embed.expand(image_features.shape[0], -1, -1),
|
176 |
+
], dim=1)
|
177 |
+
|
178 |
+
input_ids = torch.cat([
|
179 |
+
torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
|
180 |
+
torch.zeros((1, image_features.shape[1]), dtype=torch.long),
|
181 |
+
prompt,
|
182 |
+
torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
|
183 |
+
], dim=1).to('cuda')
|
184 |
+
attention_mask = torch.ones_like(input_ids)
|
185 |
+
|
186 |
+
generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, suppress_tokens=None)
|
187 |
+
|
188 |
+
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
189 |
+
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
190 |
+
generate_ids = generate_ids[:, :-1]
|
191 |
+
|
192 |
+
return tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0].strip()
|
193 |
+
|
194 |
+
@spaces.GPU()
|
195 |
+
@torch.no_grad()
|
196 |
+
def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int, lens_type: str = "", film_stock: str = "", composition_style: str = "", lighting_aspect: str = "", special_technique: str = "", color_effect: str = "") -> str:
|
197 |
+
"""
|
198 |
+
Generate a caption, training prompt, tags, or a style prompt for image generation based on the input image and parameters.
|
199 |
+
"""
|
200 |
+
# Check if an image has been uploaded
|
201 |
+
if input_image is None:
|
202 |
+
return "Error: Please upload an image before generating a caption."
|
203 |
+
|
204 |
+
torch.cuda.empty_cache()
|
205 |
+
|
206 |
try:
|
207 |
+
length = None if caption_length == "any" else caption_length
|
208 |
+
if isinstance(length, str):
|
209 |
+
length = int(length)
|
210 |
+
except ValueError:
|
211 |
+
raise ValueError(f"Invalid caption length: {caption_length}")
|
212 |
+
|
213 |
+
if caption_type in ["rng-tags", "training_prompt", "style_prompt"]:
|
214 |
+
caption_tone = "formal"
|
215 |
+
|
216 |
+
prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
|
217 |
+
if prompt_key not in CAPTION_TYPE_MAP:
|
218 |
+
raise ValueError(f"Invalid caption type: {prompt_key}")
|
219 |
+
|
220 |
+
if caption_type == "style_prompt":
|
221 |
+
# For style prompt, we'll create a custom prompt for the LLM
|
222 |
+
base_prompt = "Analyze the given image and create a detailed Stable Diffusion prompt for generating a new, creative image inspired by it. "
|
223 |
+
base_prompt += "The prompt should describe the main elements, style, and mood of the image, "
|
224 |
+
base_prompt += "but also introduce creative variations or enhancements. "
|
225 |
+
base_prompt += "Include specific details about the composition, lighting, and overall atmosphere. "
|
226 |
+
|
227 |
+
# Add custom settings to the prompt
|
228 |
+
if lens_type:
|
229 |
+
lens_type_key = lens_type.split(":")[0].strip()
|
230 |
+
base_prompt += f"Incorporate the effect of a {lens_type_key} lens ({lens_types_info[lens_type_key]}). "
|
231 |
+
if film_stock:
|
232 |
+
film_stock_key = film_stock.split(":")[0].strip()
|
233 |
+
base_prompt += f"Apply the characteristics of {film_stock_key} film stock ({film_stocks_info[film_stock_key]}). "
|
234 |
+
if composition_style:
|
235 |
+
composition_style_key = composition_style.split(":")[0].strip()
|
236 |
+
base_prompt += f"Use a {composition_style_key} composition style ({composition_styles_info[composition_style_key]}). "
|
237 |
+
if lighting_aspect:
|
238 |
+
lighting_aspect_key = lighting_aspect.split(":")[0].strip()
|
239 |
+
base_prompt += f"Implement {lighting_aspect_key} lighting ({lighting_aspects_info[lighting_aspect_key]}). "
|
240 |
+
if special_technique:
|
241 |
+
special_technique_key = special_technique.split(":")[0].strip()
|
242 |
+
base_prompt += f"Apply the {special_technique_key} technique ({special_techniques_info[special_technique_key]}). "
|
243 |
+
if color_effect:
|
244 |
+
color_effect_key = color_effect.split(":")[0].strip()
|
245 |
+
base_prompt += f"Use a {color_effect_key} color effect ({color_effects_info[color_effect_key]}). "
|
246 |
+
|
247 |
+
base_prompt += f"The final prompt should be approximately {length} words long. "
|
248 |
+
base_prompt += "Format the output as a single paragraph without numbering or bullet points."
|
249 |
+
|
250 |
+
prompt_str = base_prompt
|
251 |
else:
|
252 |
+
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
+
# Debugging: Print the constructed prompt string
|
255 |
+
print(f"Constructed Prompt: {prompt_str}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
+
pixel_values = preprocess_image(input_image)
|
258 |
+
|
259 |
+
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
|
260 |
+
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
261 |
+
image_features = vision_outputs.hidden_states
|
262 |
+
embedded_images = image_adapter(image_features)
|
263 |
+
embedded_images = embedded_images.to('cuda')
|
264 |
+
|
265 |
+
# Load the model from MODEL_PATH
|
266 |
+
text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
|
267 |
+
text_model.eval()
|
268 |
|
269 |
+
# Debugging: Print the prompt string before passing to generate_caption
|
270 |
+
print(f"Prompt passed to generate_caption: {prompt_str}")
|
271 |
+
|
272 |
+
caption = generate_caption(text_model, tokenizer, embedded_images, prompt_str)
|
273 |
+
|
274 |
+
return caption
|
275 |
|
276 |
css = """
|
277 |
+
h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, img {
|
278 |
+
text-align: left;
|
279 |
+
}
|
280 |
+
img {
|
281 |
+
display: inline-block;
|
282 |
+
vertical-align: middle;
|
283 |
+
margin-right: 10px;
|
284 |
+
max-width: 100%;
|
285 |
+
height: auto;
|
286 |
+
}
|
287 |
+
.centered-image {
|
288 |
display: block;
|
289 |
margin-left: auto;
|
290 |
margin-right: auto;
|
291 |
+
max-width: 100%;
|
292 |
+
height: auto;
|
293 |
}
|
294 |
ul, ol {
|
295 |
+
padding-left: 20px;
|
|
|
|
|
296 |
}
|
297 |
+
.gradio-container {
|
298 |
+
max-width: 100% !important;
|
299 |
+
padding: 0 !important;
|
300 |
+
}
|
301 |
+
.gradio-row {
|
302 |
+
margin-left: 0 !important;
|
303 |
+
margin-right: 0 !important;
|
304 |
+
}
|
305 |
+
.gradio-column {
|
306 |
+
padding-left: 0 !important;
|
307 |
+
padding-right: 0 !important;
|
308 |
+
}
|
309 |
+
/* Left-align dropdown text */
|
310 |
+
.gradio-dropdown > div {
|
311 |
+
text-align: left !important;
|
312 |
+
}
|
313 |
+
/* Left-align checkbox labels */
|
314 |
+
.gradio-checkbox label {
|
315 |
+
text-align: left !important;
|
316 |
+
}
|
317 |
+
/* Left-align radio button labels */
|
318 |
+
.gradio-radio label {
|
319 |
+
text-align: left !important;
|
320 |
}
|
321 |
"""
|
322 |
|
323 |
+
# Add detailed descriptions for each option
|
324 |
+
lens_types_info = {
|
325 |
+
"Standard": "A versatile lens with a field of view similar to human vision.",
|
326 |
+
"Wide-angle": "Captures a wider field of view, great for landscapes and architecture. Applies moderate to strong lens effect with image warp.",
|
327 |
+
"Telephoto": "Used for distant subjects, gives an 'award-winning' or 'National Geographic' look. Creates interesting effects when prompted.",
|
328 |
+
"Macro": "For extreme close-up photography, revealing tiny details.",
|
329 |
+
"Fish-eye": "Ultra-wide-angle lens that creates a strong bubble-like distortion. Generates panoramic photos with the entire image warping into a bubble.",
|
330 |
+
"Tilt-shift": "Allows adjusting the plane of focus, creating a 'miniature' effect. Known for the 'diorama miniature look'.",
|
331 |
+
"Zoom lens": "Variable focal length lens. Often zooms in on the subject, perfect for creating a base for inpainting. Interesting effect on landscapes with motion blur.",
|
332 |
+
"GoPro": "Wide-angle lens with clean digital look. Excludes film grain and most filter effects, resulting in natural colors and regular saturation.",
|
333 |
+
"Pinhole camera": "Creates a unique, foggy, low-detail, historic photograph look. Used since the 1850s, with peak popularity in the 1930s."
|
334 |
+
}
|
335 |
+
|
336 |
+
film_stocks_info = {
|
337 |
+
"Kodak Portra": "Professional color negative film known for its natural skin tones and low contrast.",
|
338 |
+
"Fujifilm Velvia": "Slide film known for vibrant colors and high saturation, popular among landscape photographers.",
|
339 |
+
"Ilford Delta": "Black and white film known for its fine grain and high sharpness.",
|
340 |
+
"Kodak Tri-X": "Classic high-speed black and white film, known for its distinctive grain and wide exposure latitude.",
|
341 |
+
"Fujifilm Provia": "Color reversal film known for its natural color reproduction and fine grain.",
|
342 |
+
"Cinestill": "Color photos with fine/low grain and higher than average resolution. Colors are slightly oversaturated or slightly desaturated.",
|
343 |
+
"Ektachrome": "Color photos with fine/low to moderate grain. Colors on the colder part of spectrum or regular, with normal or slightly higher saturation.",
|
344 |
+
"Ektar": "Modern Kodak film. Color photos with little to no grain. Results look like regular modern photography with artistic angles.",
|
345 |
+
"Film Washi": "Mostly black and white photos with fine/low to moderate grain. Occasionally gives colored photos with low saturation. Distinct style with high black contrast and soft camera lens effect.",
|
346 |
+
"Fomapan": "Black and white photos with fine/low to moderate grain, highly artistic exposure and angles. Adds very soft lens effect without distortion, dark photo vignette.",
|
347 |
+
"Fujicolor": "Color photos with fine/low to moderate grain. Colors are either very oversaturated or slightly desaturated, with entire color hue shifted in a very distinct manner.",
|
348 |
+
"Holga": "Color photos with high grain. Colors are either very oversaturated or slightly desaturated. Distinct contrast of black. Often applies photographic vignette.",
|
349 |
+
"Instax": "Instant color photos similar to Polaroid but clearer. Near perfect colors, regular saturation, fine/low to medium grain.",
|
350 |
+
"Lomography": "Color photos with high grain. Colors are either very oversaturated or slightly desaturated. Distinct contrast of black. Often applies photographic vignette.",
|
351 |
+
"Kodachrome": "Color photos with moderate grain. Colors on either colder part of spectrum or regular, with normal or slightly higher saturation.",
|
352 |
+
"Rollei": "Mostly black and white photos, sometimes color with fine/low grain. Can be sepia colored or have unusual hues and desaturation. Great for landscapes."
|
353 |
+
}
|
354 |
+
|
355 |
+
composition_styles_info = {
|
356 |
+
"Rule of Thirds": "Divides the frame into a 3x3 grid, placing key elements along the lines or at their intersections.",
|
357 |
+
"Golden Ratio": "Uses a spiral based on the golden ratio to create a balanced and aesthetically pleasing composition.",
|
358 |
+
"Symmetry": "Creates a mirror-like balance in the image, often used for architectural or nature photography.",
|
359 |
+
"Leading Lines": "Uses lines within the frame to draw the viewer's eye to the main subject or through the image.",
|
360 |
+
"Framing": "Uses elements within the scene to create a frame around the main subject.",
|
361 |
+
"Minimalism": "Simplifies the composition to its essential elements, often with a lot of negative space.",
|
362 |
+
"Fill the Frame": "The main subject dominates the entire frame, leaving little to no background.",
|
363 |
+
"Negative Space": "Uses empty space around the subject to create a sense of simplicity or isolation.",
|
364 |
+
"Centered Composition": "Places the main subject in the center of the frame, creating a sense of stability or importance.",
|
365 |
+
"Diagonal Lines": "Uses diagonal elements to create a sense of movement or dynamic tension in the image.",
|
366 |
+
"Triangular Composition": "Arranges elements in the frame to form a triangle, creating a sense of stability and harmony.",
|
367 |
+
"Radial Balance": "Arranges elements in a circular pattern around a central point, creating a sense of movement or completeness."
|
368 |
+
}
|
369 |
+
|
370 |
+
lighting_aspects_info = {
|
371 |
+
"Natural light": "Uses available light from the sun or sky, often creating soft, even illumination.",
|
372 |
+
"Studio lighting": "Controlled artificial lighting setup, allowing for precise manipulation of light and shadow.",
|
373 |
+
"Back light": "Light source behind the subject, creating silhouettes or rim lighting effects.",
|
374 |
+
"Split light": "Strong light source at 90-degree angle, lighting one half of the subject while leaving the other in shadow.",
|
375 |
+
"Broad light": "Light source at an angle to the subject, producing well-lit photographs with soft to moderate shadows.",
|
376 |
+
"Dim light": "Weak or distant light source, creating lower than average brightness and often dramatic images.",
|
377 |
+
"Flash photography": "Uses a brief, intense burst of light. Can be fill flash (even lighting) or harsh flash (strong contrasts).",
|
378 |
+
"Sunlight": "Direct light from the sun, often creating strong contrasts and warm tones.",
|
379 |
+
"Moonlight": "Soft, cool light from the moon, often creating a mysterious or romantic atmosphere.",
|
380 |
+
"Spotlight": "Focused beam of light illuminating a specific area, creating high contrast between light and shadow.",
|
381 |
+
"High-key lighting": "Bright, even lighting with minimal shadows, creating a light and airy feel.",
|
382 |
+
"Low-key lighting": "Predominantly dark tones with selective lighting, creating a moody or dramatic atmosphere.",
|
383 |
+
"Rembrandt lighting": "Classic portrait lighting technique creating a triangle of light on the cheek of the subject."
|
384 |
+
}
|
385 |
+
|
386 |
+
special_techniques_info = {
|
387 |
+
"Double exposure": "Superimposes two exposures to create a single image, often resulting in a dreamy or surreal effect.",
|
388 |
+
"Long exposure": "Uses a long shutter speed to capture motion over time, often creating smooth, blurred effects for moving elements.",
|
389 |
+
"Multiple exposure": "Superimposes multiple exposures, multiplying the subject or its key elements across the image.",
|
390 |
+
"HDR": "High Dynamic Range imaging, combining multiple exposures to capture a wider range of light and dark tones.",
|
391 |
+
"Bokeh effect": "Creates a soft, out-of-focus background, often with circular highlights.",
|
392 |
+
"Silhouette": "Captures the outline of a subject against a brighter background, creating a dramatic contrast.",
|
393 |
+
"Panning": "Follows a moving subject with the camera, creating a sharp subject with a blurred background.",
|
394 |
+
"Light painting": "Uses long exposure and moving light sources to 'paint' with light in the image.",
|
395 |
+
"Infrared photography": "Captures light in the infrared spectrum, often resulting in surreal, otherworldly images.",
|
396 |
+
"Ultraviolet photography": "Captures light in the ultraviolet spectrum, often revealing hidden patterns or creating a strong violet glow.",
|
397 |
+
"Kirlian photography": "High-voltage photographic technique that captures corona discharges around objects, creating a glowing effect.",
|
398 |
+
"Thermography": "Captures infrared radiation to create images based on temperature differences, resulting in false-color heat maps.",
|
399 |
+
"Astrophotography": "Specialized technique for capturing astronomical objects and celestial events, often resulting in stunning starry backgrounds.",
|
400 |
+
"Underwater photography": "Captures images beneath the surface of water, often in pools, seas, or aquariums.",
|
401 |
+
"Aerial photography": "Captures images from an elevated position, such as from drones, helicopters, or planes.",
|
402 |
+
"Macro photography": "Extreme close-up photography, revealing tiny details not visible to the naked eye."
|
403 |
+
}
|
404 |
+
|
405 |
+
color_effects_info = {
|
406 |
+
"Black and white": "Removes all color, leaving only shades of gray.",
|
407 |
+
"Sepia": "Reddish-brown monochrome effect, often associated with vintage photography.",
|
408 |
+
"Monochrome": "Uses variations of a single color.",
|
409 |
+
"Vintage color": "Muted or faded color palette reminiscent of old photographs.",
|
410 |
+
"Cross-processed": "Deliberate processing of film in the wrong chemicals, creating unusual color shifts.",
|
411 |
+
"Desaturated": "Reduces the intensity of all colors in the image.",
|
412 |
+
"Vivid colors": "Increases the saturation and intensity of colors.",
|
413 |
+
"Pastel colors": "Soft, pale colors with a light and airy feel.",
|
414 |
+
"High contrast": "Emphasizes the difference between light and dark areas in the image.",
|
415 |
+
"Low contrast": "Reduces the difference between light and dark areas, creating a softer look.",
|
416 |
+
"Color splash": "Converts most of the image to black and white while leaving one or more elements in color."
|
417 |
+
}
|
418 |
+
|
419 |
+
def get_dropdown_choices(info_dict):
|
420 |
+
return [f"{key}: {value}" for key, value in info_dict.items()]
|
421 |
+
|
422 |
def login(username, password):
|
423 |
if username == USERNAME and password == PASSWORD:
|
424 |
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value="Login successful! You can now access the QR Code Art Generator tab.", visible=True)
|
425 |
else:
|
426 |
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value="Invalid username or password. Please try again.", visible=True)
|
|
|
|
|
|
|
|
|
427 |
|
428 |
+
# Gradio interface
|
429 |
+
with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True) as demo:
|
430 |
with gr.Tab("Welcome"):
|
431 |
with gr.Row():
|
432 |
+
with gr.Column(scale=2):
|
433 |
gr.Markdown(
|
434 |
"""
|
435 |
+
<img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/LVZnwLV43UUvKu3HORqSs.webp" alt="UDG" width="250" style="max-width: 100%; height: auto; class="centered-image">
|
436 |
+
|
437 |
+
# 🎨 Underground Digital's Caption Captain: AI-Powered Art Inspiration
|
438 |
+
|
439 |
+
## Accelerate Your Creative Workflow with Intelligent Image Analysis
|
440 |
+
|
441 |
+
This innovative tool empowers Yamamoto's artists to quickly generate descriptive captions,<br>
|
442 |
+
training prompts, and tags from existing artwork, fueling the creative process for GenAI models.
|
443 |
+
|
444 |
## 🚀 How It Works:
|
445 |
+
1. **Upload Your Inspiration**: Drop in an image (e.g., a charcoal horse picture) that embodies your desired style.
|
446 |
+
2. **Choose Your Output**: Select from descriptive captions, training prompts, or tags.
|
447 |
+
3. **Customize the Results**: Adjust tone, length, and other parameters to fine-tune the output.
|
448 |
+
4. **Generate and Iterate**: Click 'Caption' to analyze your image and use the results to inspire new creations.
|
449 |
"""
|
450 |
)
|
451 |
+
|
452 |
with gr.Column(scale=1):
|
453 |
with gr.Row():
|
454 |
gr.Markdown(
|
|
|
469 |
login_button = gr.Button("Login", size="sm")
|
470 |
login_message = gr.Markdown(visible=False)
|
471 |
|
472 |
+
with gr.Tab("Caption Captain") as app_container:
|
473 |
+
with gr.Accordion("How to Use Caption Captain", open=False):
|
474 |
+
gr.Markdown("""
|
475 |
+
# How to Use Caption Captain
|
476 |
+
|
477 |
+
<img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/Ce_Z478iOXljvpZ_Fr_Y7.png" alt="Captain" width="100" style="max-width: 100%; height: auto;">
|
478 |
+
|
479 |
+
Hello, artist! Let's make some fun captions for your pictures. Here's how:
|
480 |
+
|
481 |
+
1. **Pick a Picture**: Find a cool picture you want to talk about and upload it.
|
482 |
+
|
483 |
+
2. **Choose What You Want**:
|
484 |
+
- **Caption Type**:
|
485 |
+
* "Descriptive" tells you what's in the picture
|
486 |
+
* "Training Prompt" helps computers make similar pictures
|
487 |
+
* "RNG-Tags" gives you short words about the picture
|
488 |
+
* "Style Prompt" creates detailed prompts for image generation
|
489 |
+
|
490 |
+
3. **Pick a Style** (for "Descriptive" and "Style Prompt" only):
|
491 |
+
- "Formal" sounds like a teacher talking
|
492 |
+
- "Informal" sounds like a friend chatting
|
493 |
+
|
494 |
+
4. **Decide How Long**:
|
495 |
+
- "Any" lets the computer decide
|
496 |
+
- Or pick a size from "very short" to "very long"
|
497 |
+
- You can even choose a specific number of words!
|
498 |
+
|
499 |
+
5. **Advanced Options** (for "Style Prompt" only):
|
500 |
+
- Choose lens type, film stock, composition, and lighting details
|
501 |
+
|
502 |
+
6. **Make the Caption**: Click the "Make My Caption!" button and watch the magic happen!
|
503 |
+
|
504 |
+
Remember, have fun and be creative with your captions!
|
505 |
+
|
506 |
+
## Tips for Great Captions:
|
507 |
+
- Try different types to see what you like best
|
508 |
+
- Experiment with formal and informal tones for fun variations
|
509 |
+
- Adjust the length to get just the right amount of detail
|
510 |
+
- For "Style Prompt", play with the advanced options for more specific results
|
511 |
+
- If you don't like a caption, just click "Make My Caption!" again for a new one
|
512 |
+
|
513 |
+
Have a great time captioning your art!
|
514 |
+
""")
|
515 |
|
|
|
516 |
with gr.Row():
|
517 |
with gr.Column():
|
518 |
+
input_image = gr.Image(type="pil", label="Input Image")
|
|
|
|
|
|
|
|
|
|
|
|
|
519 |
|
520 |
+
caption_type = gr.Dropdown(
|
521 |
+
choices=["descriptive", "training_prompt", "rng-tags", "style_prompt"],
|
522 |
+
label="Caption Type",
|
523 |
+
value="descriptive",
|
|
|
|
|
524 |
)
|
525 |
+
|
526 |
+
caption_tone = gr.Dropdown(
|
527 |
+
choices=["formal", "informal"],
|
528 |
+
label="Caption Tone",
|
529 |
+
value="formal",
|
|
|
530 |
)
|
531 |
|
532 |
+
caption_length = gr.Dropdown(
|
533 |
+
choices=["any", "very short", "short", "medium-length", "long", "very long"] +
|
534 |
+
[str(i) for i in range(20, 261, 10)],
|
535 |
+
label="Caption Length",
|
536 |
+
value="any",
|
537 |
+
)
|
538 |
|
539 |
+
gr.Markdown("**Note:** Caption tone doesn't affect `rng-tags`, `training_prompt`, and `style_prompt`.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
542 |
with gr.Column():
|
543 |
+
error_message = gr.Markdown(visible=False) # Add this line
|
544 |
+
output_caption = gr.Textbox(label="Generated Caption")
|
545 |
+
run_button = gr.Button("Make My Caption!")
|
546 |
+
|
547 |
+
# Container for advanced options
|
548 |
+
with gr.Column(visible=False) as advanced_options:
|
549 |
+
gr.Markdown("### Advanced Options for Style Prompt")
|
550 |
+
lens_type = gr.Dropdown(
|
551 |
+
choices=get_dropdown_choices(lens_types_info),
|
552 |
+
label="Lens Type",
|
553 |
+
info="Select a lens type to define the perspective and field of view of the image."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
554 |
)
|
555 |
+
film_stock = gr.Dropdown(
|
556 |
+
choices=get_dropdown_choices(film_stocks_info),
|
557 |
+
label="Film Stock",
|
558 |
+
info="Choose a film stock to determine the color, grain, and overall look of the image."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
559 |
)
|
560 |
+
composition_style = gr.Dropdown(
|
561 |
+
choices=get_dropdown_choices(composition_styles_info),
|
562 |
+
label="Composition Style",
|
563 |
+
info="Select a composition style to guide the arrangement of elements in the image."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
)
|
565 |
+
lighting_aspect = gr.Dropdown(
|
566 |
+
choices=get_dropdown_choices(lighting_aspects_info),
|
567 |
+
label="Lighting Aspect",
|
568 |
+
info="Choose a lighting style to define the mood and atmosphere of the image."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
569 |
)
|
570 |
+
special_technique = gr.Dropdown(
|
571 |
+
choices=get_dropdown_choices(special_techniques_info),
|
572 |
+
label="Special Technique",
|
573 |
+
info="Select a special photographic technique to add unique effects to the image."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
)
|
575 |
+
color_effect = gr.Dropdown(
|
576 |
+
choices=get_dropdown_choices(color_effects_info),
|
577 |
+
label="Color Effect",
|
578 |
+
info="Choose a color effect to alter the overall color palette of the image."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
579 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
|
581 |
+
def update_style_options(caption_type):
|
582 |
+
return gr.update(visible=caption_type == "style_prompt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
583 |
|
584 |
+
caption_type.change(update_style_options, inputs=[caption_type], outputs=[advanced_options])
|
585 |
+
|
586 |
+
def process_and_handle_errors(input_image, caption_type, caption_tone, caption_length, lens_type, film_stock, composition_style, lighting_aspect, special_technique, color_effect):
|
587 |
+
try:
|
588 |
+
result = stream_chat(input_image, caption_type, caption_tone, caption_length, lens_type, film_stock, composition_style, lighting_aspect, special_technique, color_effect)
|
589 |
+
return gr.update(visible=False), result
|
590 |
+
except Exception as e:
|
591 |
+
return gr.update(visible=True, value=f"Error: {str(e)}"), ""
|
592 |
+
|
593 |
+
run_button.click(
|
594 |
+
fn=process_and_handle_errors,
|
595 |
+
inputs=[input_image, caption_type, caption_tone, caption_length, lens_type, film_stock, composition_style, lighting_aspect, special_technique, color_effect],
|
596 |
+
outputs=[error_message, output_caption]
|
597 |
)
|
598 |
|
|
|
|
|
599 |
|
600 |
+
if __name__ == "__main__":
|
601 |
+
demo.launch()
|