ginipick commited on
Commit
6bee32e
·
verified ·
1 Parent(s): 00fc055

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +503 -336
app.py CHANGED
@@ -1,347 +1,514 @@
1
- import os
2
- import uuid
3
- import gradio as gr
4
  import spaces
5
- from clip_slider_pipeline import CLIPSliderFlux
6
- from diffusers import FluxPipeline, AutoencoderTiny
 
7
  import torch
8
- import numpy as np
9
- import cv2
10
- from PIL import Image
11
- from diffusers.utils import load_image
12
- from diffusers.utils import export_to_video
13
- import random
14
-
15
- # English menu labels
16
- english_labels = {
17
- "Prompt": "Prompt",
18
- "1st direction to steer": "1st Direction",
19
- "2nd direction to steer": "2nd Direction",
20
- "Strength": "Strength",
21
- "Generate directions": "Generate Directions",
22
- "Generated Images": "Generated Images",
23
- "From 1st to 2nd direction": "From 1st to 2nd Direction",
24
- "Strip": "Image Strip",
25
- "Looping video": "Looping Video",
26
- "Advanced options": "Advanced Options",
27
- "Num of intermediate images": "Number of Intermediate Images",
28
- "Num iterations for clip directions": "Number of CLIP Direction Iterations",
29
- "Num inference steps": "Number of Inference Steps",
30
- "Guidance scale": "Guidance Scale",
31
- "Randomize seed": "Randomize Seed",
32
- "Seed": "Seed"
33
- }
34
 
35
- # Load pipelines
36
- base_model = "black-forest-labs/FLUX.1-schnell"
37
- taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=torch.bfloat16).to("cuda")
38
- pipe = FluxPipeline.from_pretrained(
39
- base_model,
40
- vae=taef1,
41
- torch_dtype=torch.bfloat16
42
- )
43
- pipe.transformer.to(memory_format=torch.channels_last)
44
- clip_slider = CLIPSliderFlux(pipe, device=torch.device("cuda"))
45
- MAX_SEED = 2**32 - 1
46
-
47
- def save_images_with_unique_filenames(image_list, save_directory):
48
- if not os.path.exists(save_directory):
49
- os.makedirs(save_directory)
50
- paths = []
51
- for image in image_list:
52
- unique_filename = f"{uuid.uuid4()}.png"
53
- file_path = os.path.join(save_directory, unique_filename)
54
- image.save(file_path)
55
- paths.append(file_path)
56
- return paths
57
-
58
- def convert_to_centered_scale(num):
59
- if num % 2 == 0: # even
60
- start = -(num // 2 - 1)
61
- end = num // 2
62
- else: # odd
63
- start = -(num // 2)
64
- end = num // 2
65
- return tuple(range(start, end + 1))
66
-
67
- def is_korean(text):
68
- """한글 포함 여부 확인"""
69
- return any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in text)
70
-
71
- @spaces.GPU(duration=85)
72
- def generate(prompt,
73
- concept_1,
74
- concept_2,
75
- scale,
76
- randomize_seed=True,
77
- seed=42,
78
- recalc_directions=True,
79
- iterations=200,
80
- steps=3,
81
- interm_steps=33,
82
- guidance_scale=3.5,
83
- x_concept_1="", x_concept_2="",
84
- avg_diff_x=None,
85
- total_images=[],
86
- gradio_progress=gr.Progress()):
87
- # Check if there is Korean text and warn if so
88
- if is_korean(prompt) or is_korean(concept_1) or is_korean(concept_2):
89
- print("Korean text detected. The model will use it directly without translation.")
90
-
91
- print(f"Prompt: {prompt}, ← {concept_2}, {concept_1} ➡️ . scale {scale}, interm steps {interm_steps}")
92
- slider_x = [concept_2, concept_1]
93
- if randomize_seed:
94
- seed = random.randint(0, MAX_SEED)
95
- if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]) or recalc_directions:
96
- gradio_progress(0, desc="Calculating directions...")
97
- avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1], num_iterations=iterations)
98
- x_concept_1, x_concept_2 = slider_x[0], slider_x[1]
99
- else:
100
- avg_diff = avg_diff_x
101
- images = []
102
- high_scale = scale
103
- low_scale = -1 * scale
104
- for i in gradio_progress.tqdm(range(interm_steps), desc="Generating images"):
105
- cur_scale = low_scale + (high_scale - low_scale) * i / (interm_steps - 1)
106
- image = clip_slider.generate(
107
- prompt,
108
- width=768,
109
- height=768,
110
- guidance_scale=guidance_scale,
111
- scale=cur_scale,
112
- seed=seed,
113
- num_inference_steps=steps,
114
- avg_diff=avg_diff
115
- )
116
- images.append(image)
117
- canvas = Image.new('RGB', (256 * interm_steps, 256))
118
- for i, im in enumerate(images):
119
- canvas.paste(im.resize((256, 256)), (256 * i, 0))
120
- comma_concepts_x = f"{slider_x[1]}, {slider_x[0]}"
121
- scale_total = convert_to_centered_scale(interm_steps)
122
- scale_min = scale_total[0]
123
- scale_max = scale_total[-1]
124
- scale_middle = scale_total.index(0)
125
- post_generation_slider_update = gr.update(label=comma_concepts_x, value=0, minimum=scale_min, maximum=scale_max, interactive=True)
126
- avg_diff_x = avg_diff.cpu()
127
- video_path = f"{uuid.uuid4()}.mp4"
128
- print(video_path)
129
- return x_concept_1, x_concept_2, avg_diff_x, export_to_video(images, video_path, fps=5), canvas, images, images[scale_middle], post_generation_slider_update, seed
130
-
131
- def update_pre_generated_images(slider_value, total_images):
132
- number_images = len(total_images) if total_images else 0
133
- if number_images > 0:
134
- scale_tuple = convert_to_centered_scale(number_images)
135
- return total_images[scale_tuple.index(slider_value)][0]
136
- else:
137
- return None
138
-
139
- def reset_recalc_directions():
140
- return True
141
-
142
- # Five "Time Stream" themed examples (one Korean example included)
143
- examples = [
144
- ["신선한 토마토가 부패한 토마토로 변해가는 과정", "Fresh", "Rotten", 2.0],
145
- ["A blooming flower gradually withers into decay", "Bloom", "Wither", 1.5],
146
- ["A vibrant cityscape transforms into a derelict ruin over time", "Modern", "Ruined", 2.5],
147
- ["A lively forest slowly changes into an autumnal landscape", "Spring", "Autumn", 2.0],
148
- ["A calm ocean evolves into a stormy seascape as time passes", "Calm", "Stormy", 3.0]
149
- ]
150
-
151
- css = """
152
- /* Bright and modern UI with background image */
153
- body {
154
- background: #ffffff url('https://images.unsplash.com/photo-1506748686214-e9df14d4d9d0?ixlib=rb-1.2.1&auto=format&fit=crop&w=1600&q=80') no-repeat center center fixed;
155
- background-size: cover;
156
- font-family: "Helvetica Neue", Helvetica, Arial, sans-serif;
157
- color: #333;
158
- }
159
- footer {
160
- visibility: hidden;
161
- }
162
- .container {
163
- max-width: 1200px;
164
- margin: 20px auto;
165
- padding: 0 10px;
166
- }
167
- .main-panel {
168
- background-color: rgba(255, 255, 255, 0.9);
169
- border-radius: 12px;
170
- padding: 20px;
171
- margin-bottom: 20px;
172
- box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
173
- }
174
- .controls-panel {
175
- background-color: rgba(255, 255, 255, 0.85);
176
- border-radius: 8px;
177
- padding: 16px;
178
- box-shadow: inset 0 2px 4px rgba(0, 0, 0, 0.05);
179
- }
180
- .image-display {
181
- min-height: 400px;
182
- display: flex;
183
- flex-direction: column;
184
- justify-content: center;
185
- }
186
- .slider-container {
187
- padding: 10px 0;
188
- }
189
- .advanced-panel {
190
- margin-top: 20px;
191
- border-top: 1px solid #eaeaea;
192
- padding-top: 20px;
193
  }
 
194
  """
195
 
196
- # 여기서는 show_api=False를 Blocks(...)에 넣지 않음
197
- with gr.Blocks(css=css, title="Time Stream") as demo:
198
- gr.Markdown("# Time Stream")
199
-
200
- x_concept_1 = gr.State("")
201
- x_concept_2 = gr.State("")
202
- total_images = gr.State([])
203
- avg_diff_x = gr.State()
204
- recalc_directions = gr.State(False)
205
-
206
- with gr.Row(elem_classes="container"):
207
- with gr.Column(scale=4):
208
- with gr.Group(elem_classes="main-panel"):
209
- gr.Markdown("### Image Generation Controls")
210
- with gr.Group(elem_classes="controls-panel"):
211
- prompt = gr.Textbox(
212
- label=english_labels["Prompt"],
213
- info="Enter the description",
214
- placeholder="A dog in the park",
215
- lines=2
216
- )
217
- with gr.Row():
218
- with gr.Column(scale=1):
219
- concept_1 = gr.Textbox(
220
- label=english_labels["1st direction to steer"],
221
- info="Initial state",
222
- placeholder="Fresh"
223
- )
224
- with gr.Column(scale=1):
225
- concept_2 = gr.Textbox(
226
- label=english_labels["2nd direction to steer"],
227
- info="Final state",
228
- placeholder="Rotten"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  )
230
- with gr.Row(elem_classes="slider-container"):
231
- x = gr.Slider(
232
- minimum=0,
233
- value=1.75,
234
- step=0.1,
235
- maximum=4.0,
236
- label=english_labels["Strength"],
237
- info="Maximum strength for each direction (above 2.5 may be unstable)"
238
- )
239
- submit = gr.Button(english_labels["Generate directions"], size="lg", variant="primary")
240
- with gr.Accordion(label=english_labels["Advanced options"], open=False, elem_classes="advanced-panel"):
241
- with gr.Row():
242
- with gr.Column(scale=1):
243
- interm_steps = gr.Slider(
244
- label=english_labels["Num of intermediate images"],
245
- minimum=3,
246
- value=7,
247
- maximum=65,
248
- step=2
249
- )
250
- with gr.Column(scale=1):
251
- guidance_scale = gr.Slider(
252
- label=english_labels["Guidance scale"],
253
- minimum=0.1,
254
- maximum=10.0,
255
- step=0.1,
256
- value=3.5
257
- )
258
- with gr.Row():
259
- with gr.Column(scale=1):
260
- iterations = gr.Slider(
261
- label=english_labels["Num iterations for clip directions"],
262
- minimum=0,
263
- value=200,
264
- maximum=400,
265
- step=1
266
- )
267
- with gr.Column(scale=1):
268
- steps = gr.Slider(
269
- label=english_labels["Num inference steps"],
270
- minimum=1,
271
- value=3,
272
- maximum=4,
273
- step=1
274
- )
275
- with gr.Row():
276
- with gr.Column(scale=1):
277
- randomize_seed = gr.Checkbox(
278
- True,
279
- label=english_labels["Randomize seed"]
280
- )
281
- with gr.Column(scale=1):
282
- seed = gr.Slider(
283
- minimum=0,
284
- maximum=MAX_SEED,
285
- step=1,
286
- label=english_labels["Seed"],
287
- interactive=True,
288
- randomize=True
289
- )
290
- # Right Column - Output
291
- with gr.Column(scale=8):
292
- with gr.Group(elem_classes="main-panel"):
293
- gr.Markdown("### Generated Results")
294
- image_strip = gr.Image(label="Image Strip", type="filepath", elem_id="strip", height=200)
295
- output_video = gr.Video(label=english_labels["Looping video"], elem_id="video", loop=True, autoplay=True, height=600)
296
- with gr.Row():
297
- post_generation_image = gr.Image(
298
- label=english_labels["Generated Images"],
299
- type="filepath",
300
- elem_id="interactive",
301
- elem_classes="image-display"
302
- )
303
- post_generation_slider = gr.Slider(
304
- minimum=-10,
305
- maximum=10,
306
- value=0,
307
- step=1,
308
- label=english_labels["From 1st to 2nd direction"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
- gr.Examples(
312
- examples=examples,
313
- inputs=[prompt, concept_1, concept_2, x]
314
- )
315
 
316
- submit.click(
317
- fn=generate,
318
- inputs=[
319
- prompt, concept_1, concept_2, x, randomize_seed, seed,
320
- recalc_directions, iterations, steps, interm_steps,
321
- guidance_scale, x_concept_1, x_concept_2, avg_diff_x, total_images
322
- ],
323
- outputs=[
324
- x_concept_1, x_concept_2, avg_diff_x,
325
- output_video,
326
- image_strip,
327
- total_images,
328
- post_generation_image,
329
- post_generation_slider,
330
- seed
331
- ]
332
- )
333
 
334
- iterations.change(fn=reset_recalc_directions, outputs=[recalc_directions])
335
- seed.change(fn=reset_recalc_directions, outputs=[recalc_directions])
336
- post_generation_slider.change(
337
- fn=update_pre_generated_images,
338
- inputs=[post_generation_slider, total_images],
339
- outputs=[post_generation_image],
340
- queue=False,
341
- show_progress="hidden",
342
- concurrency_limit=None
343
- )
344
-
345
- # demo.launch(...)에서만 show_api=False 설정
346
- if __name__ == "__main__":
347
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import argparse
 
 
2
  import spaces
3
+ from visualcloze import VisualClozeModel
4
+ import gradio as gr
5
+ import examples
6
  import torch
7
+ from functools import partial
8
+ from data.prefix_instruction import get_layout_instruction
9
+ from huggingface_hub import snapshot_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ # Define the missing variables here
12
+ GUIDANCE = """
13
+ ## How to use this demo:
14
+ 1. Select a task example from the right side, or prepare your own in-context examples and query.
15
+ 2. The grid will be filled with in-context examples and a query row.
16
+ 3. You can modify the task description or add content descriptions.
17
+ 4. Click "Generate" to create images following the pattern shown in examples.
18
+ """
19
+
20
+ NOTE = """
21
+ **Note:** The examples on the right side demonstrate various tasks.
22
+ Click on any example to load it into the interface. You can then modify images or prompts as needed.
23
+ """
24
+
25
+ CITATION = """
26
+ ## Paper Citation
27
+ ```
28
+ @article{liu2024visualcloze,
29
+ title={VisualCloze: A Universal Image Generation Framework via Visual In-Context Learning},
30
+ author={Liu, Zhaoyang and Lian, Yuheng and Wang, Jianfeng and Zhou, Aojun and Liu, Jiashi and Ye, Hang and Chen, Kai and Wang, Jingdong and Zhao, Deli},
31
+ journal={arXiv preprint arXiv:2504.07960},
32
+ year={2024}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  }
34
+ ```
35
  """
36
 
37
+ max_grid_h = 5
38
+ max_grid_w = 5
39
+ default_grid_h = 2
40
+ default_grid_w = 3
41
+ default_upsampling_noise = 0.4
42
+ default_steps = 30
43
+
44
+
45
+ def create_demo(model):
46
+ with gr.Blocks(title="VisualCloze Demo") as demo:
47
+ gr.Markdown("# VisualCloze: A Universal Image Generation Framework via Visual In-Context Learning")
48
+
49
+ gr.HTML("""
50
+ <div style="display:flex;column-gap:4px;">
51
+ <a href="https://github.com/lzyhha/VisualCloze">
52
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
53
+ </a>
54
+ <a href="https://visualcloze.github.io/">
55
+ <img src='https://img.shields.io/badge/Project-Website-green'>
56
+ </a>
57
+ <a href="https://arxiv.org/abs/2504.07960">
58
+ <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
59
+ </a>
60
+ <a href="https://huggingface.co/VisualCloze/VisualCloze">
61
+ <img src='https://img.shields.io/badge/VisualCloze%20checkpoint-HF%20Model-green?logoColor=violet&label=%F0%9F%A4%97%20Checkpoint'>
62
+ </a>
63
+ <a href="https://huggingface.co/datasets/VisualCloze/Graph200K">
64
+ <img src='https://img.shields.io/badge/VisualCloze%20datasets-HF%20Dataset-6B88E3?logoColor=violet&label=%F0%9F%A4%97%20Graph200k%20Dataset'>
65
+ </a>
66
+ </div>
67
+ """)
68
+
69
+ gr.Markdown(GUIDANCE)
70
+
71
+ # Pre-create all possible image components
72
+ all_image_inputs = []
73
+ rows = []
74
+ row_texts = []
75
+ with gr.Row():
76
+
77
+ with gr.Column(scale=2):
78
+ # Image grid
79
+ for i in range(max_grid_h):
80
+ # Add row label before each row
81
+ row_texts.append(gr.Markdown(
82
+ "## Query" if i == default_grid_h - 1 else f"## In-context Example {i + 1}",
83
+ elem_id=f"row_text_{i}",
84
+ visible=i < default_grid_h
85
+ ))
86
+ with gr.Row(visible=i < default_grid_h, elem_id=f"row_{i}") as row:
87
+ rows.append(row)
88
+ for j in range(max_grid_w):
89
+ img_input = gr.Image(
90
+ label=f"In-context Example {i + 1}/{j + 1}" if i != default_grid_h - 1 else f"Query {j + 1}",
91
+ type="pil",
92
+ visible= i < default_grid_h and j < default_grid_w,
93
+ interactive=True,
94
+ elem_id=f"img_{i}_{j}"
95
  )
96
+ all_image_inputs.append(img_input)
97
+
98
+ # Prompts
99
+ layout_prompt = gr.Textbox(
100
+ label="Layout Description (Auto-filled, Read-only)",
101
+ placeholder="Layout description will be automatically filled based on grid size...",
102
+ value=get_layout_instruction(default_grid_w, default_grid_h),
103
+ elem_id="layout_prompt",
104
+ interactive=False
105
+ )
106
+
107
+ task_prompt = gr.Textbox(
108
+ label="Task Description (Can be modified by referring to examples to perform custom tasks, but may lead to unstable results)",
109
+ placeholder="Describe what task should be performed...",
110
+ value="",
111
+ elem_id="task_prompt"
112
+ )
113
+
114
+ content_prompt = gr.Textbox(
115
+ label="(Optional) Content Description (Image caption, Editing instructions, etc.)",
116
+ placeholder="Describe the content requirements...",
117
+ value="",
118
+ elem_id="content_prompt"
119
+ )
120
+
121
+ generate_btn = gr.Button("Generate", elem_id="generate_btn")
122
+ gr.Markdown(NOTE)
123
+
124
+ grid_h = gr.Slider(minimum=0, maximum=max_grid_h-1, value=default_grid_h-1, step=1, label="Number of In-context Examples", elem_id="grid_h")
125
+ grid_w = gr.Slider(minimum=1, maximum=max_grid_w, value=default_grid_w, step=1, label="Task Columns", elem_id="grid_w")
126
+
127
+ with gr.Accordion("Advanced options", open=False):
128
+ seed = gr.Number(label="Seed (0 for random)", value=0, precision=0)
129
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=default_steps, step=1)
130
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=50.0, value=30, step=1)
131
+ upsampling_steps = gr.Slider(label="Upsampling steps (SDEdit)", minimum=1, maximum=100.0, value=10, step=1)
132
+ upsampling_noise = gr.Slider(label="Upsampling noise (SDEdit)", minimum=0, maximum=1.0, value=default_upsampling_noise, step=0.05)
133
+
134
+ gr.Markdown(CITATION)
135
+
136
+ # Output
137
+ with gr.Column(scale=2):
138
+ output_gallery = gr.Gallery(
139
+ label="Generated Results",
140
+ show_label=True,
141
+ elem_id="output_gallery",
142
+ columns=None,
143
+ rows=None,
144
+ height="auto",
145
+ allow_preview=True,
146
+ object_fit="contain"
147
+ )
148
+
149
+ gr.Markdown("# Task Examples")
150
+ gr.Markdown("Each click on a task may result in different examples.")
151
+ text_dense_prediction_tasks = gr.Textbox(label="Task", visible=False)
152
+ dense_prediction_tasks = gr.Dataset(
153
+ samples=examples.dense_prediction_text,
154
+ label='Dense Prediction',
155
+ samples_per_page=1000,
156
+ components=[text_dense_prediction_tasks])
157
+
158
+ text_conditional_generation_tasks = gr.Textbox(label="Task", visible=False)
159
+ conditional_generation_tasks = gr.Dataset(
160
+ samples=examples.conditional_generation_text,
161
+ label='Conditional Generation',
162
+ samples_per_page=1000,
163
+ components=[text_conditional_generation_tasks])
164
+
165
+ text_image_restoration_tasks = gr.Textbox(label="Task", visible=False)
166
+ image_restoration_tasks = gr.Dataset(
167
+ samples=examples.image_restoration_text,
168
+ label='Image Restoration',
169
+ samples_per_page=1000,
170
+ components=[text_image_restoration_tasks])
171
+
172
+ text_style_transfer_tasks = gr.Textbox(label="Task", visible=False)
173
+ style_transfer_tasks = gr.Dataset(
174
+ samples=examples.style_transfer_text,
175
+ label='Style Transfer',
176
+ samples_per_page=1000,
177
+ components=[text_style_transfer_tasks])
178
+
179
+ text_style_condition_fusion_tasks = gr.Textbox(label="Task", visible=False)
180
+ style_condition_fusion_tasks = gr.Dataset(
181
+ samples=examples.style_condition_fusion_text,
182
+ label='Style Condition Fusion',
183
+ samples_per_page=1000,
184
+ components=[text_style_condition_fusion_tasks])
185
+
186
+ text_tryon_tasks = gr.Textbox(label="Task", visible=False)
187
+ tryon_tasks = gr.Dataset(
188
+ samples=examples.tryon_text,
189
+ label='Virtual Try-On',
190
+ samples_per_page=1000,
191
+ components=[text_tryon_tasks])
192
+
193
+ text_relighting_tasks = gr.Textbox(label="Task", visible=False)
194
+ relighting_tasks = gr.Dataset(
195
+ samples=examples.relighting_text,
196
+ label='Relighting',
197
+ samples_per_page=1000,
198
+ components=[text_relighting_tasks])
199
+
200
+ text_photodoodle_tasks = gr.Textbox(label="Task", visible=False)
201
+ photodoodle_tasks = gr.Dataset(
202
+ samples=examples.photodoodle_text,
203
+ label='Photodoodle',
204
+ samples_per_page=1000,
205
+ components=[text_photodoodle_tasks])
206
+
207
+ text_editing_tasks = gr.Textbox(label="Task", visible=False)
208
+ editing_tasks = gr.Dataset(
209
+ samples=examples.editing_text,
210
+ label='Editing',
211
+ samples_per_page=1000,
212
+ components=[text_editing_tasks])
213
+
214
+ text_unseen_tasks = gr.Textbox(label="Task", visible=False)
215
+ unseen_tasks = gr.Dataset(
216
+ samples=examples.unseen_tasks_text,
217
+ label='Unseen Tasks (May produce unstable effects)',
218
+ samples_per_page=1000,
219
+ components=[text_unseen_tasks])
220
+
221
+ gr.Markdown("# Subject-driven Tasks Examples")
222
+ text_subject_driven_tasks = gr.Textbox(label="Task", visible=False)
223
+ subject_driven_tasks = gr.Dataset(
224
+ samples=examples.subject_driven_text,
225
+ label='Subject-driven Generation',
226
+ samples_per_page=1000,
227
+ components=[text_subject_driven_tasks])
228
+
229
+ text_condition_subject_fusion_tasks = gr.Textbox(label="Task", visible=False)
230
+ condition_subject_fusion_tasks = gr.Dataset(
231
+ samples=examples.condition_subject_fusion_text,
232
+ label='Condition+Subject Fusion',
233
+ samples_per_page=1000,
234
+ components=[text_condition_subject_fusion_tasks])
235
+
236
+ text_style_transfer_with_subject_tasks = gr.Textbox(label="Task", visible=False)
237
+ style_transfer_with_subject_tasks = gr.Dataset(
238
+ samples=examples.style_transfer_with_subject_text,
239
+ label='Style Transfer with Subject',
240
+ samples_per_page=1000,
241
+ components=[text_style_transfer_with_subject_tasks])
242
+
243
+ text_condition_subject_style_fusion_tasks = gr.Textbox(label="Task", visible=False)
244
+ condition_subject_style_fusion_tasks = gr.Dataset(
245
+ samples=examples.condition_subject_style_fusion_text,
246
+ label='Condition+Subject+Style Fusion',
247
+ samples_per_page=1000,
248
+ components=[text_condition_subject_style_fusion_tasks])
249
+
250
+ text_editing_with_subject_tasks = gr.Textbox(label="Task", visible=False)
251
+ editing_with_subject_tasks = gr.Dataset(
252
+ samples=examples.editing_with_subject_text,
253
+ label='Editing with Subject',
254
+ samples_per_page=1000,
255
+ components=[text_editing_with_subject_tasks])
256
+
257
+ text_image_restoration_with_subject_tasks = gr.Textbox(label="Task", visible=False)
258
+ image_restoration_with_subject_tasks = gr.Dataset(
259
+ samples=examples.image_restoration_with_subject_text,
260
+ label='Image Restoration with Subject',
261
+ samples_per_page=1000,
262
+ components=[text_image_restoration_with_subject_tasks])
263
+
264
+ def update_grid(h, w):
265
+ actual_h = h + 1
266
+ model.set_grid_size(actual_h, w)
267
+
268
+ updates = []
269
+
270
+ # Update image component visibility
271
+ for i in range(max_grid_h * max_grid_w):
272
+ curr_row = i // max_grid_w
273
+ curr_col = i % max_grid_w
274
+ updates.append(
275
+ gr.update(
276
+ label=f"In-context Example {curr_row + 1}/{curr_col + 1}" if curr_row != actual_h - 1 else f"Query {curr_col + 1}",
277
+ elem_id=f"img_{curr_row}_{curr_col}",
278
+ visible=(curr_row < actual_h and curr_col < w)))
279
+
280
+ # Update row visibility and labels
281
+ updates_row = []
282
+ updates_row_text = []
283
+ for i in range(max_grid_h):
284
+ updates_row.append(gr.update(f"row_{i}", visible=(i < actual_h)))
285
+ updates_row_text.append(
286
+ gr.update(
287
+ elem_id=f"row_text_{i}",
288
+ visible=i < actual_h,
289
+ value="## Query" if i == actual_h - 1 else f"## In-context Example {i + 1}",
290
  )
291
+ )
292
+
293
+ updates.extend(updates_row)
294
+ updates.extend(updates_row_text)
295
+ updates.append(gr.update(elem_id="layout_prompt", value=get_layout_instruction(w, actual_h)))
296
+ return updates
297
+
298
+ def generate_image(*inputs):
299
+ images = []
300
+ if grid_h.value + 1 != model.grid_h or grid_w.value != model.grid_w:
301
+ raise gr.Error('Please wait for the loading to complete.')
302
+ for i in range(model.grid_h):
303
+ images.append([])
304
+ for j in range(model.grid_w):
305
+ images[i].append(inputs[i * max_grid_w + j])
306
+ if i != model.grid_h - 1:
307
+ if inputs[i * max_grid_w + j] is None:
308
+ raise gr.Error('Please upload in-context examples. Possible that the task examples have not finished loading yet, and you can try waiting a few seconds before clicking the button again.')
309
+ seed, cfg, steps, upsampling_steps, upsampling_noise, layout_text, task_text, content_text = inputs[-8:]
310
+
311
+ try:
312
+ results = generate(
313
+ images,
314
+ [layout_text, task_text, content_text],
315
+ seed=seed, cfg=cfg, steps=steps,
316
+ upsampling_steps=upsampling_steps, upsampling_noise=upsampling_noise
317
+ )
318
+ except Exception as e:
319
+ raise gr.Error('Process error. Possible that the task examples have not finished loading yet, and you can try waiting a few seconds before clicking the button again. Error: ' + str(e))
320
+
321
+ output = gr.update(
322
+ elem_id='output_gallery',
323
+ value=results,
324
+ columns=min(len(results), 2),
325
+ rows=int(len(results) / 2 + 0.5))
326
+
327
+ return output
328
+
329
+ def process_tasks(task, func):
330
+ outputs = func(task)
331
+ mask = outputs[0]
332
+ state = outputs[1:8]
333
+ if state[5] is None:
334
+ state[5] = default_upsampling_noise
335
+ if state[6] is None:
336
+ state[6] = default_steps
337
+ images = outputs[8:-len(mask)]
338
+ output = outputs[-len(mask):]
339
+ for i in range(len(mask)):
340
+ if mask[i] == 1:
341
+ images.append(None)
342
+ else:
343
+ images.append(output[-len(mask) + i])
344
+
345
+ state[0] = state[0] - 1
346
+ cur_hrid_h = state[0]
347
+ cur_hrid_w = state[1]
348
+
349
+ current_example = [None] * 25
350
+ for i, image in enumerate(images):
351
+ pos = (i // cur_hrid_w) * 5 + (i % cur_hrid_w)
352
+ if image is not None:
353
+ current_example[pos] = image
354
+ update_grid(cur_hrid_h, cur_hrid_w)
355
+ output = gr.update(
356
+ elem_id='output_gallery',
357
+ value=[o for o, m in zip(output, mask) if m == 1],
358
+ columns=min(sum(mask), 2),
359
+ rows=int(sum(mask) / 2 + 0.5))
360
+ return [output] + current_example + state
361
+
362
+ dense_prediction_tasks.click(
363
+ partial(process_tasks, func=examples.process_dense_prediction_tasks),
364
+ inputs=[dense_prediction_tasks],
365
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
366
+ show_progress="full",
367
+ show_progress_on=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps] + [generate_btn])
368
+
369
+ conditional_generation_tasks.click(
370
+ partial(process_tasks, func=examples.process_conditional_generation_tasks),
371
+ inputs=[conditional_generation_tasks],
372
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
373
+ show_progress="full")
374
+
375
+ image_restoration_tasks.click(
376
+ partial(process_tasks, func=examples.process_image_restoration_tasks),
377
+ inputs=[image_restoration_tasks],
378
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
379
+ show_progress="full")
380
+
381
+ style_transfer_tasks.click(
382
+ partial(process_tasks, func=examples.process_style_transfer_tasks),
383
+ inputs=[style_transfer_tasks],
384
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
385
+ show_progress="full")
386
+
387
+ style_condition_fusion_tasks.click(
388
+ partial(process_tasks, func=examples.process_style_condition_fusion_tasks),
389
+ inputs=[style_condition_fusion_tasks],
390
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
391
+ show_progress="full")
392
+
393
+ relighting_tasks.click(
394
+ partial(process_tasks, func=examples.process_relighting_tasks),
395
+ inputs=[relighting_tasks],
396
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
397
+ show_progress="full")
398
+
399
+ tryon_tasks.click(
400
+ partial(process_tasks, func=examples.process_tryon_tasks),
401
+ inputs=[tryon_tasks],
402
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
403
+ show_progress="full")
404
+
405
+ photodoodle_tasks.click(
406
+ partial(process_tasks, func=examples.process_photodoodle_tasks),
407
+ inputs=[photodoodle_tasks],
408
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
409
+ show_progress="full")
410
+
411
+ editing_tasks.click(
412
+ partial(process_tasks, func=examples.process_editing_tasks),
413
+ inputs=[editing_tasks],
414
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
415
+ show_progress="full")
416
+
417
+ unseen_tasks.click(
418
+ partial(process_tasks, func=examples.process_unseen_tasks),
419
+ inputs=[unseen_tasks],
420
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
421
+ show_progress="full")
422
+
423
+ subject_driven_tasks.click(
424
+ partial(process_tasks, func=examples.process_subject_driven_tasks),
425
+ inputs=[subject_driven_tasks],
426
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
427
+ show_progress="full")
428
+
429
+ style_transfer_with_subject_tasks.click(
430
+ partial(process_tasks, func=examples.process_style_transfer_with_subject_tasks),
431
+ inputs=[style_transfer_with_subject_tasks],
432
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
433
+ show_progress="full")
434
+
435
+ condition_subject_fusion_tasks.click(
436
+ partial(process_tasks, func=examples.process_condition_subject_fusion_tasks),
437
+ inputs=[condition_subject_fusion_tasks],
438
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
439
+ show_progress="full")
440
+
441
+ condition_subject_style_fusion_tasks.click(
442
+ partial(process_tasks, func=examples.process_condition_subject_style_fusion_tasks),
443
+ inputs=[condition_subject_style_fusion_tasks],
444
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
445
+ show_progress="full")
446
+
447
+ editing_with_subject_tasks.click(
448
+ partial(process_tasks, func=examples.process_editing_with_subject_tasks),
449
+ inputs=[editing_with_subject_tasks],
450
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
451
+ show_progress="full")
452
+
453
+ image_restoration_with_subject_tasks.click(
454
+ partial(process_tasks, func=examples.process_image_restoration_with_subject_tasks),
455
+ inputs=[image_restoration_with_subject_tasks],
456
+ outputs=[output_gallery] + all_image_inputs + [grid_h, grid_w, layout_prompt, task_prompt, content_prompt, upsampling_noise, steps],
457
+ show_progress="full")
458
+ # Initialize grid
459
+ model.set_grid_size(default_grid_h, default_grid_w)
460
+
461
+ # Connect event processing function to all components that need updating
462
+ output_components = all_image_inputs + rows + row_texts + [layout_prompt]
463
+
464
+ grid_h.change(fn=update_grid, inputs=[grid_h, grid_w], outputs=output_components)
465
+ grid_w.change(fn=update_grid, inputs=[grid_h, grid_w], outputs=output_components)
466
+
467
+ # Modify generate button click event
468
+ generate_btn.click(
469
+ fn=generate_image,
470
+ inputs=all_image_inputs + [seed, cfg, steps, upsampling_steps, upsampling_noise] + [layout_prompt, task_prompt, content_prompt],
471
+ outputs=output_gallery
472
+ )
473
+
474
+ return demo
475
+
476
+
477
+ @spaces.GPU()
478
+ def generate(
479
+ images,
480
+ prompts,
481
+ seed, cfg, steps,
482
+ upsampling_steps, upsampling_noise):
483
+ with torch.no_grad():
484
+ return model.process_images(
485
+ images=images,
486
+ prompts=prompts,
487
+ seed=seed,
488
+ cfg=cfg,
489
+ steps=steps,
490
+ upsampling_steps=upsampling_steps,
491
+ upsampling_noise=upsampling_noise)
492
+
493
+
494
+ def parse_args():
495
+ parser = argparse.ArgumentParser()
496
+ parser.add_argument("--model_path", type=str, default="checkpoints/visualcloze-384-lora.pth")
497
+ parser.add_argument("--precision", type=str, choices=["fp32", "bf16", "fp16"], default="bf16")
498
+ parser.add_argument("--resolution", type=int, default=384)
499
+ return parser.parse_args()
500
+
501
+
502
+ if __name__ == "__main__":
503
+ args = parse_args()
504
 
505
+ snapshot_download(repo_id="VisualCloze/VisualCloze", repo_type="model", local_dir="checkpoints")
506
+
507
+ # Initialize model
508
+ model = VisualClozeModel(resolution=args.resolution, model_path=args.model_path, precision=args.precision)
509
 
510
+ # Create Gradio demo
511
+ demo = create_demo(model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
 
513
+ # Start Gradio server
514
+ demo.launch()