ginipick commited on
Commit
84ce695
ยท
verified ยท
1 Parent(s): 6bee32e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +341 -503
app.py CHANGED
@@ -1,514 +1,352 @@
1
- import argparse
2
- import spaces
3
- from visualcloze import VisualClozeModel
4
  import gradio as gr
5
- import examples
 
 
6
  import torch
7
- from functools import partial
8
- from data.prefix_instruction import get_layout_instruction
9
- from huggingface_hub import snapshot_download
10
-
11
- # Define the missing variables here
12
- GUIDANCE = """
13
- ## How to use this demo:
14
- 1. Select a task example from the right side, or prepare your own in-context examples and query.
15
- 2. The grid will be filled with in-context examples and a query row.
16
- 3. You can modify the task description or add content descriptions.
17
- 4. Click "Generate" to create images following the pattern shown in examples.
18
- """
19
-
20
- NOTE = """
21
- **Note:** The examples on the right side demonstrate various tasks.
22
- Click on any example to load it into the interface. You can then modify images or prompts as needed.
23
- """
 
 
 
 
 
 
 
 
 
24
 
25
- CITATION = """
26
- ## Paper Citation
27
- ```
28
- @article{liu2024visualcloze,
29
- title={VisualCloze: A Universal Image Generation Framework via Visual In-Context Learning},
30
- author={Liu, Zhaoyang and Lian, Yuheng and Wang, Jianfeng and Zhou, Aojun and Liu, Jiashi and Ye, Hang and Chen, Kai and Wang, Jingdong and Zhao, Deli},
31
- journal={arXiv preprint arXiv:2504.07960},
32
- year={2024}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  }
34
- ```
35
  """
36
 
37
- max_grid_h = 5
38
- max_grid_w = 5
39
- default_grid_h = 2
40
- default_grid_w = 3
41
- default_upsampling_noise = 0.4
42
- default_steps = 30
43
-
44
-
45
- def create_demo(model):
46
- with gr.Blocks(title="VisualCloze Demo") as demo:
47
- gr.Markdown("# VisualCloze: A Universal Image Generation Framework via Visual In-Context Learning")
48
-
49
- gr.HTML("""
50
- <div style="display:flex;column-gap:4px;">
51
- <a href="https://github.com/lzyhha/VisualCloze">
52
- <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
53
- </a>
54
- <a href="https://visualcloze.github.io/">
55
- <img src='https://img.shields.io/badge/Project-Website-green'>
56
- </a>
57
- <a href="https://arxiv.org/abs/2504.07960">
58
- <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
59
- </a>
60
- <a href="https://huggingface.co/VisualCloze/VisualCloze">
61
- <img src='https://img.shields.io/badge/VisualCloze%20checkpoint-HF%20Model-green?logoColor=violet&label=%F0%9F%A4%97%20Checkpoint'>
62
- </a>
63
- <a href="https://huggingface.co/datasets/VisualCloze/Graph200K">
64
- <img src='https://img.shields.io/badge/VisualCloze%20datasets-HF%20Dataset-6B88E3?logoColor=violet&label=%F0%9F%A4%97%20Graph200k%20Dataset'>
65
- </a>
66
- </div>
67
- """)
68
-
69
- gr.Markdown(GUIDANCE)
70
-
71
- # Pre-create all possible image components
72
- all_image_inputs = []
73
- rows = []
74
- row_texts = []
75
- with gr.Row():
76
-
77
- with gr.Column(scale=2):
78
- # Image grid
79
- for i in range(max_grid_h):
80
- # Add row label before each row
81
- row_texts.append(gr.Markdown(
82
- "## Query" if i == default_grid_h - 1 else f"## In-context Example {i + 1}",
83
- elem_id=f"row_text_{i}",
84
- visible=i < default_grid_h
85
- ))
86
- with gr.Row(visible=i < default_grid_h, elem_id=f"row_{i}") as row:
87
- rows.append(row)
88
- for j in range(max_grid_w):
89
- img_input = gr.Image(
90
- label=f"In-context Example {i + 1}/{j + 1}" if i != default_grid_h - 1 else f"Query {j + 1}",
91
- type="pil",
92
- visible= i < default_grid_h and j < default_grid_w,
93
- interactive=True,
94
- elem_id=f"img_{i}_{j}"
95
  )
96
- all_image_inputs.append(img_input)
97
-
98
- # Prompts
99
- layout_prompt = gr.Textbox(
100
- label="Layout Description (Auto-filled, Read-only)",
101
- placeholder="Layout description will be automatically filled based on grid size...",
102
- value=get_layout_instruction(default_grid_w, default_grid_h),
103
- elem_id="layout_prompt",
104
- interactive=False
105
- )
106
-
107
- task_prompt = gr.Textbox(
108
- label="Task Description (Can be modified by referring to examples to perform custom tasks, but may lead to unstable results)",
109
- placeholder="Describe what task should be performed...",
110
- value="",
111
- elem_id="task_prompt"
112
- )
113
-
114
- content_prompt = gr.Textbox(
115
- label="(Optional) Content Description (Image caption, Editing instructions, etc.)",
116
- placeholder="Describe the content requirements...",
117
- value="",
118
- elem_id="content_prompt"
119
- )
120
-
121
- generate_btn = gr.Button("Generate", elem_id="generate_btn")
122
- gr.Markdown(NOTE)
123
-
124
- grid_h = gr.Slider(minimum=0, maximum=max_grid_h-1, value=default_grid_h-1, step=1, label="Number of In-context Examples", elem_id="grid_h")
125
- grid_w = gr.Slider(minimum=1, maximum=max_grid_w, value=default_grid_w, step=1, label="Task Columns", elem_id="grid_w")
126
-
127
- with gr.Accordion("Advanced options", open=False):
128
- seed = gr.Number(label="Seed (0 for random)", value=0, precision=0)
129
- steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=default_steps, step=1)
130
- cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=50.0, value=30, step=1)
131
- upsampling_steps = gr.Slider(label="Upsampling steps (SDEdit)", minimum=1, maximum=100.0, value=10, step=1)
132
- upsampling_noise = gr.Slider(label="Upsampling noise (SDEdit)", minimum=0, maximum=1.0, value=default_upsampling_noise, step=0.05)
133
-
134
- gr.Markdown(CITATION)
135
-
136
- # Output
137
- with gr.Column(scale=2):
138
- output_gallery = gr.Gallery(
139
- label="Generated Results",
140
- show_label=True,
141
- elem_id="output_gallery",
142
- columns=None,
143
- rows=None,
144
- height="auto",
145
- allow_preview=True,
146
- object_fit="contain"
147
- )
148
-
149
- gr.Markdown("# Task Examples")
150
- gr.Markdown("Each click on a task may result in different examples.")
151
- text_dense_prediction_tasks = gr.Textbox(label="Task", visible=False)
152
- dense_prediction_tasks = gr.Dataset(
153
- samples=examples.dense_prediction_text,
154
- label='Dense Prediction',
155
- samples_per_page=1000,
156
- components=[text_dense_prediction_tasks])
157
-
158
- text_conditional_generation_tasks = gr.Textbox(label="Task", visible=False)
159
- conditional_generation_tasks = gr.Dataset(
160
- samples=examples.conditional_generation_text,
161
- label='Conditional Generation',
162
- samples_per_page=1000,
163
- components=[text_conditional_generation_tasks])
164
-
165
- text_image_restoration_tasks = gr.Textbox(label="Task", visible=False)
166
- image_restoration_tasks = gr.Dataset(
167
- samples=examples.image_restoration_text,
168
- label='Image Restoration',
169
- samples_per_page=1000,
170
- components=[text_image_restoration_tasks])
171
-
172
- text_style_transfer_tasks = gr.Textbox(label="Task", visible=False)
173
- style_transfer_tasks = gr.Dataset(
174
- samples=examples.style_transfer_text,
175
- label='Style Transfer',
176
- samples_per_page=1000,
177
- components=[text_style_transfer_tasks])
178
-
179
- text_style_condition_fusion_tasks = gr.Textbox(label="Task", visible=False)
180
- style_condition_fusion_tasks = gr.Dataset(
181
- samples=examples.style_condition_fusion_text,
182
- label='Style Condition Fusion',
183
- samples_per_page=1000,
184
- components=[text_style_condition_fusion_tasks])
185
-
186
- text_tryon_tasks = gr.Textbox(label="Task", visible=False)
187
- tryon_tasks = gr.Dataset(
188
- samples=examples.tryon_text,
189
- label='Virtual Try-On',
190
- samples_per_page=1000,
191
- components=[text_tryon_tasks])
192
-
193
- text_relighting_tasks = gr.Textbox(label="Task", visible=False)
194
- relighting_tasks = gr.Dataset(
195
- samples=examples.relighting_text,
196
- label='Relighting',
197
- samples_per_page=1000,
198
- components=[text_relighting_tasks])
199
-
200
- text_photodoodle_tasks = gr.Textbox(label="Task", visible=False)
201
- photodoodle_tasks = gr.Dataset(
202
- samples=examples.photodoodle_text,
203
- label='Photodoodle',
204
- samples_per_page=1000,
205
- components=[text_photodoodle_tasks])
206
-
207
- text_editing_tasks = gr.Textbox(label="Task", visible=False)
208
- editing_tasks = gr.Dataset(
209
- samples=examples.editing_text,
210
- label='Editing',
211
- samples_per_page=1000,
212
- components=[text_editing_tasks])
213
-
214
- text_unseen_tasks = gr.Textbox(label="Task", visible=False)
215
- unseen_tasks = gr.Dataset(
216
- samples=examples.unseen_tasks_text,
217
- label='Unseen Tasks (May produce unstable effects)',
218
- samples_per_page=1000,
219
- components=[text_unseen_tasks])
220
-
221
- gr.Markdown("# Subject-driven Tasks Examples")
222
- text_subject_driven_tasks = gr.Textbox(label="Task", visible=False)
223
- subject_driven_tasks = gr.Dataset(
224
- samples=examples.subject_driven_text,
225
- label='Subject-driven Generation',
226
- samples_per_page=1000,
227
- components=[text_subject_driven_tasks])
228
-
229
- text_condition_subject_fusion_tasks = gr.Textbox(label="Task", visible=False)
230
- condition_subject_fusion_tasks = gr.Dataset(
231
- samples=examples.condition_subject_fusion_text,
232
- label='Condition+Subject Fusion',
233
- samples_per_page=1000,
234
- components=[text_condition_subject_fusion_tasks])
235
-
236
- text_style_transfer_with_subject_tasks = gr.Textbox(label="Task", visible=False)
237
- style_transfer_with_subject_tasks = gr.Dataset(
238
- samples=examples.style_transfer_with_subject_text,
239
- label='Style Transfer with Subject',
240
- samples_per_page=1000,
241
- components=[text_style_transfer_with_subject_tasks])
242
-
243
- text_condition_subject_style_fusion_tasks = gr.Textbox(label="Task", visible=False)
244
- condition_subject_style_fusion_tasks = gr.Dataset(
245
- samples=examples.condition_subject_style_fusion_text,
246
- label='Condition+Subject+Style Fusion',
247
- samples_per_page=1000,
248
- components=[text_condition_subject_style_fusion_tasks])
249
-
250
- text_editing_with_subject_tasks = gr.Textbox(label="Task", visible=False)
251
- editing_with_subject_tasks = gr.Dataset(
252
- samples=examples.editing_with_subject_text,
253
- label='Editing with Subject',
254
- samples_per_page=1000,
255
- components=[text_editing_with_subject_tasks])
256
-
257
- text_image_restoration_with_subject_tasks = gr.Textbox(label="Task", visible=False)
258
- image_restoration_with_subject_tasks = gr.Dataset(
259
- samples=examples.image_restoration_with_subject_text,
260
- label='Image Restoration with Subject',
261
- samples_per_page=1000,
262
- components=[text_image_restoration_with_subject_tasks])
263
-
264
- def update_grid(h, w):
265
- actual_h = h + 1
266
- model.set_grid_size(actual_h, w)
267
-
268
- updates = []
269
-
270
- # Update image component visibility
271
- for i in range(max_grid_h * max_grid_w):
272
- curr_row = i // max_grid_w
273
- curr_col = i % max_grid_w
274
- updates.append(
275
- gr.update(
276
- label=f"In-context Example {curr_row + 1}/{curr_col + 1}" if curr_row != actual_h - 1 else f"Query {curr_col + 1}",
277
- elem_id=f"img_{curr_row}_{curr_col}",
278
- visible=(curr_row < actual_h and curr_col < w)))
279
-
280
- # Update row visibility and labels
281
- updates_row = []
282
- updates_row_text = []
283
- for i in range(max_grid_h):
284
- updates_row.append(gr.update(f"row_{i}", visible=(i < actual_h)))
285
- updates_row_text.append(
286
- gr.update(
287
- elem_id=f"row_text_{i}",
288
- visible=i < actual_h,
289
- value="## Query" if i == actual_h - 1 else f"## In-context Example {i + 1}",
290
  )
291
- )
292
-
293
- updates.extend(updates_row)
294
- updates.extend(updates_row_text)
295
- updates.append(gr.update(elem_id="layout_prompt", value=get_layout_instruction(w, actual_h)))
296
- return updates
297
-
298
- def generate_image(*inputs):
299
- images = []
300
- if grid_h.value + 1 != model.grid_h or grid_w.value != model.grid_w:
301
- raise gr.Error('Please wait for the loading to complete.')
302
- for i in range(model.grid_h):
303
- images.append([])
304
- for j in range(model.grid_w):
305
- images[i].append(inputs[i * max_grid_w + j])
306
- if i != model.grid_h - 1:
307
- if inputs[i * max_grid_w + j] is None:
308
- raise gr.Error('Please upload in-context examples. Possible that the task examples have not finished loading yet, and you can try waiting a few seconds before clicking the button again.')
309
- seed, cfg, steps, upsampling_steps, upsampling_noise, layout_text, task_text, content_text = inputs[-8:]
310
-
311
- try:
312
- results = generate(
313
- images,
314
- [layout_text, task_text, content_text],
315
- seed=seed, cfg=cfg, steps=steps,
316
- upsampling_steps=upsampling_steps, upsampling_noise=upsampling_noise
317
- )
318
- except Exception as e:
319
- raise gr.Error('Process error. Possible that the task examples have not finished loading yet, and you can try waiting a few seconds before clicking the button again. Error: ' + str(e))
320
-
321
- output = gr.update(
322
- elem_id='output_gallery',
323
- value=results,
324
- columns=min(len(results), 2),
325
- rows=int(len(results) / 2 + 0.5))
326
-
327
- return output
328
-
329
- def process_tasks(task, func):
330
- outputs = func(task)
331
- mask = outputs[0]
332
- state = outputs[1:8]
333
- if state[5] is None:
334
- state[5] = default_upsampling_noise
335
- if state[6] is None:
336
- state[6] = default_steps
337
- images = outputs[8:-len(mask)]
338
- output = outputs[-len(mask):]
339
- for i in range(len(mask)):
340
- if mask[i] == 1:
341
- images.append(None)
342
- else:
343
- images.append(output[-len(mask) + i])
344
-
345
- state[0] = state[0] - 1
346
- cur_hrid_h = state[0]
347
- cur_hrid_w = state[1]
348
-
349
- current_example = [None] * 25
350
- for i, image in enumerate(images):
351
- pos = (i // cur_hrid_w) * 5 + (i % cur_hrid_w)
352
- if image is not None:
353
- current_example[pos] = image
354
- update_grid(cur_hrid_h, cur_hrid_w)
355
- output = gr.update(
356
- elem_id='output_gallery',
357
- value=[o for o, m in zip(output, mask) if m == 1],
358
- columns=min(sum(mask), 2),
359
- rows=int(sum(mask) / 2 + 0.5))
360
- return [output] + current_example + state
361
-
362
- dense_prediction_tasks.click(
363
- partial(process_tasks, func=examples.process_dense_prediction_tasks),
364
- inputs=[dense_prediction_tasks],
365
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
366
- show_progress="full",
367
- show_progress_on=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps] + [generate_btn])
368
-
369
- conditional_generation_tasks.click(
370
- partial(process_tasks, func=examples.process_conditional_generation_tasks),
371
- inputs=[conditional_generation_tasks],
372
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
373
- show_progress="full")
374
-
375
- image_restoration_tasks.click(
376
- partial(process_tasks, func=examples.process_image_restoration_tasks),
377
- inputs=[image_restoration_tasks],
378
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
379
- show_progress="full")
380
-
381
- style_transfer_tasks.click(
382
- partial(process_tasks, func=examples.process_style_transfer_tasks),
383
- inputs=[style_transfer_tasks],
384
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
385
- show_progress="full")
386
-
387
- style_condition_fusion_tasks.click(
388
- partial(process_tasks, func=examples.process_style_condition_fusion_tasks),
389
- inputs=[style_condition_fusion_tasks],
390
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
391
- show_progress="full")
392
-
393
- relighting_tasks.click(
394
- partial(process_tasks, func=examples.process_relighting_tasks),
395
- inputs=[relighting_tasks],
396
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
397
- show_progress="full")
398
-
399
- tryon_tasks.click(
400
- partial(process_tasks, func=examples.process_tryon_tasks),
401
- inputs=[tryon_tasks],
402
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
403
- show_progress="full")
404
-
405
- photodoodle_tasks.click(
406
- partial(process_tasks, func=examples.process_photodoodle_tasks),
407
- inputs=[photodoodle_tasks],
408
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
409
- show_progress="full")
410
-
411
- editing_tasks.click(
412
- partial(process_tasks, func=examples.process_editing_tasks),
413
- inputs=[editing_tasks],
414
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
415
- show_progress="full")
416
-
417
- unseen_tasks.click(
418
- partial(process_tasks, func=examples.process_unseen_tasks),
419
- inputs=[unseen_tasks],
420
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
421
- show_progress="full")
422
-
423
- subject_driven_tasks.click(
424
- partial(process_tasks, func=examples.process_subject_driven_tasks),
425
- inputs=[subject_driven_tasks],
426
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
427
- show_progress="full")
428
-
429
- style_transfer_with_subject_tasks.click(
430
- partial(process_tasks, func=examples.process_style_transfer_with_subject_tasks),
431
- inputs=[style_transfer_with_subject_tasks],
432
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
433
- show_progress="full")
434
-
435
- condition_subject_fusion_tasks.click(
436
- partial(process_tasks, func=examples.process_condition_subject_fusion_tasks),
437
- inputs=[condition_subject_fusion_tasks],
438
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
439
- show_progress="full")
440
-
441
- condition_subject_style_fusion_tasks.click(
442
- partial(process_tasks, func=examples.process_condition_subject_style_fusion_tasks),
443
- inputs=[condition_subject_style_fusion_tasks],
444
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
445
- show_progress="full")
446
-
447
- editing_with_subject_tasks.click(
448
- partial(process_tasks, func=examples.process_editing_with_subject_tasks),
449
- inputs=[editing_with_subject_tasks],
450
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
451
- show_progress="full")
452
-
453
- image_restoration_with_subject_tasks.click(
454
- partial(process_tasks, func=examples.process_image_restoration_with_subject_tasks),
455
- inputs=[image_restoration_with_subject_tasks],
456
- outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
457
- show_progress="full")
458
- # Initialize grid
459
- model.set_grid_size(default_grid_h, default_grid_w)
460
-
461
- # Connect event processing function to all components that need updating
462
- output_components = all_image_inputs + rows + row_texts + [layout_prompt]
463
-
464
- grid_h.change(fn=update_grid, inputs=[grid_h, grid_w], outputs=output_components)
465
- grid_w.change(fn=update_grid, inputs=[grid_h, grid_w], outputs=output_components)
466
-
467
- # Modify generate button click event
468
- generate_btn.click(
469
- fn=generate_image,
470
- inputs=all_image_inputs + [seed, cfg, steps, upsampling_steps, upsampling_noise] + [layout_prompt, task_prompt, content_prompt],
471
- outputs=output_gallery
472
- )
473
-
474
- return demo
475
-
476
-
477
- @spaces.GPU()
478
- def generate(
479
- images,
480
- prompts,
481
- seed, cfg, steps,
482
- upsampling_steps, upsampling_noise):
483
- with torch.no_grad():
484
- return model.process_images(
485
- images=images,
486
- prompts=prompts,
487
- seed=seed,
488
- cfg=cfg,
489
- steps=steps,
490
- upsampling_steps=upsampling_steps,
491
- upsampling_noise=upsampling_noise)
492
-
493
-
494
- def parse_args():
495
- parser = argparse.ArgumentParser()
496
- parser.add_argument("--model_path", type=str, default="checkpoints/visualcloze-384-lora.pth")
497
- parser.add_argument("--precision", type=str, choices=["fp32", "bf16", "fp16"], default="bf16")
498
- parser.add_argument("--resolution", type=int, default=384)
499
- return parser.parse_args()
500
-
501
-
502
- if __name__ == "__main__":
503
- args = parse_args()
504
 
505
- snapshot_download(repo_id="VisualCloze/VisualCloze", repo_type="model", local_dir="checkpoints")
506
-
507
- # Initialize model
508
- model = VisualClozeModel(resolution=args.resolution, model_path=args.model_path, precision=args.precision)
 
509
 
510
- # Create Gradio demo
511
- demo = create_demo(model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
 
513
- # Start Gradio server
514
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ iimport os
2
+ import uuid
 
3
  import gradio as gr
4
+ import spaces
5
+ from clip_slider_pipeline import CLIPSliderFlux
6
+ from diffusers import FluxPipeline, AutoencoderTiny
7
  import torch
8
+ import numpy as np
9
+ import cv2
10
+ from PIL import Image
11
+ from diffusers.utils import load_image
12
+ from diffusers.utils import export_to_video
13
+ import random
14
+
15
+ # English menu labels
16
+ english_labels = {
17
+ "Prompt": "Prompt",
18
+ "1st direction to steer": "1st Direction",
19
+ "2nd direction to steer": "2nd Direction",
20
+ "Strength": "Strength",
21
+ "Generate directions": "Generate Directions",
22
+ "Generated Images": "Generated Images",
23
+ "From 1st to 2nd direction": "From 1st to 2nd Direction",
24
+ "Strip": "Image Strip",
25
+ "Looping video": "Looping Video",
26
+ "Advanced options": "Advanced Options",
27
+ "Num of intermediate images": "Number of Intermediate Images",
28
+ "Num iterations for clip directions": "Number of CLIP Direction Iterations",
29
+ "Num inference steps": "Number of Inference Steps",
30
+ "Guidance scale": "Guidance Scale",
31
+ "Randomize seed": "Randomize Seed",
32
+ "Seed": "Seed"
33
+ }
34
 
35
+ # Load pipelines
36
+ base_model = "black-forest-labs/FLUX.1-schnell"
37
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to("cuda")
38
+ pipe = FluxPipeline.from_pretrained(
39
+ base_model,
40
+ vae=taef1,
41
+ torch_dtype=torch.bfloat16
42
+ )
43
+ pipe.transformer.to(memory_format=torch.channels_last)
44
+ clip_slider = CLIPSliderFlux(pipe, device=torch.device("cuda"))
45
+ MAX_SEED = 2**32 - 1
46
+
47
+ def save_images_with_unique_filenames(image_list, save_directory):
48
+ if not os.path.exists(save_directory):
49
+ os.makedirs(save_directory)
50
+ paths = []
51
+ for image in image_list:
52
+ unique_filename = f"{uuid.uuid4()}.png"
53
+ file_path = os.path.join(save_directory, unique_filename)
54
+ image.save(file_path)
55
+ paths.append(file_path)
56
+ return paths
57
+
58
+ def convert_to_centered_scale(num):
59
+ if num % 2 == 0: # even
60
+ start = -(num // 2 - 1)
61
+ end = num // 2
62
+ else: # odd
63
+ start = -(num // 2)
64
+ end = num // 2
65
+ return tuple(range(start, end + 1))
66
+
67
+ def is_korean(text):
68
+ """ํ•œ๊ธ€ ํฌํ•จ ์—ฌ๋ถ€ ํ™•์ธ"""
69
+ return any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in text)
70
+
71
+ @spaces.GPU(duration=85)
72
+ def generate(prompt,
73
+ concept_1,
74
+ concept_2,
75
+ scale,
76
+ randomize_seed=True,
77
+ seed=42,
78
+ recalc_directions=True,
79
+ iterations=200,
80
+ steps=3,
81
+ interm_steps=33,
82
+ guidance_scale=3.5,
83
+ x_concept_1="", x_concept_2="",
84
+ avg_diff_x=None,
85
+ total_images=[],
86
+ gradio_progress=gr.Progress()):
87
+ # Check if there is Korean text and warn if so
88
+ if is_korean(prompt) or is_korean(concept_1) or is_korean(concept_2):
89
+ print("Korean text detected. The model will use it directly without translation.")
90
+
91
+ print(f"Prompt: {prompt}, โ† {concept_2}, {concept_1} โžก๏ธ . scale {scale}, interm steps {interm_steps}")
92
+ slider_x = [concept_2, concept_1]
93
+ if randomize_seed:
94
+ seed = random.randint(0, MAX_SEED)
95
+ if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]) or recalc_directions:
96
+ gradio_progress(0, desc="Calculating directions...")
97
+ avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1], num_iterations=iterations)
98
+ x_concept_1, x_concept_2 = slider_x[0], slider_x[1]
99
+ else:
100
+ avg_diff = avg_diff_x
101
+ images = []
102
+ high_scale = scale
103
+ low_scale = -1 * scale
104
+ for i in gradio_progress.tqdm(range(interm_steps), desc="Generating images"):
105
+ cur_scale = low_scale + (high_scale - low_scale) * i / (interm_steps - 1)
106
+ image = clip_slider.generate(
107
+ prompt,
108
+ width=768,
109
+ height=768,
110
+ guidance_scale=guidance_scale,
111
+ scale=cur_scale,
112
+ seed=seed,
113
+ num_inference_steps=steps,
114
+ avg_diff=avg_diff
115
+ )
116
+ images.append(image)
117
+ canvas = Image.new('RGB', (256 * interm_steps, 256))
118
+ for i, im in enumerate(images):
119
+ canvas.paste(im.resize((256, 256)), (256 * i, 0))
120
+ comma_concepts_x = f"{slider_x[1]}, {slider_x[0]}"
121
+ scale_total = convert_to_centered_scale(interm_steps)
122
+ scale_min = scale_total[0]
123
+ scale_max = scale_total[-1]
124
+ scale_middle = scale_total.index(0)
125
+ post_generation_slider_update = gr.update(label=comma_concepts_x, value=0, minimum=scale_min, maximum=scale_max, interactive=True)
126
+ avg_diff_x = avg_diff.cpu()
127
+ video_path = f"{uuid.uuid4()}.mp4"
128
+ print(video_path)
129
+ return x_concept_1, x_concept_2, avg_diff_x, export_to_video(images, video_path, fps=5), canvas, images, images[scale_middle], post_generation_slider_update, seed
130
+
131
+ def update_pre_generated_images(slider_value, total_images):
132
+ number_images = len(total_images) if total_images else 0
133
+ if number_images > 0:
134
+ scale_tuple = convert_to_centered_scale(number_images)
135
+ return total_images[scale_tuple.index(slider_value)][0]
136
+ else:
137
+ return None
138
+
139
+ def reset_recalc_directions():
140
+ return True
141
+
142
+ # Five "Time Stream" themed examples (one Korean example included)
143
+ examples = [
144
+ ["์‹ ์„ ํ•œ ํ† ๋งˆํ† ๊ฐ€ ๋ถ€ํŒจํ•œ ํ† ๋งˆํ† ๋กœ ๋ณ€ํ•ด๊ฐ€๋Š” ๊ณผ์ •", "Fresh", "Rotten", 2.0],
145
+ ["A blooming flower gradually withers into decay", "Bloom", "Wither", 1.5],
146
+ ["A vibrant cityscape transforms into a derelict ruin over time", "Modern", "Ruined", 2.5],
147
+ ["A lively forest slowly changes into an autumnal landscape", "Spring", "Autumn", 2.0],
148
+ ["A calm ocean evolves into a stormy seascape as time passes", "Calm", "Stormy", 3.0]
149
+ ]
150
+
151
+ # CSS for a bright and modern UI with a background image
152
+ css = """
153
+ /* Bright and modern UI with background image */
154
+ body {
155
+ background: #ffffff url('https://images.unsplash.com/photo-1506748686214-e9df14d4d9d0?ixlib=rb-1.2.1&auto=format&fit=crop&w=1600&q=80') no-repeat center center fixed;
156
+ background-size: cover;
157
+ font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
158
+ color: #333;
159
+ }
160
+ footer {
161
+ visibility: hidden;
162
+ }
163
+ .container {
164
+ max-width: 1200px;
165
+ margin: 20px auto;
166
+ padding: 0 10px;
167
+ }
168
+ .main-panel {
169
+ background-color: rgba(255, 255, 255, 0.9);
170
+ border-radius: 12px;
171
+ padding: 20px;
172
+ margin-bottom: 20px;
173
+ box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
174
+ }
175
+ .controls-panel {
176
+ background-color: rgba(255, 255, 255, 0.85);
177
+ border-radius: 8px;
178
+ padding: 16px;
179
+ box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.05);
180
+ }
181
+ .image-display {
182
+ min-height: 400px;
183
+ display: flex;
184
+ flex-direction: column;
185
+ justify-content: center;
186
+ }
187
+ .slider-container {
188
+ padding: 10px 0;
189
+ }
190
+ .advanced-panel {
191
+ margin-top: 20px;
192
+ border-top: 1px solid #eaeaea;
193
+ padding-top: 20px;
194
  }
 
195
  """
196
 
197
+ # ์—ฌ๊ธฐ์—์„œ show_api=False๋ฅผ ์ถ”๊ฐ€ํ•ด Gradio์˜ OpenAPI ์Šคํ‚ค๋งˆ ์ƒ์„ฑ์„ ๋น„ํ™œ์„ฑํ™”ํ•ฉ๋‹ˆ๋‹ค.
198
+ with gr.Blocks(css=css, title="Time Stream", show_api=False) as demo:
199
+ gr.Markdown("# Time Stream")
200
+
201
+ x_concept_1 = gr.State("")
202
+ x_concept_2 = gr.State("")
203
+ total_images = gr.State([])
204
+ avg_diff_x = gr.State()
205
+ recalc_directions = gr.State(False)
206
+
207
+ with gr.Row(elem_classes="container"):
208
+ # Left Column - Controls
209
+ with gr.Column(scale=4):
210
+ with gr.Group(elem_classes="main-panel"):
211
+ gr.Markdown("### Image Generation Controls")
212
+ with gr.Group(elem_classes="controls-panel"):
213
+ prompt = gr.Textbox(
214
+ label=english_labels["Prompt"],
215
+ info="Enter the description",
216
+ placeholder="A dog in the park",
217
+ lines=2
218
+ )
219
+ with gr.Row():
220
+ with gr.Column(scale=1):
221
+ concept_1 = gr.Textbox(
222
+ label=english_labels["1st direction to steer"],
223
+ info="Initial state",
224
+ placeholder="Fresh"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  )
226
+ with gr.Column(scale=1):
227
+ concept_2 = gr.Textbox(
228
+ label=english_labels["2nd direction to steer"],
229
+ info="Final state",
230
+ placeholder="Rotten"
231
+ )
232
+ with gr.Row(elem_classes="slider-container"):
233
+ x = gr.Slider(
234
+ minimum=0,
235
+ value=1.75,
236
+ step=0.1,
237
+ maximum=4.0,
238
+ label=english_labels["Strength"],
239
+ info="Maximum strength for each direction (above 2.5 may be unstable)"
240
+ )
241
+ submit = gr.Button(english_labels["Generate directions"], size="lg", variant="primary")
242
+ with gr.Accordion(label=english_labels["Advanced options"], open=False, elem_classes="advanced-panel"):
243
+ with gr.Row():
244
+ with gr.Column(scale=1):
245
+ interm_steps = gr.Slider(
246
+ label=english_labels["Num of intermediate images"],
247
+ minimum=3,
248
+ value=7,
249
+ maximum=65,
250
+ step=2
251
+ )
252
+ with gr.Column(scale=1):
253
+ guidance_scale = gr.Slider(
254
+ label=english_labels["Guidance scale"],
255
+ minimum=0.1,
256
+ maximum=10.0,
257
+ step=0.1,
258
+ value=3.5
259
+ )
260
+ with gr.Row():
261
+ with gr.Column(scale=1):
262
+ iterations = gr.Slider(
263
+ label=english_labels["Num iterations for clip directions"],
264
+ minimum=0,
265
+ value=200,
266
+ maximum=400,
267
+ step=1
268
+ )
269
+ with gr.Column(scale=1):
270
+ steps = gr.Slider(
271
+ label=english_labels["Num inference steps"],
272
+ minimum=1,
273
+ value=3,
274
+ maximum=4,
275
+ step=1
276
+ )
277
+ with gr.Row():
278
+ with gr.Column(scale=1):
279
+ randomize_seed = gr.Checkbox(
280
+ True,
281
+ label=english_labels["Randomize seed"]
282
+ )
283
+ with gr.Column(scale=1):
284
+ seed = gr.Slider(
285
+ minimum=0,
286
+ maximum=MAX_SEED,
287
+ step=1,
288
+ label=english_labels["Seed"],
289
+ interactive=True,
290
+ randomize=True
291
+ )
292
+ # Right Column - Output
293
+ with gr.Column(scale=8):
294
+ with gr.Group(elem_classes="main-panel"):
295
+ gr.Markdown("### Generated Results")
296
+ # Swapped order: Image strip on top, video below (video is larger)
297
+ image_strip = gr.Image(label="Image Strip", type="filepath", elem_id="strip", height=200)
298
+ output_video = gr.Video(label=english_labels["Looping video"], elem_id="video", loop=True, autoplay=True, height=600)
299
+ with gr.Row():
300
+ post_generation_image = gr.Image(
301
+ label=english_labels["Generated Images"],
302
+ type="filepath",
303
+ elem_id="interactive",
304
+ elem_classes="image-display"
305
+ )
306
+ post_generation_slider = gr.Slider(
307
+ minimum=-10,
308
+ maximum=10,
309
+ value=0,
310
+ step=1,
311
+ label=english_labels["From 1st to 2nd direction"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
+ # Examples
315
+ gr.Examples(
316
+ examples=examples,
317
+ inputs=[prompt, concept_1, concept_2, x]
318
+ )
319
 
320
+ # Event Handlers
321
+ submit.click(
322
+ fn=generate,
323
+ inputs=[
324
+ prompt, concept_1, concept_2, x, randomize_seed, seed,
325
+ recalc_directions, iterations, steps, interm_steps,
326
+ guidance_scale, x_concept_1, x_concept_2, avg_diff_x, total_images
327
+ ],
328
+ outputs=[
329
+ x_concept_1, x_concept_2, avg_diff_x,
330
+ output_video, # video output
331
+ image_strip, # canvas (image strip)
332
+ total_images,
333
+ post_generation_image,
334
+ post_generation_slider,
335
+ seed
336
+ ]
337
+ )
338
 
339
+ iterations.change(fn=reset_recalc_directions, outputs=[recalc_directions])
340
+ seed.change(fn=reset_recalc_directions, outputs=[recalc_directions])
341
+ post_generation_slider.change(
342
+ fn=update_pre_generated_images,
343
+ inputs=[post_generation_slider, total_images],
344
+ outputs=[post_generation_image],
345
+ queue=False,
346
+ show_progress="hidden",
347
+ concurrency_limit=None
348
+ )
349
+
350
+ if __name__ == "__main__":
351
+ # Gradio API ์Šคํ‚ค๋งˆ๋ฅผ ํ‘œ์‹œํ•˜์ง€ ์•Š์œผ๋ ค๋ฉด ์•„๋ž˜์™€ ๊ฐ™์ด show_api=False ์˜ต์…˜์„ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
352
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_api=False)