multimodalart HF staff commited on
Commit
77ca2e1
1 Parent(s): 461d656

Suggested UI and ZeroGPU compatibility

Browse files
Files changed (1) hide show
  1. app.py +174 -131
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
  import torch
 
3
  from PIL import Image
4
  import os
5
  from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast
@@ -10,11 +11,10 @@ import torch.nn as nn
10
  import math
11
  import logging
12
  import sys
13
- from huggingface_hub import snapshot_download
14
  from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
15
- import spaces
16
 
17
- # 设置日志
18
  logging.basicConfig(
19
  level=logging.INFO,
20
  format='%(asctime)s - %(levelname)s - %(message)s',
@@ -24,10 +24,28 @@ logger = logging.getLogger(__name__)
24
 
25
  MODEL_ID = "Djrango/Qwen2vl-Flux"
26
  MODEL_CACHE_DIR = "model_cache"
27
- device = "cuda" if torch.cuda.is_available() else "cpu"
28
- dtype = torch.bfloat16
29
 
30
- # 预下载模型
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  if not os.path.exists(MODEL_CACHE_DIR):
32
  logger.info("Starting model download...")
33
  try:
@@ -41,68 +59,70 @@ if not os.path.exists(MODEL_CACHE_DIR):
41
  logger.error(f"Error downloading models: {str(e)}")
42
  raise
43
 
44
- # 加载小模型到 GPU
45
- logger.info("Loading small models to GPU...")
 
 
46
  tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer"))
47
  text_encoder = CLIPTextModel.from_pretrained(
48
  os.path.join(MODEL_CACHE_DIR, "flux/text_encoder")
49
- ).to(dtype).to(device)
50
 
51
  text_encoder_two = T5EncoderModel.from_pretrained(
52
  os.path.join(MODEL_CACHE_DIR, "flux/text_encoder_2")
53
- ).to(dtype).to(device)
54
 
55
  tokenizer_two = T5TokenizerFast.from_pretrained(
56
- os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2"))
 
57
 
58
- # 大模型初始加载到 CPU
59
- logger.info("Loading large models to CPU...")
60
  vae = AutoencoderKL.from_pretrained(
61
  os.path.join(MODEL_CACHE_DIR, "flux/vae")
62
- ).to(dtype).cpu()
63
 
64
  transformer = FluxTransformer2DModel.from_pretrained(
65
  os.path.join(MODEL_CACHE_DIR, "flux/transformer")
66
- ).to(dtype).cpu()
67
 
68
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
69
  os.path.join(MODEL_CACHE_DIR, "flux/scheduler"),
70
  shift=1
71
  )
72
 
 
73
  qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
74
  os.path.join(MODEL_CACHE_DIR, "qwen2-vl")
75
- ).to(dtype).cpu()
76
-
77
- qwen2vl_processor = AutoProcessor.from_pretrained(
78
- MODEL_ID,
79
- subfolder="qwen2-vl",
80
- min_pixels=256*28*28,
81
- max_pixels=256*28*28
82
- )
83
-
84
- # 加载 connector 和 embedder 到 CPU
85
- class Qwen2Connector(nn.Module):
86
- def __init__(self, input_dim=3584, output_dim=4096):
87
- super().__init__()
88
- self.linear = nn.Linear(input_dim, output_dim)
89
-
90
- def forward(self, x):
91
- return self.linear(x)
92
 
93
- connector = Qwen2Connector().to(dtype).cpu()
 
94
  connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
95
  connector_state = torch.load(connector_path, map_location='cpu')
96
- connector_state = {k.replace('module.', ''): v.to(dtype) for k, v in connector_state.items()}
97
  connector.load_state_dict(connector_state)
98
 
99
- t5_context_embedder = nn.Linear(4096, 3072).to(dtype).cpu()
100
  t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
101
  t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
102
- t5_embedder_state = {k: v.to(dtype) for k, v in t5_embedder_state.items()}
103
  t5_context_embedder.load_state_dict(t5_embedder_state)
104
 
105
- # 创建pipeline (先用CPU上的模型)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  pipeline = FluxPipeline(
107
  transformer=transformer,
108
  scheduler=scheduler,
@@ -111,30 +131,14 @@ pipeline = FluxPipeline(
111
  tokenizer=tokenizer,
112
  )
113
 
114
- # 设置所有模型为eval模式
115
- for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl,
116
- connector, t5_context_embedder]:
117
- model.requires_grad_(False)
118
- model.eval()
119
-
120
- # Aspect ratio options
121
- ASPECT_RATIOS = {
122
- "1:1": (1024, 1024),
123
- "16:9": (1344, 768),
124
- "9:16": (768, 1344),
125
- "2.4:1": (1536, 640),
126
- "3:4": (896, 1152),
127
- "4:3": (1152, 896),
128
- }
129
-
130
  def process_image(image):
131
  """Process image with Qwen2VL model"""
132
  try:
133
- # Qwen2VL 相关模型移到 GPU
134
  logger.info("Moving Qwen2VL models to GPU...")
135
- qwen2vl.to(device)
136
- connector.to(device)
137
-
138
  message = [
139
  {
140
  "role": "user",
@@ -156,27 +160,42 @@ def process_image(image):
156
  images=[image],
157
  padding=True,
158
  return_tensors="pt"
159
- ).to(device)
160
 
161
  output_hidden_state, image_token_mask, image_grid_thw = qwen2vl(**inputs)
162
  image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
163
  image_hidden_state = connector(image_hidden_state)
164
 
165
- # 保存结果到 CPU
166
  result = (image_hidden_state.cpu(), image_grid_thw)
167
-
168
- # 将模型移回 CPU 并清理显存
169
- logger.info("Moving Qwen2VL models back to CPU...")
170
- qwen2vl.cpu()
171
- connector.cpu()
172
- torch.cuda.empty_cache()
173
-
174
- return result
175
 
176
  except Exception as e:
177
  logger.error(f"Error in process_image: {str(e)}")
178
  raise
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  def compute_t5_text_embeddings(prompt):
181
  """Compute T5 embeddings for text prompt"""
182
  if prompt == "":
@@ -188,21 +207,15 @@ def compute_t5_text_embeddings(prompt):
188
  max_length=256,
189
  truncation=True,
190
  return_tensors="pt"
191
- ).to(device)
192
 
193
  prompt_embeds = text_encoder_two(text_inputs.input_ids)[0]
194
-
195
- # 将 t5_context_embedder 移到 GPU
196
- t5_context_embedder.to(device)
197
- prompt_embeds = t5_context_embedder(prompt_embeds)
198
-
199
- # 将 t5_context_embedder 移回 CPU
200
  t5_context_embedder.cpu()
201
 
202
  return prompt_embeds
203
 
204
  def compute_text_embeddings(prompt=""):
205
- """Compute text embeddings for the prompt"""
206
  with torch.no_grad():
207
  text_inputs = tokenizer(
208
  prompt,
@@ -210,18 +223,17 @@ def compute_text_embeddings(prompt=""):
210
  max_length=77,
211
  truncation=True,
212
  return_tensors="pt"
213
- ).to(device)
214
 
215
  prompt_embeds = text_encoder(
216
  text_inputs.input_ids,
217
  output_hidden_states=False
218
  )
219
- return prompt_embeds.pooler_output
 
220
 
221
- @spaces.GPU(duration=120) # 使用ZeroGPU装饰器
222
- def generate_images(input_image, prompt="", guidance_scale=3.5,
223
- num_inference_steps=28, num_images=1, seed=None, aspect_ratio="1:1"):
224
- """Generate images using the pipeline"""
225
  try:
226
  logger.info(f"Starting generation with prompt: {prompt}")
227
 
@@ -233,31 +245,34 @@ def generate_images(input_image, prompt="", guidance_scale=3.5,
233
  logger.info(f"Set random seed to: {seed}")
234
 
235
  # Process image with Qwen2VL
 
236
  qwen2_hidden_state, image_grid_thw = process_image(input_image)
 
237
 
238
  # Compute text embeddings
 
239
  pooled_prompt_embeds = compute_text_embeddings(prompt)
240
  t5_prompt_embeds = compute_t5_text_embeddings(prompt)
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  # Get dimensions
243
  width, height = ASPECT_RATIOS[aspect_ratio]
244
  logger.info(f"Using dimensions: {width}x{height}")
245
 
246
- # Generate images
247
  try:
248
  logger.info("Starting image generation...")
249
-
250
- # 将 Transformer 和 VAE 移到 GPU
251
- logger.info("Moving Transformer and VAE to GPU...")
252
- transformer.to(device)
253
- vae.to(device)
254
-
255
- # 更新 pipeline 中的模型引用
256
- pipeline.transformer = transformer
257
- pipeline.vae = vae
258
-
259
  output_images = pipeline(
260
- prompt_embeds=qwen2_hidden_state.to(device).repeat(num_images, 1, 1),
261
  pooled_prompt_embeds=pooled_prompt_embeds,
262
  t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
263
  num_inference_steps=num_inference_steps,
@@ -265,15 +280,8 @@ def generate_images(input_image, prompt="", guidance_scale=3.5,
265
  height=height,
266
  width=width,
267
  ).images
268
-
269
  logger.info("Image generation completed")
270
 
271
- # 将 Transformer 和 VAE 移回 CPU
272
- logger.info("Moving models back to CPU...")
273
- #transformer.cpu()
274
- #vae.cpu()
275
- torch.cuda.empty_cache()
276
-
277
  return output_images
278
 
279
  except Exception as e:
@@ -287,19 +295,32 @@ def generate_images(input_image, prompt="", guidance_scale=3.5,
287
  with gr.Blocks(
288
  theme=gr.themes.Soft(),
289
  css="""
290
- .container { max-width: 1200px; margin: auto; padding: 0 20px; }
291
- .header { text-align: center; margin: 20px 0 40px 0; padding: 20px; background: #f7f7f7; border-radius: 12px; }
292
- .param-row { padding: 10px 0; }
293
- footer { margin-top: 40px; padding: 20px; border-top: 1px solid #eee; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  """
295
  ) as demo:
296
  with gr.Column(elem_classes="container"):
297
- gr.Markdown("""
298
- <div class="header">
299
- # 🎨 Qwen2vl-Flux Image Variation Demo
300
- Generate creative variations of your images with optional text guidance
301
- </div>
302
- """)
303
 
304
  with gr.Row(equal_height=True):
305
  with gr.Column(scale=1):
@@ -309,14 +330,13 @@ with gr.Blocks(
309
  height=384,
310
  sources=["upload", "clipboard"]
311
  )
312
-
 
 
 
 
313
  with gr.Accordion("Advanced Settings", open=False):
314
  with gr.Group():
315
- prompt = gr.Textbox(
316
- label="Text Prompt (Optional)",
317
- placeholder="As Long As Possible...",
318
- lines=3
319
- )
320
 
321
  with gr.Row(elem_classes="param-row"):
322
  guidance = gr.Slider(
@@ -324,38 +344,48 @@ with gr.Blocks(
324
  maximum=10,
325
  value=3.5,
326
  step=0.5,
327
- label="Guidance Scale"
 
328
  )
329
  steps = gr.Slider(
330
  minimum=1,
331
- maximum=30,
332
  value=28,
333
  step=1,
334
- label="Sampling Steps"
 
335
  )
336
 
337
  with gr.Row(elem_classes="param-row"):
338
  num_images = gr.Slider(
339
  minimum=1,
340
- maximum=2,
341
- value=1, # 默认改为1
342
  step=1,
343
- label="Number of Images"
 
344
  )
345
  seed = gr.Number(
346
  label="Random Seed",
347
  value=None,
348
- precision=0
 
349
  )
350
  aspect_ratio = gr.Radio(
351
  label="Aspect Ratio",
352
  choices=["1:1", "16:9", "9:16", "2.4:1", "3:4", "4:3"],
353
- value="1:1"
 
354
  )
355
 
356
- submit_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
 
 
 
 
357
 
358
  with gr.Column(scale=1):
 
359
  output_gallery = gr.Gallery(
360
  label="Generated Variations",
361
  columns=2,
@@ -363,11 +393,23 @@ with gr.Blocks(
363
  height=700,
364
  object_fit="contain",
365
  show_label=True,
366
- allow_preview=True
 
367
  )
 
368
 
 
 
 
 
 
 
 
 
 
 
369
  submit_btn.click(
370
- fn=generate_images,
371
  inputs=[
372
  input_image,
373
  prompt,
@@ -376,14 +418,15 @@ with gr.Blocks(
376
  num_images,
377
  seed,
378
  aspect_ratio
379
- ],
380
  outputs=[output_gallery],
381
  show_progress=True
382
  )
383
 
 
384
  if __name__ == "__main__":
385
  demo.launch(
386
- server_name="0.0.0.0",
387
- server_port=7860,
388
- share=False
389
  )
 
1
  import gradio as gr
2
  import torch
3
+ import spaces
4
  from PIL import Image
5
  import os
6
  from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast
 
11
  import math
12
  import logging
13
  import sys
 
14
  from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
15
+ from huggingface_hub import snapshot_download
16
 
17
+ # Set up logging
18
  logging.basicConfig(
19
  level=logging.INFO,
20
  format='%(asctime)s - %(levelname)s - %(message)s',
 
24
 
25
  MODEL_ID = "Djrango/Qwen2vl-Flux"
26
  MODEL_CACHE_DIR = "model_cache"
27
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
+ DTYPE = torch.bfloat16
29
 
30
+ # Aspect ratio options
31
+ ASPECT_RATIOS = {
32
+ "1:1": (1024, 1024),
33
+ "16:9": (1344, 768),
34
+ "9:16": (768, 1344),
35
+ "2.4:1": (1536, 640),
36
+ "3:4": (896, 1152),
37
+ "4:3": (1152, 896),
38
+ }
39
+
40
+ class Qwen2Connector(nn.Module):
41
+ def __init__(self, input_dim=3584, output_dim=4096):
42
+ super().__init__()
43
+ self.linear = nn.Linear(input_dim, output_dim)
44
+
45
+ def forward(self, x):
46
+ return self.linear(x)
47
+
48
+ # Download models if not present
49
  if not os.path.exists(MODEL_CACHE_DIR):
50
  logger.info("Starting model download...")
51
  try:
 
59
  logger.error(f"Error downloading models: {str(e)}")
60
  raise
61
 
62
+ # Initialize models in global context
63
+ logger.info("Starting model loading...")
64
+
65
+ # Load smaller models to GPU
66
  tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer"))
67
  text_encoder = CLIPTextModel.from_pretrained(
68
  os.path.join(MODEL_CACHE_DIR, "flux/text_encoder")
69
+ ).to(DTYPE).to(DEVICE)
70
 
71
  text_encoder_two = T5EncoderModel.from_pretrained(
72
  os.path.join(MODEL_CACHE_DIR, "flux/text_encoder_2")
73
+ ).to(DTYPE).to(DEVICE)
74
 
75
  tokenizer_two = T5TokenizerFast.from_pretrained(
76
+ os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2")
77
+ )
78
 
79
+ # Load larger models to CPU
 
80
  vae = AutoencoderKL.from_pretrained(
81
  os.path.join(MODEL_CACHE_DIR, "flux/vae")
82
+ ).to(DTYPE).cpu()
83
 
84
  transformer = FluxTransformer2DModel.from_pretrained(
85
  os.path.join(MODEL_CACHE_DIR, "flux/transformer")
86
+ ).to(DTYPE).cpu()
87
 
88
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
89
  os.path.join(MODEL_CACHE_DIR, "flux/scheduler"),
90
  shift=1
91
  )
92
 
93
+ # Load Qwen2VL to CPU
94
  qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
95
  os.path.join(MODEL_CACHE_DIR, "qwen2-vl")
96
+ ).to(DTYPE).cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ # Load connector and embedder
99
+ connector = Qwen2Connector().to(DTYPE).cpu()
100
  connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
101
  connector_state = torch.load(connector_path, map_location='cpu')
102
+ connector_state = {k.replace('module.', ''): v.to(DTYPE) for k, v in connector_state.items()}
103
  connector.load_state_dict(connector_state)
104
 
105
+ t5_context_embedder = nn.Linear(4096, 3072).to(DTYPE).cpu()
106
  t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
107
  t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
108
+ t5_embedder_state = {k: v.to(DTYPE) for k, v in t5_embedder_state.items()}
109
  t5_context_embedder.load_state_dict(t5_embedder_state)
110
 
111
+ # Set all models to eval mode
112
+ for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl, connector, t5_context_embedder]:
113
+ model.requires_grad_(False)
114
+ model.eval()
115
+
116
+ logger.info("All models loaded successfully")
117
+
118
+ # Initialize processors and pipeline
119
+ qwen2vl_processor = AutoProcessor.from_pretrained(
120
+ MODEL_ID,
121
+ subfolder="qwen2-vl",
122
+ min_pixels=256*28*28,
123
+ max_pixels=256*28*28
124
+ )
125
+
126
  pipeline = FluxPipeline(
127
  transformer=transformer,
128
  scheduler=scheduler,
 
131
  tokenizer=tokenizer,
132
  )
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  def process_image(image):
135
  """Process image with Qwen2VL model"""
136
  try:
137
+ # Move Qwen2VL models to GPU
138
  logger.info("Moving Qwen2VL models to GPU...")
139
+ qwen2vl.to(DEVICE)
140
+ connector.to(DEVICE)
141
+
142
  message = [
143
  {
144
  "role": "user",
 
160
  images=[image],
161
  padding=True,
162
  return_tensors="pt"
163
+ ).to(DEVICE)
164
 
165
  output_hidden_state, image_token_mask, image_grid_thw = qwen2vl(**inputs)
166
  image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
167
  image_hidden_state = connector(image_hidden_state)
168
 
 
169
  result = (image_hidden_state.cpu(), image_grid_thw)
170
+
171
+ # Move models back to CPU
172
+ qwen2vl.cpu()
173
+ connector.cpu()
174
+ torch.cuda.empty_cache()
175
+
176
+ return result
 
177
 
178
  except Exception as e:
179
  logger.error(f"Error in process_image: {str(e)}")
180
  raise
181
 
182
+ def resize_image(img, max_pixels=1050000):
183
+ if not isinstance(img, Image.Image):
184
+ img = Image.fromarray(img)
185
+
186
+ width, height = img.size
187
+ num_pixels = width * height
188
+
189
+ if num_pixels > max_pixels:
190
+ scale = math.sqrt(max_pixels / num_pixels)
191
+ new_width = int(width * scale)
192
+ new_height = int(height * scale)
193
+ new_width = new_width - (new_width % 8)
194
+ new_height = new_height - (new_height % 8)
195
+ img = img.resize((new_width, new_height), Image.LANCZOS)
196
+
197
+ return img
198
+
199
  def compute_t5_text_embeddings(prompt):
200
  """Compute T5 embeddings for text prompt"""
201
  if prompt == "":
 
207
  max_length=256,
208
  truncation=True,
209
  return_tensors="pt"
210
+ ).to(DEVICE)
211
 
212
  prompt_embeds = text_encoder_two(text_inputs.input_ids)[0]
213
+ prompt_embeds = t5_context_embedder.to(DEVICE)(prompt_embeds)
 
 
 
 
 
214
  t5_context_embedder.cpu()
215
 
216
  return prompt_embeds
217
 
218
  def compute_text_embeddings(prompt=""):
 
219
  with torch.no_grad():
220
  text_inputs = tokenizer(
221
  prompt,
 
223
  max_length=77,
224
  truncation=True,
225
  return_tensors="pt"
226
+ ).to(DEVICE)
227
 
228
  prompt_embeds = text_encoder(
229
  text_inputs.input_ids,
230
  output_hidden_states=False
231
  )
232
+ pooled_prompt_embeds = prompt_embeds.pooler_output
233
+ return pooled_prompt_embeds
234
 
235
+ @spaces.GPU(duration=75)
236
+ def generate(input_image, prompt="", guidance_scale=3.5, num_inference_steps=28, num_images=2, seed=None, aspect_ratio="1:1", progress=gr.Progress(track_tqdm=True)):
 
 
237
  try:
238
  logger.info(f"Starting generation with prompt: {prompt}")
239
 
 
245
  logger.info(f"Set random seed to: {seed}")
246
 
247
  # Process image with Qwen2VL
248
+ logger.info("Processing input image with Qwen2VL...")
249
  qwen2_hidden_state, image_grid_thw = process_image(input_image)
250
+ logger.info("Image processing completed")
251
 
252
  # Compute text embeddings
253
+ logger.info("Computing text embeddings...")
254
  pooled_prompt_embeds = compute_text_embeddings(prompt)
255
  t5_prompt_embeds = compute_t5_text_embeddings(prompt)
256
+ logger.info("Text embeddings computed")
257
+
258
+ # Move Transformer and VAE to GPU
259
+ logger.info("Moving Transformer and VAE to GPU...")
260
+ transformer.to(DEVICE)
261
+ vae.to(DEVICE)
262
+
263
+ # Update pipeline models
264
+ pipeline.transformer = transformer
265
+ pipeline.vae = vae
266
+ logger.info("Models moved to GPU")
267
 
268
  # Get dimensions
269
  width, height = ASPECT_RATIOS[aspect_ratio]
270
  logger.info(f"Using dimensions: {width}x{height}")
271
 
 
272
  try:
273
  logger.info("Starting image generation...")
 
 
 
 
 
 
 
 
 
 
274
  output_images = pipeline(
275
+ prompt_embeds=qwen2_hidden_state.to(DEVICE).repeat(num_images, 1, 1),
276
  pooled_prompt_embeds=pooled_prompt_embeds,
277
  t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
278
  num_inference_steps=num_inference_steps,
 
280
  height=height,
281
  width=width,
282
  ).images
 
283
  logger.info("Image generation completed")
284
 
 
 
 
 
 
 
285
  return output_images
286
 
287
  except Exception as e:
 
295
  with gr.Blocks(
296
  theme=gr.themes.Soft(),
297
  css="""
298
+ .container {
299
+ max-width: 1200px;
300
+ margin: auto;
301
+ }
302
+ .header {
303
+ text-align: center;
304
+ margin: 20px 0 40px 0;
305
+ padding: 20px;
306
+ background: #f7f7f7;
307
+ border-radius: 12px;
308
+ }
309
+ .param-row {
310
+ padding: 10px 0;
311
+ }
312
+ footer {
313
+ margin-top: 40px;
314
+ padding: 20px;
315
+ border-top: 1px solid #eee;
316
+ }
317
  """
318
  ) as demo:
319
  with gr.Column(elem_classes="container"):
320
+ gr.Markdown(
321
+ """# 🎨 Qwen2vl-Flux Image Variation Demo
322
+ Generate creative variations of your images with optional text guidance"""
323
+ )
 
 
324
 
325
  with gr.Row(equal_height=True):
326
  with gr.Column(scale=1):
 
330
  height=384,
331
  sources=["upload", "clipboard"]
332
  )
333
+ prompt = gr.Textbox(
334
+ label="Text Prompt (Optional)",
335
+ placeholder="As Long As Possible...",
336
+ lines=3
337
+ )
338
  with gr.Accordion("Advanced Settings", open=False):
339
  with gr.Group():
 
 
 
 
 
340
 
341
  with gr.Row(elem_classes="param-row"):
342
  guidance = gr.Slider(
 
344
  maximum=10,
345
  value=3.5,
346
  step=0.5,
347
+ label="Guidance Scale",
348
+ info="Higher values follow prompt more closely"
349
  )
350
  steps = gr.Slider(
351
  minimum=1,
352
+ maximum=50,
353
  value=28,
354
  step=1,
355
+ label="Sampling Steps",
356
+ info="More steps = better quality but slower"
357
  )
358
 
359
  with gr.Row(elem_classes="param-row"):
360
  num_images = gr.Slider(
361
  minimum=1,
362
+ maximum=4,
363
+ value=1,
364
  step=1,
365
+ label="Number of Images",
366
+ info="Generate multiple variations at once"
367
  )
368
  seed = gr.Number(
369
  label="Random Seed",
370
  value=None,
371
+ precision=0,
372
+ info="Set for reproducible results"
373
  )
374
  aspect_ratio = gr.Radio(
375
  label="Aspect Ratio",
376
  choices=["1:1", "16:9", "9:16", "2.4:1", "3:4", "4:3"],
377
+ value="1:1",
378
+ info="Choose aspect ratio for generated images"
379
  )
380
 
381
+ submit_btn = gr.Button(
382
+ "🎨 Generate Variations",
383
+ variant="primary",
384
+ size="lg"
385
+ )
386
 
387
  with gr.Column(scale=1):
388
+ # Output Section
389
  output_gallery = gr.Gallery(
390
  label="Generated Variations",
391
  columns=2,
 
393
  height=700,
394
  object_fit="contain",
395
  show_label=True,
396
+ allow_preview=True,
397
+ preview=True
398
  )
399
+ error_message = gr.Textbox(visible=False)
400
 
401
+ with gr.Row(elem_classes="footer"):
402
+ gr.Markdown("""
403
+ ### Tips:
404
+ - 📸 Upload any image to get started
405
+ - 💡 Add an optional text prompt to guide the generation
406
+ - 🎯 Adjust guidance scale to control prompt influence
407
+ - ⚙️ Increase steps for higher quality
408
+ - 🎲 Use seeds for reproducible results
409
+ """)
410
+
411
  submit_btn.click(
412
+ fn=generate,
413
  inputs=[
414
  input_image,
415
  prompt,
 
418
  num_images,
419
  seed,
420
  aspect_ratio
421
+ ],
422
  outputs=[output_gallery],
423
  show_progress=True
424
  )
425
 
426
+ # Launch the app
427
  if __name__ == "__main__":
428
  demo.launch(
429
+ server_name="0.0.0.0", # Listen on all network interfaces
430
+ server_port=7860, # Use a specific port
431
+ share=False, # Disable public URL sharing
432
  )