Severian commited on
Commit
14a4524
·
verified ·
1 Parent(s): 8816043

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +515 -774
app.py CHANGED
@@ -1,18 +1,15 @@
 
1
  import gradio as gr
2
- from PIL import Image
3
- import qrcode
 
4
  from pathlib import Path
5
- import requests
6
- import io
7
- import os
8
  from PIL import Image
9
- import numpy as np
10
- import cv2
11
- from pyzxing import BarCodeReader
12
- from PIL import ImageOps, ImageEnhance, ImageFilter
13
- from huggingface_hub import hf_hub_download, snapshot_download
14
- from PIL import ImageEnhance
15
- import replicate
16
  from dotenv import load_dotenv
17
 
18
  # Load environment variables from .env file
@@ -20,348 +17,438 @@ load_dotenv()
20
 
21
  USERNAME = os.getenv("USERNAME")
22
  PASSWORD = os.getenv("PASSWORD")
23
- REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
24
-
25
- # Set the Replicate API token
26
- os.environ["REPLICATE_API_TOKEN"] = REPLICATE_API_TOKEN
27
-
28
- qrcode_generator = qrcode.QRCode(
29
- version=1,
30
- error_correction=qrcode.ERROR_CORRECT_H,
31
- box_size=10,
32
- border=4,
33
- )
34
-
35
 
36
- # Define available models
37
- CONTROLNET_MODELS = {
38
- "QR Code Monster": "monster-labs/control_v1p_sd15_qrcode_monster/v2/",
39
- "QR Code": "DionTimmer/controlnet_qrcode-control_v1p_sd15",
40
- # Add more ControlNet models here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  }
42
 
43
- DIFFUSION_MODELS = {
44
- "GhostMix": "digiplay/GhostMixV1.2VAE",
45
- "Stable v1.5": "Jiali/stable-diffusion-1.5",
46
- # Add more diffusion models here
47
- }
48
 
49
- # Global variables to store loaded models
50
- loaded_controlnet = None
51
- loaded_pipe = None
52
 
53
- # def load_models_on_launch():
54
- # global loaded_controlnet, loaded_pipe
55
- # print("Loading models on launch...")
56
-
57
- # Download the main repository
58
- # main_repo_path = snapshot_download("monster-labs/control_v1p_sd15_qrcode_monster")
59
-
60
- # Construct the path to the subfolder
61
- # controlnet_path = os.path.join(main_repo_path, "v2")
62
-
63
- # loaded_controlnet = ControlNetModel.from_pretrained(
64
- # controlnet_path,
65
- # torch_dtype=torch.float16
66
- # ).to("mps")
67
-
68
- # diffusion_path = snapshot_download(DIFFUSION_MODELS["GhostMix"])
69
- # loaded_pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
70
- # diffusion_path,
71
- # controlnet=loaded_controlnet,
72
- # torch_dtype=torch.float16,
73
- # safety_checker=None,
74
- # ).to("mps")
75
- # print("Models loaded successfully!")
76
-
77
- # Modify the load_models function to use global variables
78
- #def load_models(controlnet_model, diffusion_model):
79
- # global loaded_controlnet, loaded_pipe
80
- # if loaded_controlnet is None or loaded_pipe is None:
81
- # load_models_on_launch()
82
- # return loaded_pipe
83
-
84
- # Add new functions for image adjustments
85
- def adjust_image(image, brightness, contrast, saturation):
86
- if image is None:
87
- return None
88
-
89
- img = Image.fromarray(image) if isinstance(image, np.ndarray) else image
90
-
91
- if brightness != 1:
92
- img = ImageEnhance.Brightness(img).enhance(brightness)
93
- if contrast != 1:
94
- img = ImageEnhance.Contrast(img).enhance(contrast)
95
- if saturation != 1:
96
- img = ImageEnhance.Color(img).enhance(saturation)
97
-
98
- return np.array(img)
99
-
100
- def resize_for_condition_image(input_image: Image.Image, resolution: int):
101
- input_image = input_image.convert("RGB")
102
- W, H = input_image.size
103
- k = float(resolution) / min(H, W)
104
- H *= k
105
- W *= k
106
- H = int(round(H / 64.0)) * 64
107
- W = int(round(W / 64.0)) * 64
108
- img = input_image.resize((W, H), resample=Image.LANCZOS)
109
- return img
110
-
111
-
112
- # SAMPLER_MAP = {
113
- # "DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True, algorithm_type="sde-dpmsolver++"),
114
- # "DPM++ Karras": lambda config: DPMSolverMultistepScheduler.from_config(config, use_karras=True),
115
- # "Heun": lambda config: HeunDiscreteScheduler.from_config(config),
116
- # "Euler": lambda config: EulerDiscreteScheduler.from_config(config),
117
- # "DDIM": lambda config: DDIMScheduler.from_config(config),
118
- # "DEIS": lambda config: DEISMultistepScheduler.from_config(config),
119
- #}
120
-
121
- def scan_qr_code(image):
122
- # Convert gradio image to PIL Image if necessary
123
- if isinstance(image, np.ndarray):
124
- image = Image.fromarray(image)
125
-
126
- # Convert to grayscale
127
- gray_image = image.convert('L')
128
-
129
- # Convert to numpy array
130
- np_image = np.array(gray_image)
131
-
132
- # Method 1: Using qrcode library
133
- try:
134
- qr = qrcode.QRCode()
135
- qr.add_data('')
136
- qr.decode(gray_image)
137
- return qr.data.decode('utf-8')
138
- except Exception:
139
- pass
140
-
141
- # Method 2: Using OpenCV
142
- try:
143
- qr_detector = cv2.QRCodeDetector()
144
- retval, decoded_info, points, straight_qrcode = qr_detector.detectAndDecodeMulti(np_image)
145
- if retval:
146
- return decoded_info[0]
147
- except Exception:
148
- pass
149
-
150
- # Method 3: Fallback to zxing-cpp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  try:
152
- reader = BarCodeReader()
153
- results = reader.decode(np_image)
154
- if results:
155
- return results[0].parsed
156
- except Exception:
157
- pass
158
-
159
- return None
160
-
161
- def invert_image(image):
162
- if image is None:
163
- return None
164
- if isinstance(image, np.ndarray):
165
- return 255 - image
166
- elif isinstance(image, Image.Image):
167
- return ImageOps.invert(image.convert('RGB'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  else:
169
- raise ValueError("Unsupported image type")
170
-
171
- def invert_displayed_image(image):
172
- if image is None:
173
- return None
174
- inverted = invert_image(image)
175
- if isinstance(inverted, np.ndarray):
176
- return Image.fromarray(inverted)
177
- return inverted
178
-
179
-
180
- #@spaces.GPU()
181
- def inference(
182
- qr_code_content: str,
183
- prompt: str,
184
- negative_prompt: str,
185
- guidance_scale: float = 9.0,
186
- qr_conditioning_scale: float = 1.47,
187
- num_inference_steps: int = 20,
188
- seed: int = -1,
189
- image_resolution: int = 512,
190
- scheduler: str = "K_EULER",
191
- eta: float = 0.0,
192
- num_outputs: int = 1,
193
- low_threshold: int = 100,
194
- high_threshold: int = 200,
195
- guess_mode: bool = False,
196
- disable_safety_check: bool = False,
197
- ):
198
- try:
199
- progress = gr.Progress()
200
- progress(0, desc="Generating QR code...")
201
-
202
- # Generate QR code image
203
- qr = qrcode.QRCode(
204
- version=1,
205
- error_correction=qrcode.constants.ERROR_CORRECT_H,
206
- box_size=10,
207
- border=4,
208
- )
209
- qr.add_data(qr_code_content)
210
- qr.make(fit=True)
211
- qr_image = qr.make_image(fill_color="black", back_color="white")
212
-
213
- # Save QR code image to a temporary file
214
- temp_qr_path = "temp_qr.png"
215
- qr_image.save(temp_qr_path)
216
-
217
- progress(0.3, desc="Running inference...")
218
-
219
- # Ensure num_outputs is within the allowed range
220
- num_outputs = max(1, min(num_outputs, 10))
221
-
222
- # Ensure seed is an integer and not null
223
- seed = int(seed) if seed != -1 else None
224
-
225
- # Ensure high_threshold is at least 1
226
- high_threshold = max(1, int(high_threshold))
227
-
228
- # Prepare the input dictionary
229
- input_dict = {
230
- "prompt": prompt,
231
- "qr_image": open(temp_qr_path, "rb"),
232
- "negative_prompt": negative_prompt,
233
- "guidance_scale": float(guidance_scale),
234
- "qr_conditioning_scale": float(qr_conditioning_scale),
235
- "num_inference_steps": int(num_inference_steps),
236
- "image_resolution": int(image_resolution),
237
- "scheduler": scheduler,
238
- "eta": float(eta),
239
- "num_outputs": num_outputs,
240
- "low_threshold": int(low_threshold),
241
- "high_threshold": high_threshold,
242
- "guess_mode": guess_mode,
243
- "disable_safety_check": disable_safety_check,
244
- }
245
-
246
- # Only add seed to input_dict if it's not None
247
- if seed is not None:
248
- input_dict["seed"] = seed
249
-
250
- # Run inference using Replicate API
251
- output = replicate.run(
252
- "anotherjesse/multi-control:76d8414a702e66c84fe2e6e9c8cbdc12e53f950f255aae9ffa5caa7873b12de0",
253
- input=input_dict
254
- )
255
-
256
- progress(0.9, desc="Processing results...")
257
-
258
- # Download the generated image
259
- response = requests.get(output[0])
260
- img = Image.open(io.BytesIO(response.content))
261
-
262
- # Clean up temporary file
263
- os.remove(temp_qr_path)
264
-
265
- progress(1.0, desc="Done!")
266
- return img, seed if seed is not None else -1
267
- except Exception as e:
268
- print(f"Error in inference: {str(e)}")
269
- return Image.new('RGB', (512, 512), color='white'), -1
270
-
271
-
272
-
273
- def invert_init_image_display(image):
274
- if image is None:
275
- return None
276
- inverted = invert_image(image)
277
- if isinstance(inverted, np.ndarray):
278
- return Image.fromarray(inverted)
279
- return inverted
280
-
281
- def adjust_color_balance(image, r, g, b):
282
- # Convert image to RGB if it's not already
283
- image = image.convert('RGB')
284
-
285
- # Split the image into its RGB channels
286
- r_channel, g_channel, b_channel = image.split()
287
-
288
- # Adjust each channel
289
- r_channel = r_channel.point(lambda i: i + (i * r))
290
- g_channel = g_channel.point(lambda i: i + (i * g))
291
- b_channel = b_channel.point(lambda i: i + (i * b))
292
-
293
- # Merge the channels back
294
- return Image.merge('RGB', (r_channel, g_channel, b_channel))
295
 
296
- def apply_qr_overlay(image, original_qr, overlay, opacity):
297
- if not overlay or original_qr is None:
298
- return image
299
-
300
- # Resize original QR to match the generated image
301
- original_qr = original_qr.resize(image.size)
302
-
303
- # Create a new image blending the generated image and the QR code
304
- return Image.blend(image, original_qr, opacity)
305
 
306
- def apply_edge_enhancement(image, strength):
307
- if strength == 0:
308
- return image
309
-
310
- # Apply edge enhancement
311
- enhanced = image.filter(ImageFilter.EDGE_ENHANCE)
312
-
313
- # Blend the original and enhanced images based on strength
314
- return Image.blend(image, enhanced, strength / 5.0)
 
 
315
 
 
 
 
 
 
 
316
 
317
  css = """
318
- h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, .centered-image {
319
- text-align: center;
 
 
 
 
 
 
 
 
 
320
  display: block;
321
  margin-left: auto;
322
  margin-right: auto;
 
 
323
  }
324
  ul, ol {
325
- margin-left: auto;
326
- margin-right: auto;
327
- display: table;
328
  }
329
- .centered-image {
330
- max-width: 100%;
331
- height: auto;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  }
333
  """
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  def login(username, password):
336
  if username == USERNAME and password == PASSWORD:
337
  return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value="Login successful! You can now access the QR Code Art Generator tab.", visible=True)
338
  else:
339
  return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value="Invalid username or password. Please try again.", visible=True)
340
-
341
- # Add login elements to the Gradio interface
342
- with gr.Blocks(theme='Hev832/Applio', css=css, fill_width=True, fill_height=True) as blocks:
343
- generated_images = gr.State([])
344
 
 
 
345
  with gr.Tab("Welcome"):
346
  with gr.Row():
347
- with gr.Column(scale=2):
348
  gr.Markdown(
349
  """
350
- <img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/8lHjpId7-JDalHq1JByPE.webp" alt="Yamamoto Logo" class="centered-image">
351
-
352
- # 🎨 Yamamoto QR Code Art Generator
353
-
354
- ## Transform Your QR Codes into Brand Masterpieces
355
- This cutting-edge tool empowers our creative team to craft visually stunning,<br>
356
- on-brand QR codes that perfectly blend functionality with artistic expression.
 
 
357
  ## 🚀 How It Works:
358
- 1. **Enter Your QR Code Content**: Start by inputting the URL or text for your QR code.
359
- 2. **Craft Your Prompt**: Describe the artistic style or theme you envision for your QR code.
360
- 3. **Fine-tune with Advanced Settings**: Adjust parameters to perfect your creation (see tips below).
361
- 4. **Generate and Iterate**: Click 'Run' to create your art, then refine as needed.
362
  """
363
  )
364
-
365
  with gr.Column(scale=1):
366
  with gr.Row():
367
  gr.Markdown(
@@ -382,479 +469,133 @@ with gr.Blocks(theme='Hev832/Applio', css=css, fill_width=True, fill_height=True
382
  login_button = gr.Button("Login", size="sm")
383
  login_message = gr.Markdown(visible=False)
384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
- with gr.Tab("QR Code Art Generator", visible=False) as app_container:
387
  with gr.Row():
388
  with gr.Column():
389
- qr_code_content = gr.Textbox(
390
- label="QR Code Content",
391
- placeholder="Enter URL or text for your QR code",
392
- info="This is what your QR code will link to or display when scanned.",
393
- value="https://theunderground.digital/",
394
- lines=1,
395
- )
396
 
397
- prompt = gr.Textbox(
398
- label="Artistic Prompt",
399
- placeholder="Describe the style or theme for your QR code art (For best results, keep the prompt to 75 characters or less as seen below)",
400
- value="A high-res, photo-realistic minimalist rendering of Mount Fuji as a sharp, semi-realistic silhouette on the horizon. The mountain conveys strength and motion with clean, crisp lines and natural flow. Features detailed snow textures, subtle ridge highlights, and a powerful yet serene atmosphere. Emphasizes strength with clarity and precision in texture and light.",
401
- info="Describe the style or theme for your QR code art (For best results, keep the prompt to 75 characters or less as seen in the example)",
402
- lines=8,
403
  )
404
- negative_prompt = gr.Textbox(
405
- label="Elements to Avoid",
406
- placeholder="Describe what you don't want in the image",
407
- value="ugly, disfigured, low quality, blurry, nsfw, bad_pictures, poorly drawn, distorted, overexposed, flat shading, bad proportions, deformed, pixelated, messy details, lack of contrast, unrealistic textures, bad anatomy, rough edges, low resolution",
408
- info="List elements or styles you want to avoid in your QR code art.",
409
- lines=4,
410
  )
411
 
412
- run_btn = gr.Button("🎨 Create Your QR Art", variant="primary")
 
 
 
 
 
413
 
414
- with gr.Accordion(label="Needs Some Prompting Help?", open=False, visible=True):
415
- gr.Markdown(
416
- """
417
- ## 🌟 Tips for Spectacular Results:
418
- - Use concise details in your prompt to help the AI understand your vision.
419
- - Use negative prompts to avoid unwanted elements in your image.
420
- - Experiment with different ControlNet models and diffusion models to find the best combination for your prompt.
421
-
422
- ## 🎭 Prompt Ideas to Spark Your Creativity:
423
- - "A serene Japanese garden with cherry blossoms and a koi pond"
424
- - "A futuristic cityscape with neon lights and flying cars"
425
- - "An abstract painting with swirling colors and geometric shapes"
426
- - "A vintage-style travel poster featuring iconic landmarks"
427
-
428
- Remember, the magic lies in the details of your prompt and the fine-tuning of your settings.
429
- Happy creating!
430
- """
431
- )
432
 
433
- with gr.Accordion("Set Custom QR Code Colors", open=False, visible=False):
434
- bg_color = gr.ColorPicker(
435
- label="Background Color",
436
- value="#FFFFFF",
437
- info="Choose the background color for the QR code"
438
- )
439
- qr_color = gr.ColorPicker(
440
- label="QR Code Color",
441
- value="#000000",
442
- info="Choose the color for the QR code pattern"
443
- )
444
- invert_final_image = gr.Checkbox(
445
- label="Invert Final Image",
446
- value=False,
447
- info="Check this to invert the colors of the final image",
448
- visible=False,
449
- )
450
- with gr.Accordion("AI Model Selection", open=False, visible=False):
451
- controlnet_model_dropdown = gr.Dropdown(
452
- choices=list(CONTROLNET_MODELS.keys()),
453
- value="QR Code Monster",
454
- label="ControlNet Model",
455
- info="Select the ControlNet model for QR code generation"
456
- )
457
- diffusion_model_dropdown = gr.Dropdown(
458
- choices=list(DIFFUSION_MODELS.keys()),
459
- value="GhostMix",
460
- label="Diffusion Model",
461
- info="Select the main diffusion model for image generation"
462
- )
463
 
464
-
465
- with gr.Accordion(label="QR Code Image (Optional)", open=False, visible=False):
466
- qr_code_image = gr.Image(
467
- label="QR Code Image (Optional). Leave blank to automatically generate QR code",
468
- type="pil",
469
- )
470
-
471
  with gr.Column():
472
- gr.Markdown("### Your Generated QR Code Art")
473
- result_image = gr.Image(
474
- label="Your Artistic QR Code",
475
- show_download_button=True,
476
- show_fullscreen_button=True,
477
- container=False
478
- )
479
- gr.Markdown("💾 Right-click and save the image to download your QR code art.")
480
-
481
- scan_button = gr.Button("Verify QR Code Works", visible=False)
482
- scan_result = gr.Textbox(label="Validation Result of QR Code", interactive=False, visible=False)
483
- used_seed = gr.Number(label="Seed Used", interactive=False)
484
-
485
- with gr.Accordion(label="Use Your Own Image as a Reference", open=True, visible=True) as init_image_acc:
486
- init_image = gr.Image(label="Reference Image", type="pil")
487
- with gr.Row():
488
- use_qr_code_as_init_image = gr.Checkbox(
489
- label="Uncheck to use your own image for generation",
490
- value=True,
491
- interactive=True,
492
- info="Allows you to use your own image for generation, otherwise a generic QR Code is created automatically as the base image"
493
- )
494
- reference_image_strength = gr.Slider(
495
- minimum=0.0,
496
- maximum=5.0,
497
- step=0.05,
498
- value=0.6,
499
- label="Reference Image Influence",
500
- info="Controls how much the reference image influences the final result (0 = ignore, 5 = copy exactly)",
501
- visible=False
502
- )
503
- invert_init_image_button = gr.Button("Invert Init Image", size="sm", visible=False)
504
-
505
- with gr.Tab("Advanced Settings"):
506
- with gr.Accordion("Advanced Art Controls", open=True):
507
- with gr.Row():
508
- qr_conditioning_scale = gr.Slider(
509
- minimum=0.0,
510
- maximum=5.0,
511
- step=0.01,
512
- value=1.47,
513
- label="QR Code Visibility",
514
- )
515
- with gr.Accordion("QR Code Visibility Explanation", open=False):
516
- gr.Markdown(
517
- """
518
- **QR Code Visibility** controls how prominent the QR code is in the final image:
519
-
520
- - **Low (0.0-1.0)**: QR code blends more with the art, potentially harder to scan.
521
- - **Medium (1.0-3.0)**: Balanced visibility, usually scannable while maintaining artistic quality.
522
- - **High (3.0-5.0)**: QR code stands out more, easier to scan but less artistic.
523
-
524
- Start with 1.47 for a good balance between art and functionality.
525
- """
526
  )
527
-
528
- with gr.Row():
529
- guidance_scale = gr.Slider(
530
- minimum=0.1,
531
- maximum=30.0,
532
- step=0.1,
533
- value=9.0,
534
- label="Prompt Adherence",
535
- )
536
- with gr.Accordion("Prompt Adherence Explanation", open=False):
537
- gr.Markdown(
538
- """
539
- **Prompt Adherence** determines how closely the AI follows your prompt:
540
-
541
- - **Low (0.1-5.0)**: More creative freedom, may deviate from prompt.
542
- - **Medium (5.0-15.0)**: Balanced between prompt and AI creativity.
543
- - **High (15.0-30.0)**: Strictly follows the prompt, less creative freedom.
544
-
545
- A value of 9.0 provides a good balance between creativity and prompt adherence.
546
- """
547
- )
548
-
549
- with gr.Row():
550
- num_inference_steps = gr.Slider(
551
- minimum=1,
552
- maximum=100,
553
- step=1,
554
- value=20,
555
- label="Generation Steps",
556
- )
557
- with gr.Accordion("Generation Steps Explanation", open=False):
558
- gr.Markdown(
559
- """
560
- **Generation Steps** affects the detail and quality of the generated image:
561
-
562
- - **Low (1-10)**: Faster generation, less detailed results.
563
- - **Medium (11-30)**: Good balance between speed and quality.
564
- - **High (31-100)**: More detailed results, slower generation.
565
-
566
- 20 steps is a good starting point for most generations.
567
- """
568
  )
569
-
570
- with gr.Row():
571
- image_resolution = gr.Slider(
572
- minimum=256,
573
- maximum=1024,
574
- step=64,
575
- value=512,
576
- label="Image Resolution",
577
- )
578
- with gr.Accordion("Image Resolution Explanation", open=False):
579
- gr.Markdown(
580
- """
581
- **Image Resolution** determines the size and detail of the generated image:
582
-
583
- - **Low (256-384)**: Faster generation, less detailed.
584
- - **Medium (512-768)**: Good balance of detail and generation time.
585
- - **High (832-1024)**: More detailed, slower generation.
586
-
587
- 512x512 is a good default for most use cases.
588
- """
589
- )
590
-
591
- with gr.Row():
592
- seed = gr.Slider(
593
- minimum=-1,
594
- maximum=9999999999,
595
- step=1,
596
- value=-1,
597
- label="Generation Seed",
598
- )
599
- with gr.Accordion("Generation Seed Explanation", open=False):
600
- gr.Markdown(
601
- """
602
- **Generation Seed** controls the randomness of the generation:
603
-
604
- - **-1**: Random seed each time, producing different results.
605
- - **Any positive number**: Consistent results for the same inputs.
606
-
607
- Use -1 to explore various designs, or set a specific seed to recreate a particular result.
608
- """
609
- )
610
-
611
- with gr.Row():
612
- scheduler = gr.Dropdown(
613
- choices=["DDIM", "K_EULER", "DPMSolverMultistep", "K_EULER_ANCESTRAL", "PNDM", "KLMS"],
614
- value="K_EULER",
615
- label="Sampling Method",
616
- )
617
- with gr.Accordion("Sampling Method Explanation", open=False):
618
- gr.Markdown(
619
- """
620
- **Sampling Method** affects the image generation process:
621
-
622
- - **K_EULER**: Good balance of speed and quality.
623
- - **DDIM**: Can produce sharper results but may be slower.
624
- - **DPMSolverMultistep**: Often produces high-quality results.
625
- - **K_EULER_ANCESTRAL**: Can introduce more variations.
626
- - **PNDM**: Another quality-focused option.
627
- - **KLMS**: Can produce smooth results.
628
-
629
- Experiment with different methods to find what works best for your specific prompts.
630
- """
631
- )
632
-
633
- with gr.Row():
634
- eta = gr.Slider(
635
- minimum=0.0,
636
- maximum=1.0,
637
- step=0.01,
638
- value=0.0,
639
- label="ETA (Noise Level)",
640
- )
641
- with gr.Accordion("ETA Explanation", open=False):
642
- gr.Markdown(
643
- """
644
- **ETA (Noise Level)** controls the amount of noise in the generation process:
645
-
646
- - **0.0**: No added noise, more deterministic results.
647
- - **0.1-0.5**: Slight variations in output.
648
- - **0.6-1.0**: More variations, potentially more creative results.
649
-
650
- Start with 0.0 and increase if you want more variation in your outputs.
651
- """
652
  )
653
-
654
- with gr.Row():
655
- low_threshold = gr.Slider(
656
- minimum=1,
657
- maximum=255,
658
- step=1,
659
- value=100,
660
- label="Edge Detection Low Threshold",
661
- )
662
- high_threshold = gr.Slider(
663
- minimum=1,
664
- maximum=255,
665
- step=1,
666
- value=200,
667
- label="Edge Detection High Threshold",
668
- )
669
- with gr.Accordion("Edge Detection Thresholds Explanation", open=False):
670
- gr.Markdown(
671
- """
672
- **Edge Detection Thresholds** affect how the QR code edges are processed:
673
-
674
- - **Low Threshold**: Lower values detect more edges, higher values fewer.
675
- - **High Threshold**: Determines which edges are strong. Higher values result in fewer strong edges.
676
-
677
- Default values (100, 200) work well for most QR codes. Adjust if you need more or less edge definition.
678
- """
679
  )
680
-
681
- with gr.Row():
682
- guess_mode = gr.Checkbox(
683
- label="Guess Mode",
684
- value=False,
685
- )
686
- with gr.Accordion("Guess Mode Explanation", open=False):
687
- gr.Markdown(
688
- """
689
- **Guess Mode**, when enabled, allows the AI to interpret the input image more freely:
690
-
691
- - **Unchecked**: AI follows the QR code structure more strictly.
692
- - **Checked**: AI has more freedom to interpret the input, potentially leading to more creative results.
693
-
694
- Use this if you want more artistic interpretations of your QR code.
695
- """
696
  )
697
-
698
- with gr.Row():
699
- disable_safety_check = gr.Checkbox(
700
- label="Disable Safety Check",
701
- value=False,
702
- )
703
- with gr.Accordion("Safety Check Explanation", open=False):
704
- gr.Markdown(
705
- """
706
- **Disable Safety Check** removes content filtering from the generation process:
707
-
708
- - **Unchecked**: Normal content filtering applied.
709
- - **Checked**: No content filtering, may produce unexpected or inappropriate results.
710
-
711
- Use with caution and only if necessary for your specific use case.
712
- """
713
  )
714
- with gr.Tab("Image Editing"):
715
- with gr.Column():
716
- image_selector = gr.Dropdown(label="Select Image to Edit", choices=[], interactive=True, visible=False)
717
- image_to_edit = gr.Image(label="Your Artistic QR Code", show_download_button=True, show_fullscreen_button=True, container=True)
718
-
719
- with gr.Row():
720
- qr_overlay = gr.Checkbox(label="Overlay Original QR Code", value=False, visible=False)
721
- qr_opacity = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="QR Overlay Opacity", visible=False)
722
- edge_enhance = gr.Slider(minimum=0.0, maximum=5.0, step=0.1, value=0.0, label="Edge Enhancement", visible=False)
723
-
724
- with gr.Row():
725
- red_balance = gr.Slider(minimum=-1.0, maximum=1.0, step=0.1, value=0.0, label="Red Balance")
726
- green_balance = gr.Slider(minimum=-1.0, maximum=1.0, step=0.1, value=0.0, label="Green Balance")
727
- blue_balance = gr.Slider(minimum=-1.0, maximum=1.0, step=0.1, value=0.0, label="Blue Balance")
728
-
729
-
730
- with gr.Row():
731
- brightness = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Brightness")
732
- contrast = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Contrast")
733
- saturation = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Saturation")
734
- with gr.Row():
735
- invert_button = gr.Button("Invert Image", size="sm")
736
-
737
- with gr.Row():
738
- edited_image = gr.Image(label="Edited QR Code", show_download_button=True, show_fullscreen_button=True, visible=False)
739
- scan_button = gr.Button("Verify QR Code Works", size="sm", visible=False)
740
- scan_result = gr.Textbox(label="Validation Result of QR Code", interactive=False, visible=False)
741
-
742
- used_seed = gr.Number(label="Seed Used", interactive=False)
743
-
744
- gr.Markdown(
745
- """
746
- ### 🔍 Analyzing Your Creation
747
- - Is the QR code scannable? Check with your phone camera to see if it can scan it.
748
- - If not scannable, use the Brightness, Contrast, and Saturation sliders to optimize the QR code for scanning.
749
- - Does the art style match your prompt? If not, try adjusting the 'Prompt Adherence'.
750
- - Want more artistic flair? Increase the 'Artistic Freedom'.
751
- - Need a clearer QR code? Raise the 'QR Code Visibility'.
752
- """
753
- )
754
-
755
- def scan_and_display(image):
756
- if image is None:
757
- return "No image to scan"
758
-
759
- scanned_text = scan_qr_code(image)
760
- if scanned_text:
761
- return f"Scanned successfully: {scanned_text}"
762
- else:
763
- return "Failed to scan QR code. Try adjusting the settings for better visibility."
764
-
765
- def invert_displayed_image(image):
766
- if image is None:
767
- return None
768
- return invert_image(image)
769
-
770
- scan_button.click(
771
- scan_and_display,
772
- inputs=[result_image],
773
- outputs=[scan_result]
774
- )
775
-
776
- invert_button.click(
777
- invert_displayed_image,
778
- inputs=[result_image],
779
- outputs=[result_image]
780
- )
781
-
782
- invert_init_image_button.click(
783
- invert_init_image_display,
784
- inputs=[init_image],
785
- outputs=[init_image]
786
- )
787
-
788
- brightness.change(
789
- adjust_image,
790
- inputs=[result_image, brightness, contrast, saturation],
791
- outputs=[result_image]
792
- )
793
- contrast.change(
794
- adjust_image,
795
- inputs=[result_image, brightness, contrast, saturation],
796
- outputs=[result_image]
797
- )
798
- saturation.change(
799
- adjust_image,
800
- inputs=[result_image, brightness, contrast, saturation],
801
- outputs=[result_image]
802
- )
803
-
804
- # Add logic to show/hide the reference_image_strength slider
805
- def update_reference_image_strength_visibility(init_image, use_qr_code_as_init_image):
806
- return gr.update(visible=init_image is not None and not use_qr_code_as_init_image)
807
-
808
- init_image.change(
809
- update_reference_image_strength_visibility,
810
- inputs=[init_image, use_qr_code_as_init_image],
811
- outputs=[reference_image_strength]
812
- )
813
 
814
- use_qr_code_as_init_image.change(
815
- update_reference_image_strength_visibility,
816
- inputs=[init_image, use_qr_code_as_init_image],
817
- outputs=[reference_image_strength]
818
- )
819
-
820
- run_btn.click(
821
- fn=inference,
822
- inputs=[
823
- qr_code_content,
824
- prompt,
825
- negative_prompt,
826
- guidance_scale,
827
- qr_conditioning_scale,
828
- num_inference_steps,
829
- seed,
830
- image_resolution,
831
- scheduler,
832
- eta,
833
- low_threshold,
834
- high_threshold,
835
- guess_mode,
836
- disable_safety_check,
837
- ],
838
- outputs=[result_image, used_seed],
839
- concurrency_limit=20
840
- )
841
 
842
- # Define login button click behavior
843
- login_button.click(
844
- login,
845
- inputs=[username, password],
846
- outputs=[app_container, login_message, login_button, login_message]
847
- )
848
-
849
- # Define password textbox submit behavior
850
- password.submit(
851
- login,
852
- inputs=[username, password],
853
- outputs=[app_container, login_message, login_button, login_message]
 
854
  )
855
 
856
- # Load models on launch
857
- #load_models_on_launch()
858
 
859
- blocks.queue(max_size=20)
860
- blocks.launch(share=False, show_api=False)
 
1
+ import spaces
2
  import gradio as gr
3
+ from huggingface_hub import InferenceClient
4
+ from torch import nn
5
+ from transformers import AutoModel, AutoProcessor, AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast, AutoModelForCausalLM
6
  from pathlib import Path
7
+ import torch
8
+ import torch.amp.autocast_mode
 
9
  from PIL import Image
10
+ import os
11
+ import torchvision.transforms.functional as TVF
12
+
 
 
 
 
13
  from dotenv import load_dotenv
14
 
15
  # Load environment variables from .env file
 
17
 
18
  USERNAME = os.getenv("USERNAME")
19
  PASSWORD = os.getenv("PASSWORD")
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ CLIP_PATH = "google/siglip-so400m-patch14-384"
22
+ MODEL_PATH = "meta-llama/Meta-Llama-3.1-8B"
23
+ CHECKPOINT_PATH = Path("9em124t2-499968")
24
+ TITLE = "<h1><center>JoyCaption Alpha One (2024-09-20a)</center></h1>"
25
+ CAPTION_TYPE_MAP = {
26
+ ("descriptive", "formal", False, False): ["Write a descriptive caption for this image in a formal tone."],
27
+ ("descriptive", "formal", False, True): ["Write a descriptive caption for this image in a formal tone within {word_count} words."],
28
+ ("descriptive", "formal", True, False): ["Write a {length} descriptive caption for this image in a formal tone."],
29
+ ("descriptive", "informal", False, False): ["Write a descriptive caption for this image in a casual tone."],
30
+ ("descriptive", "informal", False, True): ["Write a descriptive caption for this image in a casual tone within {word_count} words."],
31
+ ("descriptive", "informal", True, False): ["Write a {length} descriptive caption for this image in a casual tone."],
32
+
33
+ ("training_prompt", "formal", False, False): ["Write a stable diffusion prompt for this image."],
34
+ ("training_prompt", "formal", False, True): ["Write a stable diffusion prompt for this image within {word_count} words."],
35
+ ("training_prompt", "formal", True, False): ["Write a {length} stable diffusion prompt for this image."],
36
+
37
+ ("rng-tags", "formal", False, False): ["Write a list of Booru tags for this image."],
38
+ ("rng-tags", "formal", False, True): ["Write a list of Booru tags for this image within {word_count} words."],
39
+ ("rng-tags", "formal", True, False): ["Write a {length} list of Booru tags for this image."],
40
+
41
+ ("style_prompt", "formal", False, False): ["Generate a detailed style prompt for this image, including lens type, film stock, composition notes, lighting aspects, and any special photographic techniques."],
42
+ ("style_prompt", "formal", False, True): ["Generate a detailed style prompt for this image within {word_count} words, including lens type, film stock, composition notes, lighting aspects, and any special photographic techniques."],
43
+ ("style_prompt", "formal", True, False): ["Generate a {length} detailed style prompt for this image, including lens type, film stock, composition notes, lighting aspects, and any special photographic techniques."],
44
  }
45
 
46
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
 
 
 
47
 
 
 
 
48
 
49
+ class ImageAdapter(nn.Module):
50
+ def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
51
+ super().__init__()
52
+ self.deep_extract = deep_extract
53
+
54
+ if self.deep_extract:
55
+ input_features = input_features * 5
56
+
57
+ self.linear1 = nn.Linear(input_features, output_features)
58
+ self.activation = nn.GELU()
59
+ self.linear2 = nn.Linear(output_features, output_features)
60
+ self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
61
+ self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
62
+
63
+ # Mode token
64
+ #self.mode_token = nn.Embedding(n_modes, output_features)
65
+ #self.mode_token.weight.data.normal_(mean=0.0, std=0.02) # Matches HF's implementation of llama3
66
+
67
+ # Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>)
68
+ self.other_tokens = nn.Embedding(3, output_features)
69
+ self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) # Matches HF's implementation of llama3
70
+
71
+ def forward(self, vision_outputs: torch.Tensor):
72
+ if self.deep_extract:
73
+ x = torch.concat((
74
+ vision_outputs[-2],
75
+ vision_outputs[3],
76
+ vision_outputs[7],
77
+ vision_outputs[13],
78
+ vision_outputs[20],
79
+ ), dim=-1)
80
+ assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}" # batch, tokens, features
81
+ assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
82
+ else:
83
+ x = vision_outputs[-2]
84
+
85
+ x = self.ln1(x)
86
+
87
+ if self.pos_emb is not None:
88
+ assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
89
+ x = x + self.pos_emb
90
+
91
+ x = self.linear1(x)
92
+ x = self.activation(x)
93
+ x = self.linear2(x)
94
+
95
+ # Mode token
96
+ #mode_token = self.mode_token(mode)
97
+ #assert mode_token.shape == (x.shape[0], mode_token.shape[1], x.shape[2]), f"Expected {(x.shape[0], 1, x.shape[2])}, got {mode_token.shape}"
98
+ #x = torch.cat((x, mode_token), dim=1)
99
+
100
+ # <|image_start|>, IMAGE, <|image_end|>
101
+ other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
102
+ assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
103
+ x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
104
+
105
+ return x
106
+
107
+ def get_eot_embedding(self):
108
+ return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
109
+
110
+
111
+
112
+ # Load CLIP
113
+ print("Loading CLIP")
114
+ clip_processor = AutoProcessor.from_pretrained(CLIP_PATH)
115
+ clip_model = AutoModel.from_pretrained(CLIP_PATH)
116
+ clip_model = clip_model.vision_model
117
+
118
+ if (CHECKPOINT_PATH / "clip_model.pt").exists():
119
+ print("Loading VLM's custom vision model")
120
+ checkpoint = torch.load(CHECKPOINT_PATH / "clip_model.pt", map_location='cpu')
121
+ checkpoint = {k.replace("_orig_mod.module.", ""): v for k, v in checkpoint.items()}
122
+ clip_model.load_state_dict(checkpoint)
123
+ del checkpoint
124
+
125
+ clip_model.eval()
126
+ clip_model.requires_grad_(False)
127
+ clip_model.to("cuda")
128
+
129
+
130
+ # Tokenizer
131
+ print("Loading tokenizer")
132
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
133
+ assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
134
+
135
+ # LLM
136
+ print("Loading LLM")
137
+ if (CHECKPOINT_PATH / "text_model").exists:
138
+ print("Loading VLM's custom text model")
139
+ text_model = AutoModelForCausalLM.from_pretrained(CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16)
140
+ else:
141
+ text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
142
+
143
+ text_model.eval()
144
+
145
+ # Image Adapter
146
+ print("Loading image adapter")
147
+ image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False)
148
+ image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu", weights_only=True))
149
+ image_adapter.eval()
150
+ image_adapter.to("cuda")
151
+
152
+
153
+ def preprocess_image(input_image: Image.Image) -> torch.Tensor:
154
+ """
155
+ Preprocess the input image for the CLIP model.
156
+ """
157
+ image = input_image.resize((384, 384), Image.LANCZOS)
158
+ pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
159
+ pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
160
+ return pixel_values.to('cuda')
161
+
162
+ def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max_new_tokens: int = 300) -> str:
163
+ """
164
+ Generate a caption based on the image features and prompt.
165
+ """
166
+ prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
167
+ prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
168
+ embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
169
+ eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
170
+
171
+ inputs_embeds = torch.cat([
172
+ embedded_bos.expand(image_features.shape[0], -1, -1),
173
+ image_features.to(dtype=embedded_bos.dtype),
174
+ prompt_embeds.expand(image_features.shape[0], -1, -1),
175
+ eot_embed.expand(image_features.shape[0], -1, -1),
176
+ ], dim=1)
177
+
178
+ input_ids = torch.cat([
179
+ torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long),
180
+ torch.zeros((1, image_features.shape[1]), dtype=torch.long),
181
+ prompt,
182
+ torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
183
+ ], dim=1).to('cuda')
184
+ attention_mask = torch.ones_like(input_ids)
185
+
186
+ generate_ids = text_model.generate(input_ids, inputs_embeds=inputs_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, suppress_tokens=None)
187
+
188
+ generate_ids = generate_ids[:, input_ids.shape[1]:]
189
+ if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
190
+ generate_ids = generate_ids[:, :-1]
191
+
192
+ return tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0].strip()
193
+
194
+ @spaces.GPU()
195
+ @torch.no_grad()
196
+ def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int, lens_type: str = "", film_stock: str = "", composition_style: str = "", lighting_aspect: str = "", special_technique: str = "", color_effect: str = "") -> str:
197
+ """
198
+ Generate a caption, training prompt, tags, or a style prompt for image generation based on the input image and parameters.
199
+ """
200
+ # Check if an image has been uploaded
201
+ if input_image is None:
202
+ return "Error: Please upload an image before generating a caption."
203
+
204
+ torch.cuda.empty_cache()
205
+
206
  try:
207
+ length = None if caption_length == "any" else caption_length
208
+ if isinstance(length, str):
209
+ length = int(length)
210
+ except ValueError:
211
+ raise ValueError(f"Invalid caption length: {caption_length}")
212
+
213
+ if caption_type in ["rng-tags", "training_prompt", "style_prompt"]:
214
+ caption_tone = "formal"
215
+
216
+ prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
217
+ if prompt_key not in CAPTION_TYPE_MAP:
218
+ raise ValueError(f"Invalid caption type: {prompt_key}")
219
+
220
+ if caption_type == "style_prompt":
221
+ # For style prompt, we'll create a custom prompt for the LLM
222
+ base_prompt = "Analyze the given image and create a detailed Stable Diffusion prompt for generating a new, creative image inspired by it. "
223
+ base_prompt += "The prompt should describe the main elements, style, and mood of the image, "
224
+ base_prompt += "but also introduce creative variations or enhancements. "
225
+ base_prompt += "Include specific details about the composition, lighting, and overall atmosphere. "
226
+
227
+ # Add custom settings to the prompt
228
+ if lens_type:
229
+ lens_type_key = lens_type.split(":")[0].strip()
230
+ base_prompt += f"Incorporate the effect of a {lens_type_key} lens ({lens_types_info[lens_type_key]}). "
231
+ if film_stock:
232
+ film_stock_key = film_stock.split(":")[0].strip()
233
+ base_prompt += f"Apply the characteristics of {film_stock_key} film stock ({film_stocks_info[film_stock_key]}). "
234
+ if composition_style:
235
+ composition_style_key = composition_style.split(":")[0].strip()
236
+ base_prompt += f"Use a {composition_style_key} composition style ({composition_styles_info[composition_style_key]}). "
237
+ if lighting_aspect:
238
+ lighting_aspect_key = lighting_aspect.split(":")[0].strip()
239
+ base_prompt += f"Implement {lighting_aspect_key} lighting ({lighting_aspects_info[lighting_aspect_key]}). "
240
+ if special_technique:
241
+ special_technique_key = special_technique.split(":")[0].strip()
242
+ base_prompt += f"Apply the {special_technique_key} technique ({special_techniques_info[special_technique_key]}). "
243
+ if color_effect:
244
+ color_effect_key = color_effect.split(":")[0].strip()
245
+ base_prompt += f"Use a {color_effect_key} color effect ({color_effects_info[color_effect_key]}). "
246
+
247
+ base_prompt += f"The final prompt should be approximately {length} words long. "
248
+ base_prompt += "Format the output as a single paragraph without numbering or bullet points."
249
+
250
+ prompt_str = base_prompt
251
  else:
252
+ prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
+ # Debugging: Print the constructed prompt string
255
+ print(f"Constructed Prompt: {prompt_str}")
 
 
 
 
 
 
 
256
 
257
+ pixel_values = preprocess_image(input_image)
258
+
259
+ with torch.amp.autocast_mode.autocast('cuda', enabled=True):
260
+ vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
261
+ image_features = vision_outputs.hidden_states
262
+ embedded_images = image_adapter(image_features)
263
+ embedded_images = embedded_images.to('cuda')
264
+
265
+ # Load the model from MODEL_PATH
266
+ text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
267
+ text_model.eval()
268
 
269
+ # Debugging: Print the prompt string before passing to generate_caption
270
+ print(f"Prompt passed to generate_caption: {prompt_str}")
271
+
272
+ caption = generate_caption(text_model, tokenizer, embedded_images, prompt_str)
273
+
274
+ return caption
275
 
276
  css = """
277
+ h1, h2, h3, h4, h5, h6, p, li, ul, ol, a, img {
278
+ text-align: left;
279
+ }
280
+ img {
281
+ display: inline-block;
282
+ vertical-align: middle;
283
+ margin-right: 10px;
284
+ max-width: 100%;
285
+ height: auto;
286
+ }
287
+ .centered-image {
288
  display: block;
289
  margin-left: auto;
290
  margin-right: auto;
291
+ max-width: 100%;
292
+ height: auto;
293
  }
294
  ul, ol {
295
+ padding-left: 20px;
 
 
296
  }
297
+ .gradio-container {
298
+ max-width: 100% !important;
299
+ padding: 0 !important;
300
+ }
301
+ .gradio-row {
302
+ margin-left: 0 !important;
303
+ margin-right: 0 !important;
304
+ }
305
+ .gradio-column {
306
+ padding-left: 0 !important;
307
+ padding-right: 0 !important;
308
+ }
309
+ /* Left-align dropdown text */
310
+ .gradio-dropdown > div {
311
+ text-align: left !important;
312
+ }
313
+ /* Left-align checkbox labels */
314
+ .gradio-checkbox label {
315
+ text-align: left !important;
316
+ }
317
+ /* Left-align radio button labels */
318
+ .gradio-radio label {
319
+ text-align: left !important;
320
  }
321
  """
322
 
323
+ # Add detailed descriptions for each option
324
+ lens_types_info = {
325
+ "Standard": "A versatile lens with a field of view similar to human vision.",
326
+ "Wide-angle": "Captures a wider field of view, great for landscapes and architecture. Applies moderate to strong lens effect with image warp.",
327
+ "Telephoto": "Used for distant subjects, gives an 'award-winning' or 'National Geographic' look. Creates interesting effects when prompted.",
328
+ "Macro": "For extreme close-up photography, revealing tiny details.",
329
+ "Fish-eye": "Ultra-wide-angle lens that creates a strong bubble-like distortion. Generates panoramic photos with the entire image warping into a bubble.",
330
+ "Tilt-shift": "Allows adjusting the plane of focus, creating a 'miniature' effect. Known for the 'diorama miniature look'.",
331
+ "Zoom lens": "Variable focal length lens. Often zooms in on the subject, perfect for creating a base for inpainting. Interesting effect on landscapes with motion blur.",
332
+ "GoPro": "Wide-angle lens with clean digital look. Excludes film grain and most filter effects, resulting in natural colors and regular saturation.",
333
+ "Pinhole camera": "Creates a unique, foggy, low-detail, historic photograph look. Used since the 1850s, with peak popularity in the 1930s."
334
+ }
335
+
336
+ film_stocks_info = {
337
+ "Kodak Portra": "Professional color negative film known for its natural skin tones and low contrast.",
338
+ "Fujifilm Velvia": "Slide film known for vibrant colors and high saturation, popular among landscape photographers.",
339
+ "Ilford Delta": "Black and white film known for its fine grain and high sharpness.",
340
+ "Kodak Tri-X": "Classic high-speed black and white film, known for its distinctive grain and wide exposure latitude.",
341
+ "Fujifilm Provia": "Color reversal film known for its natural color reproduction and fine grain.",
342
+ "Cinestill": "Color photos with fine/low grain and higher than average resolution. Colors are slightly oversaturated or slightly desaturated.",
343
+ "Ektachrome": "Color photos with fine/low to moderate grain. Colors on the colder part of spectrum or regular, with normal or slightly higher saturation.",
344
+ "Ektar": "Modern Kodak film. Color photos with little to no grain. Results look like regular modern photography with artistic angles.",
345
+ "Film Washi": "Mostly black and white photos with fine/low to moderate grain. Occasionally gives colored photos with low saturation. Distinct style with high black contrast and soft camera lens effect.",
346
+ "Fomapan": "Black and white photos with fine/low to moderate grain, highly artistic exposure and angles. Adds very soft lens effect without distortion, dark photo vignette.",
347
+ "Fujicolor": "Color photos with fine/low to moderate grain. Colors are either very oversaturated or slightly desaturated, with entire color hue shifted in a very distinct manner.",
348
+ "Holga": "Color photos with high grain. Colors are either very oversaturated or slightly desaturated. Distinct contrast of black. Often applies photographic vignette.",
349
+ "Instax": "Instant color photos similar to Polaroid but clearer. Near perfect colors, regular saturation, fine/low to medium grain.",
350
+ "Lomography": "Color photos with high grain. Colors are either very oversaturated or slightly desaturated. Distinct contrast of black. Often applies photographic vignette.",
351
+ "Kodachrome": "Color photos with moderate grain. Colors on either colder part of spectrum or regular, with normal or slightly higher saturation.",
352
+ "Rollei": "Mostly black and white photos, sometimes color with fine/low grain. Can be sepia colored or have unusual hues and desaturation. Great for landscapes."
353
+ }
354
+
355
+ composition_styles_info = {
356
+ "Rule of Thirds": "Divides the frame into a 3x3 grid, placing key elements along the lines or at their intersections.",
357
+ "Golden Ratio": "Uses a spiral based on the golden ratio to create a balanced and aesthetically pleasing composition.",
358
+ "Symmetry": "Creates a mirror-like balance in the image, often used for architectural or nature photography.",
359
+ "Leading Lines": "Uses lines within the frame to draw the viewer's eye to the main subject or through the image.",
360
+ "Framing": "Uses elements within the scene to create a frame around the main subject.",
361
+ "Minimalism": "Simplifies the composition to its essential elements, often with a lot of negative space.",
362
+ "Fill the Frame": "The main subject dominates the entire frame, leaving little to no background.",
363
+ "Negative Space": "Uses empty space around the subject to create a sense of simplicity or isolation.",
364
+ "Centered Composition": "Places the main subject in the center of the frame, creating a sense of stability or importance.",
365
+ "Diagonal Lines": "Uses diagonal elements to create a sense of movement or dynamic tension in the image.",
366
+ "Triangular Composition": "Arranges elements in the frame to form a triangle, creating a sense of stability and harmony.",
367
+ "Radial Balance": "Arranges elements in a circular pattern around a central point, creating a sense of movement or completeness."
368
+ }
369
+
370
+ lighting_aspects_info = {
371
+ "Natural light": "Uses available light from the sun or sky, often creating soft, even illumination.",
372
+ "Studio lighting": "Controlled artificial lighting setup, allowing for precise manipulation of light and shadow.",
373
+ "Back light": "Light source behind the subject, creating silhouettes or rim lighting effects.",
374
+ "Split light": "Strong light source at 90-degree angle, lighting one half of the subject while leaving the other in shadow.",
375
+ "Broad light": "Light source at an angle to the subject, producing well-lit photographs with soft to moderate shadows.",
376
+ "Dim light": "Weak or distant light source, creating lower than average brightness and often dramatic images.",
377
+ "Flash photography": "Uses a brief, intense burst of light. Can be fill flash (even lighting) or harsh flash (strong contrasts).",
378
+ "Sunlight": "Direct light from the sun, often creating strong contrasts and warm tones.",
379
+ "Moonlight": "Soft, cool light from the moon, often creating a mysterious or romantic atmosphere.",
380
+ "Spotlight": "Focused beam of light illuminating a specific area, creating high contrast between light and shadow.",
381
+ "High-key lighting": "Bright, even lighting with minimal shadows, creating a light and airy feel.",
382
+ "Low-key lighting": "Predominantly dark tones with selective lighting, creating a moody or dramatic atmosphere.",
383
+ "Rembrandt lighting": "Classic portrait lighting technique creating a triangle of light on the cheek of the subject."
384
+ }
385
+
386
+ special_techniques_info = {
387
+ "Double exposure": "Superimposes two exposures to create a single image, often resulting in a dreamy or surreal effect.",
388
+ "Long exposure": "Uses a long shutter speed to capture motion over time, often creating smooth, blurred effects for moving elements.",
389
+ "Multiple exposure": "Superimposes multiple exposures, multiplying the subject or its key elements across the image.",
390
+ "HDR": "High Dynamic Range imaging, combining multiple exposures to capture a wider range of light and dark tones.",
391
+ "Bokeh effect": "Creates a soft, out-of-focus background, often with circular highlights.",
392
+ "Silhouette": "Captures the outline of a subject against a brighter background, creating a dramatic contrast.",
393
+ "Panning": "Follows a moving subject with the camera, creating a sharp subject with a blurred background.",
394
+ "Light painting": "Uses long exposure and moving light sources to 'paint' with light in the image.",
395
+ "Infrared photography": "Captures light in the infrared spectrum, often resulting in surreal, otherworldly images.",
396
+ "Ultraviolet photography": "Captures light in the ultraviolet spectrum, often revealing hidden patterns or creating a strong violet glow.",
397
+ "Kirlian photography": "High-voltage photographic technique that captures corona discharges around objects, creating a glowing effect.",
398
+ "Thermography": "Captures infrared radiation to create images based on temperature differences, resulting in false-color heat maps.",
399
+ "Astrophotography": "Specialized technique for capturing astronomical objects and celestial events, often resulting in stunning starry backgrounds.",
400
+ "Underwater photography": "Captures images beneath the surface of water, often in pools, seas, or aquariums.",
401
+ "Aerial photography": "Captures images from an elevated position, such as from drones, helicopters, or planes.",
402
+ "Macro photography": "Extreme close-up photography, revealing tiny details not visible to the naked eye."
403
+ }
404
+
405
+ color_effects_info = {
406
+ "Black and white": "Removes all color, leaving only shades of gray.",
407
+ "Sepia": "Reddish-brown monochrome effect, often associated with vintage photography.",
408
+ "Monochrome": "Uses variations of a single color.",
409
+ "Vintage color": "Muted or faded color palette reminiscent of old photographs.",
410
+ "Cross-processed": "Deliberate processing of film in the wrong chemicals, creating unusual color shifts.",
411
+ "Desaturated": "Reduces the intensity of all colors in the image.",
412
+ "Vivid colors": "Increases the saturation and intensity of colors.",
413
+ "Pastel colors": "Soft, pale colors with a light and airy feel.",
414
+ "High contrast": "Emphasizes the difference between light and dark areas in the image.",
415
+ "Low contrast": "Reduces the difference between light and dark areas, creating a softer look.",
416
+ "Color splash": "Converts most of the image to black and white while leaving one or more elements in color."
417
+ }
418
+
419
+ def get_dropdown_choices(info_dict):
420
+ return [f"{key}: {value}" for key, value in info_dict.items()]
421
+
422
  def login(username, password):
423
  if username == USERNAME and password == PASSWORD:
424
  return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(value="Login successful! You can now access the QR Code Art Generator tab.", visible=True)
425
  else:
426
  return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(value="Invalid username or password. Please try again.", visible=True)
 
 
 
 
427
 
428
+ # Gradio interface
429
+ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True) as demo:
430
  with gr.Tab("Welcome"):
431
  with gr.Row():
432
+ with gr.Column(scale=2):
433
  gr.Markdown(
434
  """
435
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/LVZnwLV43UUvKu3HORqSs.webp" alt="UDG" width="250" style="max-width: 100%; height: auto; class="centered-image">
436
+
437
+ # 🎨 Underground Digital's Caption Captain: AI-Powered Art Inspiration
438
+
439
+ ## Accelerate Your Creative Workflow with Intelligent Image Analysis
440
+
441
+ This innovative tool empowers Yamamoto's artists to quickly generate descriptive captions,<br>
442
+ training prompts, and tags from existing artwork, fueling the creative process for GenAI models.
443
+
444
  ## 🚀 How It Works:
445
+ 1. **Upload Your Inspiration**: Drop in an image (e.g., a charcoal horse picture) that embodies your desired style.
446
+ 2. **Choose Your Output**: Select from descriptive captions, training prompts, or tags.
447
+ 3. **Customize the Results**: Adjust tone, length, and other parameters to fine-tune the output.
448
+ 4. **Generate and Iterate**: Click 'Caption' to analyze your image and use the results to inspire new creations.
449
  """
450
  )
451
+
452
  with gr.Column(scale=1):
453
  with gr.Row():
454
  gr.Markdown(
 
469
  login_button = gr.Button("Login", size="sm")
470
  login_message = gr.Markdown(visible=False)
471
 
472
+ with gr.Tab("Caption Captain") as app_container:
473
+ with gr.Accordion("How to Use Caption Captain", open=False):
474
+ gr.Markdown("""
475
+ # How to Use Caption Captain
476
+
477
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/Ce_Z478iOXljvpZ_Fr_Y7.png" alt="Captain" width="100" style="max-width: 100%; height: auto;">
478
+
479
+ Hello, artist! Let's make some fun captions for your pictures. Here's how:
480
+
481
+ 1. **Pick a Picture**: Find a cool picture you want to talk about and upload it.
482
+
483
+ 2. **Choose What You Want**:
484
+ - **Caption Type**:
485
+ * "Descriptive" tells you what's in the picture
486
+ * "Training Prompt" helps computers make similar pictures
487
+ * "RNG-Tags" gives you short words about the picture
488
+ * "Style Prompt" creates detailed prompts for image generation
489
+
490
+ 3. **Pick a Style** (for "Descriptive" and "Style Prompt" only):
491
+ - "Formal" sounds like a teacher talking
492
+ - "Informal" sounds like a friend chatting
493
+
494
+ 4. **Decide How Long**:
495
+ - "Any" lets the computer decide
496
+ - Or pick a size from "very short" to "very long"
497
+ - You can even choose a specific number of words!
498
+
499
+ 5. **Advanced Options** (for "Style Prompt" only):
500
+ - Choose lens type, film stock, composition, and lighting details
501
+
502
+ 6. **Make the Caption**: Click the "Make My Caption!" button and watch the magic happen!
503
+
504
+ Remember, have fun and be creative with your captions!
505
+
506
+ ## Tips for Great Captions:
507
+ - Try different types to see what you like best
508
+ - Experiment with formal and informal tones for fun variations
509
+ - Adjust the length to get just the right amount of detail
510
+ - For "Style Prompt", play with the advanced options for more specific results
511
+ - If you don't like a caption, just click "Make My Caption!" again for a new one
512
+
513
+ Have a great time captioning your art!
514
+ """)
515
 
 
516
  with gr.Row():
517
  with gr.Column():
518
+ input_image = gr.Image(type="pil", label="Input Image")
 
 
 
 
 
 
519
 
520
+ caption_type = gr.Dropdown(
521
+ choices=["descriptive", "training_prompt", "rng-tags", "style_prompt"],
522
+ label="Caption Type",
523
+ value="descriptive",
 
 
524
  )
525
+
526
+ caption_tone = gr.Dropdown(
527
+ choices=["formal", "informal"],
528
+ label="Caption Tone",
529
+ value="formal",
 
530
  )
531
 
532
+ caption_length = gr.Dropdown(
533
+ choices=["any", "very short", "short", "medium-length", "long", "very long"] +
534
+ [str(i) for i in range(20, 261, 10)],
535
+ label="Caption Length",
536
+ value="any",
537
+ )
538
 
539
+ gr.Markdown("**Note:** Caption tone doesn't affect `rng-tags`, `training_prompt`, and `style_prompt`.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
 
 
 
 
 
 
 
 
542
  with gr.Column():
543
+ error_message = gr.Markdown(visible=False) # Add this line
544
+ output_caption = gr.Textbox(label="Generated Caption")
545
+ run_button = gr.Button("Make My Caption!")
546
+
547
+ # Container for advanced options
548
+ with gr.Column(visible=False) as advanced_options:
549
+ gr.Markdown("### Advanced Options for Style Prompt")
550
+ lens_type = gr.Dropdown(
551
+ choices=get_dropdown_choices(lens_types_info),
552
+ label="Lens Type",
553
+ info="Select a lens type to define the perspective and field of view of the image."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
  )
555
+ film_stock = gr.Dropdown(
556
+ choices=get_dropdown_choices(film_stocks_info),
557
+ label="Film Stock",
558
+ info="Choose a film stock to determine the color, grain, and overall look of the image."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
  )
560
+ composition_style = gr.Dropdown(
561
+ choices=get_dropdown_choices(composition_styles_info),
562
+ label="Composition Style",
563
+ info="Select a composition style to guide the arrangement of elements in the image."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
564
  )
565
+ lighting_aspect = gr.Dropdown(
566
+ choices=get_dropdown_choices(lighting_aspects_info),
567
+ label="Lighting Aspect",
568
+ info="Choose a lighting style to define the mood and atmosphere of the image."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  )
570
+ special_technique = gr.Dropdown(
571
+ choices=get_dropdown_choices(special_techniques_info),
572
+ label="Special Technique",
573
+ info="Select a special photographic technique to add unique effects to the image."
 
 
 
 
 
 
 
 
 
 
 
 
574
  )
575
+ color_effect = gr.Dropdown(
576
+ choices=get_dropdown_choices(color_effects_info),
577
+ label="Color Effect",
578
+ info="Choose a color effect to alter the overall color palette of the image."
 
 
 
 
 
 
 
 
 
 
 
 
579
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
 
581
+ def update_style_options(caption_type):
582
+ return gr.update(visible=caption_type == "style_prompt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
 
584
+ caption_type.change(update_style_options, inputs=[caption_type], outputs=[advanced_options])
585
+
586
+ def process_and_handle_errors(input_image, caption_type, caption_tone, caption_length, lens_type, film_stock, composition_style, lighting_aspect, special_technique, color_effect):
587
+ try:
588
+ result = stream_chat(input_image, caption_type, caption_tone, caption_length, lens_type, film_stock, composition_style, lighting_aspect, special_technique, color_effect)
589
+ return gr.update(visible=False), result
590
+ except Exception as e:
591
+ return gr.update(visible=True, value=f"Error: {str(e)}"), ""
592
+
593
+ run_button.click(
594
+ fn=process_and_handle_errors,
595
+ inputs=[input_image, caption_type, caption_tone, caption_length, lens_type, film_stock, composition_style, lighting_aspect, special_technique, color_effect],
596
+ outputs=[error_message, output_caption]
597
  )
598
 
 
 
599
 
600
+ if __name__ == "__main__":
601
+ demo.launch()