jbilcke-hf HF staff commited on
Commit
b91a6aa
·
1 Parent(s): 5b7dcad

debugging the preview tab

Browse files
Files changed (3) hide show
  1. vms/config.py +40 -1
  2. vms/services/previewing.py +110 -44
  3. vms/tabs/preview_tab.py +330 -27
vms/config.py CHANGED
@@ -58,7 +58,6 @@ if NORMALIZE_IMAGES_TO not in ['png', 'jpg']:
58
  raise ValueError("NORMALIZE_IMAGES_TO must be either 'png' or 'jpg'")
59
  JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97'))
60
 
61
- # Expanded model types to include Wan-2.1-T2V
62
  MODEL_TYPES = {
63
  "HunyuanVideo": "hunyuan_video",
64
  "LTX-Video": "ltx_video",
@@ -71,6 +70,46 @@ TRAINING_TYPES = {
71
  "Full Finetune": "full-finetune"
72
  }
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  DEFAULT_SEED = 42
75
 
76
  DEFAULT_REMOVE_COMMON_LLM_CAPTION_PREFIXES = True
 
58
  raise ValueError("NORMALIZE_IMAGES_TO must be either 'png' or 'jpg'")
59
  JPEG_QUALITY = int(os.environ.get('JPEG_QUALITY', '97'))
60
 
 
61
  MODEL_TYPES = {
62
  "HunyuanVideo": "hunyuan_video",
63
  "LTX-Video": "ltx_video",
 
70
  "Full Finetune": "full-finetune"
71
  }
72
 
73
+ # Model variants for each model type
74
+ MODEL_VARIANTS = {
75
+ "wan": {
76
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers": {
77
+ "name": "Wan 2.1 T2V 1.3B (text-only, smaller)",
78
+ "type": "text-to-video",
79
+ "description": "Faster, smaller model (1.3B parameters)"
80
+ },
81
+ "Wan-AI/Wan2.1-T2V-14B-Diffusers": {
82
+ "name": "Wan 2.1 T2V 14B (text-only, larger)",
83
+ "type": "text-to-video",
84
+ "description": "Higher quality but slower (14B parameters)"
85
+ },
86
+ "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers": {
87
+ "name": "Wan 2.1 I2V 480p (image+text)",
88
+ "type": "image-to-video",
89
+ "description": "Image conditioning at 480p resolution"
90
+ },
91
+ "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers": {
92
+ "name": "Wan 2.1 I2V 720p (image+text)",
93
+ "type": "image-to-video",
94
+ "description": "Image conditioning at 720p resolution"
95
+ }
96
+ },
97
+ "ltx_video": {
98
+ "Lightricks/LTX-Video": {
99
+ "name": "LTX Video (official)",
100
+ "type": "text-to-video",
101
+ "description": "Official LTX Video model"
102
+ }
103
+ },
104
+ "hunyuan_video": {
105
+ "hunyuanvideo-community/HunyuanVideo": {
106
+ "name": "Hunyuan Video (official)",
107
+ "type": "text-to-video",
108
+ "description": "Official Hunyuan Video model"
109
+ }
110
+ }
111
+ }
112
+
113
  DEFAULT_SEED = 42
114
 
115
  DEFAULT_REMOVE_COMMON_LLM_CAPTION_PREFIXES = True
vms/services/previewing.py CHANGED
@@ -6,13 +6,13 @@ Handles the video generation logic and model integration
6
 
7
  import logging
8
  import tempfile
9
- import torch
10
  from pathlib import Path
11
  from typing import Dict, Any, List, Optional, Tuple, Callable
 
12
 
13
  from vms.config import (
14
  OUTPUT_PATH, STORAGE_PATH, MODEL_TYPES, TRAINING_PATH,
15
- DEFAULT_PROMPT_PREFIX
16
  )
17
  from vms.utils import format_time
18
 
@@ -48,9 +48,14 @@ class PreviewingService:
48
  logger.error(f"Error finding LoRA weights: {e}")
49
  return None
50
 
 
 
 
 
51
  def generate_video(
52
  self,
53
  model_type: str,
 
54
  prompt: str,
55
  negative_prompt: str,
56
  prompt_prefix: str,
@@ -62,7 +67,8 @@ class PreviewingService:
62
  lora_weight: float,
63
  inference_steps: int,
64
  enable_cpu_offload: bool,
65
- fps: int
 
66
  ) -> Tuple[Optional[str], str, str]:
67
  """Generate a video using the trained model"""
68
  try:
@@ -71,6 +77,7 @@ class PreviewingService:
71
  def log(msg: str):
72
  log_messages.append(msg)
73
  logger.info(msg)
 
74
  return "\n".join(log_messages)
75
 
76
  # Find latest LoRA weights
@@ -95,7 +102,30 @@ class PreviewingService:
95
  if not internal_model_type:
96
  return None, f"Error: Invalid model type {model_type}", log(f"Error: Invalid model type {model_type}")
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  log(f"Generating video with model type: {internal_model_type}")
 
99
  log(f"Using LoRA weights from: {lora_path}")
100
  log(f"Resolution: {width}x{height}, Frames: {num_frames}, FPS: {fps}")
101
  log(f"Guidance Scale: {guidance_scale}, Flow Shift: {flow_shift}, LoRA Weight: {lora_weight}")
@@ -107,19 +137,22 @@ class PreviewingService:
107
  return self.generate_wan_video(
108
  full_prompt, negative_prompt, width, height, num_frames,
109
  guidance_scale, flow_shift, lora_path, lora_weight,
110
- inference_steps, enable_cpu_offload, fps, log
 
111
  )
112
  elif internal_model_type == "ltx_video":
113
  return self.generate_ltx_video(
114
  full_prompt, negative_prompt, width, height, num_frames,
115
  guidance_scale, flow_shift, lora_path, lora_weight,
116
- inference_steps, enable_cpu_offload, fps, log
 
117
  )
118
  elif internal_model_type == "hunyuan_video":
119
  return self.generate_hunyuan_video(
120
  full_prompt, negative_prompt, width, height, num_frames,
121
  guidance_scale, flow_shift, lora_path, lora_weight,
122
- inference_steps, enable_cpu_offload, fps, log
 
123
  )
124
  else:
125
  return None, f"Error: Unsupported model type {internal_model_type}", log(f"Error: Unsupported model type {internal_model_type}")
@@ -142,28 +175,31 @@ class PreviewingService:
142
  inference_steps: int,
143
  enable_cpu_offload: bool,
144
  fps: int,
145
- log_fn: Callable
 
 
146
  ) -> Tuple[Optional[str], str, str]:
147
  """Generate video using Wan model"""
148
- start_time = torch.cuda.Event(enable_timing=True)
149
- end_time = torch.cuda.Event(enable_timing=True)
150
-
151
  try:
152
  import torch
153
  from diffusers import AutoencoderKLWan, WanPipeline
154
  from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
155
  from diffusers.utils import export_to_video
 
 
 
 
 
 
156
 
157
  log_fn("Importing Wan model components...")
158
 
159
- # Use the smaller model for faster inference
160
- model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
161
 
162
- log_fn(f"Loading VAE from {model_id}...")
163
- vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
164
-
165
- log_fn(f"Loading transformer from {model_id}...")
166
- pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
167
 
168
  log_fn(f"Configuring scheduler with flow_shift={flow_shift}...")
169
  pipe.scheduler = UniPCMultistepScheduler.from_config(
@@ -189,15 +225,36 @@ class PreviewingService:
189
  log_fn("Starting video generation...")
190
  start_time.record()
191
 
192
- output = pipe(
193
- prompt=prompt,
194
- negative_prompt=negative_prompt,
195
- height=height,
196
- width=width,
197
- num_frames=num_frames,
198
- guidance_scale=guidance_scale,
199
- num_inference_steps=inference_steps,
200
- ).frames[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  end_time.record()
203
  torch.cuda.synchronize()
@@ -236,23 +293,25 @@ class PreviewingService:
236
  inference_steps: int,
237
  enable_cpu_offload: bool,
238
  fps: int,
239
- log_fn: Callable
 
 
240
  ) -> Tuple[Optional[str], str, str]:
241
  """Generate video using LTX model"""
242
- start_time = torch.cuda.Event(enable_timing=True)
243
- end_time = torch.cuda.Event(enable_timing=True)
244
-
245
  try:
246
  import torch
247
  from diffusers import LTXPipeline
248
  from diffusers.utils import export_to_video
 
249
 
 
 
 
250
  log_fn("Importing LTX model components...")
251
 
252
- model_id = "Lightricks/LTX-Video"
253
-
254
- log_fn(f"Loading pipeline from {model_id}...")
255
- pipe = LTXPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
256
 
257
  log_fn("Moving pipeline to CUDA device...")
258
  pipe.to("cuda")
@@ -272,6 +331,7 @@ class PreviewingService:
272
  log_fn("Starting video generation...")
273
  start_time.record()
274
 
 
275
  video = pipe(
276
  prompt=prompt,
277
  negative_prompt=negative_prompt,
@@ -321,31 +381,33 @@ class PreviewingService:
321
  inference_steps: int,
322
  enable_cpu_offload: bool,
323
  fps: int,
324
- log_fn: Callable
 
 
325
  ) -> Tuple[Optional[str], str, str]:
326
  """Generate video using HunyuanVideo model"""
327
- start_time = torch.cuda.Event(enable_timing=True)
328
- end_time = torch.cuda.Event(enable_timing=True)
329
 
330
  try:
331
  import torch
332
  from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, AutoencoderKLHunyuanVideo
333
  from diffusers.utils import export_to_video
334
 
 
 
 
335
  log_fn("Importing HunyuanVideo model components...")
336
 
337
- model_id = "hunyuanvideo-community/HunyuanVideo"
338
-
339
- log_fn(f"Loading transformer from {model_id}...")
340
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
341
- model_id,
342
  subfolder="transformer",
343
  torch_dtype=torch.bfloat16
344
  )
345
 
346
- log_fn(f"Loading pipeline from {model_id}...")
347
  pipe = HunyuanVideoPipeline.from_pretrained(
348
- model_id,
349
  transformer=transformer,
350
  torch_dtype=torch.float16
351
  )
@@ -371,9 +433,13 @@ class PreviewingService:
371
  log_fn("Starting video generation...")
372
  start_time.record()
373
 
 
 
 
 
374
  output = pipe(
375
  prompt=prompt,
376
- negative_prompt=negative_prompt if negative_prompt else None,
377
  height=height,
378
  width=width,
379
  num_frames=num_frames,
 
6
 
7
  import logging
8
  import tempfile
 
9
  from pathlib import Path
10
  from typing import Dict, Any, List, Optional, Tuple, Callable
11
+ import time
12
 
13
  from vms.config import (
14
  OUTPUT_PATH, STORAGE_PATH, MODEL_TYPES, TRAINING_PATH,
15
+ DEFAULT_PROMPT_PREFIX, MODEL_VARIANTS
16
  )
17
  from vms.utils import format_time
18
 
 
48
  logger.error(f"Error finding LoRA weights: {e}")
49
  return None
50
 
51
+ def get_model_variants(self, model_type: str) -> Dict[str, Dict[str, str]]:
52
+ """Get available model variants for the given model type"""
53
+ return MODEL_VARIANTS.get(model_type, {})
54
+
55
  def generate_video(
56
  self,
57
  model_type: str,
58
+ model_variant: str,
59
  prompt: str,
60
  negative_prompt: str,
61
  prompt_prefix: str,
 
67
  lora_weight: float,
68
  inference_steps: int,
69
  enable_cpu_offload: bool,
70
+ fps: int,
71
+ conditioning_image: Optional[str] = None
72
  ) -> Tuple[Optional[str], str, str]:
73
  """Generate a video using the trained model"""
74
  try:
 
77
  def log(msg: str):
78
  log_messages.append(msg)
79
  logger.info(msg)
80
+ # Return updated log string for UI updates
81
  return "\n".join(log_messages)
82
 
83
  # Find latest LoRA weights
 
102
  if not internal_model_type:
103
  return None, f"Error: Invalid model type {model_type}", log(f"Error: Invalid model type {model_type}")
104
 
105
+ # Check if model variant is valid for this model type
106
+ variants = self.get_model_variants(internal_model_type)
107
+ if model_variant not in variants:
108
+ # Use default variant if specified one is invalid
109
+ if len(variants) > 0:
110
+ model_variant = next(iter(variants.keys()))
111
+ log(f"Warning: Invalid model variant, using default: {model_variant}")
112
+ else:
113
+ # Fall back to default IDs if no variants defined
114
+ if internal_model_type == "wan":
115
+ model_variant = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
116
+ elif internal_model_type == "ltx_video":
117
+ model_variant = "Lightricks/LTX-Video"
118
+ elif internal_model_type == "hunyuan_video":
119
+ model_variant = "hunyuanvideo-community/HunyuanVideo"
120
+ log(f"Warning: No variants defined for model type, using default: {model_variant}")
121
+
122
+ # Check if this is an image-to-video model but no image was provided
123
+ variant_info = variants.get(model_variant, {})
124
+ if variant_info.get("type") == "image-to-video" and not conditioning_image:
125
+ return None, "Error: This model requires a conditioning image", log("Error: This model variant requires a conditioning image but none was provided")
126
+
127
  log(f"Generating video with model type: {internal_model_type}")
128
+ log(f"Using model variant: {model_variant}")
129
  log(f"Using LoRA weights from: {lora_path}")
130
  log(f"Resolution: {width}x{height}, Frames: {num_frames}, FPS: {fps}")
131
  log(f"Guidance Scale: {guidance_scale}, Flow Shift: {flow_shift}, LoRA Weight: {lora_weight}")
 
137
  return self.generate_wan_video(
138
  full_prompt, negative_prompt, width, height, num_frames,
139
  guidance_scale, flow_shift, lora_path, lora_weight,
140
+ inference_steps, enable_cpu_offload, fps, log,
141
+ model_variant, conditioning_image
142
  )
143
  elif internal_model_type == "ltx_video":
144
  return self.generate_ltx_video(
145
  full_prompt, negative_prompt, width, height, num_frames,
146
  guidance_scale, flow_shift, lora_path, lora_weight,
147
+ inference_steps, enable_cpu_offload, fps, log,
148
+ model_variant, conditioning_image
149
  )
150
  elif internal_model_type == "hunyuan_video":
151
  return self.generate_hunyuan_video(
152
  full_prompt, negative_prompt, width, height, num_frames,
153
  guidance_scale, flow_shift, lora_path, lora_weight,
154
+ inference_steps, enable_cpu_offload, fps, log,
155
+ model_variant, conditioning_image
156
  )
157
  else:
158
  return None, f"Error: Unsupported model type {internal_model_type}", log(f"Error: Unsupported model type {internal_model_type}")
 
175
  inference_steps: int,
176
  enable_cpu_offload: bool,
177
  fps: int,
178
+ log_fn: Callable,
179
+ model_variant: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
180
+ conditioning_image: Optional[str] = None
181
  ) -> Tuple[Optional[str], str, str]:
182
  """Generate video using Wan model"""
183
+
 
 
184
  try:
185
  import torch
186
  from diffusers import AutoencoderKLWan, WanPipeline
187
  from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
188
  from diffusers.utils import export_to_video
189
+ from PIL import Image
190
+ import os
191
+
192
+ start_time = torch.cuda.Event(enable_timing=True)
193
+ end_time = torch.cuda.Event(enable_timing=True)
194
+
195
 
196
  log_fn("Importing Wan model components...")
197
 
198
+ log_fn(f"Loading VAE from {model_variant}...")
199
+ vae = AutoencoderKLWan.from_pretrained(model_variant, subfolder="vae", torch_dtype=torch.float32)
200
 
201
+ log_fn(f"Loading transformer from {model_variant}...")
202
+ pipe = WanPipeline.from_pretrained(model_variant, vae=vae, torch_dtype=torch.bfloat16)
 
 
 
203
 
204
  log_fn(f"Configuring scheduler with flow_shift={flow_shift}...")
205
  pipe.scheduler = UniPCMultistepScheduler.from_config(
 
225
  log_fn("Starting video generation...")
226
  start_time.record()
227
 
228
+ # Check if this is an image-to-video model
229
+ is_i2v = "I2V" in model_variant
230
+
231
+ if is_i2v and conditioning_image:
232
+ log_fn(f"Loading conditioning image from {conditioning_image}...")
233
+ image = Image.open(conditioning_image).convert("RGB")
234
+ image = image.resize((width, height))
235
+
236
+ log_fn("Generating video with image conditioning...")
237
+ output = pipe(
238
+ prompt=prompt,
239
+ negative_prompt=negative_prompt,
240
+ image=image,
241
+ height=height,
242
+ width=width,
243
+ num_frames=num_frames,
244
+ guidance_scale=guidance_scale,
245
+ num_inference_steps=inference_steps,
246
+ ).frames[0]
247
+ else:
248
+ log_fn("Generating video with text-only conditioning...")
249
+ output = pipe(
250
+ prompt=prompt,
251
+ negative_prompt=negative_prompt,
252
+ height=height,
253
+ width=width,
254
+ num_frames=num_frames,
255
+ guidance_scale=guidance_scale,
256
+ num_inference_steps=inference_steps,
257
+ ).frames[0]
258
 
259
  end_time.record()
260
  torch.cuda.synchronize()
 
293
  inference_steps: int,
294
  enable_cpu_offload: bool,
295
  fps: int,
296
+ log_fn: Callable,
297
+ model_variant: str = "Lightricks/LTX-Video",
298
+ conditioning_image: Optional[str] = None
299
  ) -> Tuple[Optional[str], str, str]:
300
  """Generate video using LTX model"""
301
+
 
 
302
  try:
303
  import torch
304
  from diffusers import LTXPipeline
305
  from diffusers.utils import export_to_video
306
+ from PIL import Image
307
 
308
+ start_time = torch.cuda.Event(enable_timing=True)
309
+ end_time = torch.cuda.Event(enable_timing=True)
310
+
311
  log_fn("Importing LTX model components...")
312
 
313
+ log_fn(f"Loading pipeline from {model_variant}...")
314
+ pipe = LTXPipeline.from_pretrained(model_variant, torch_dtype=torch.bfloat16)
 
 
315
 
316
  log_fn("Moving pipeline to CUDA device...")
317
  pipe.to("cuda")
 
331
  log_fn("Starting video generation...")
332
  start_time.record()
333
 
334
+ # LTX doesn't currently support image conditioning in the standard way
335
  video = pipe(
336
  prompt=prompt,
337
  negative_prompt=negative_prompt,
 
381
  inference_steps: int,
382
  enable_cpu_offload: bool,
383
  fps: int,
384
+ log_fn: Callable,
385
+ model_variant: str = "hunyuanvideo-community/HunyuanVideo",
386
+ conditioning_image: Optional[str] = None
387
  ) -> Tuple[Optional[str], str, str]:
388
  """Generate video using HunyuanVideo model"""
389
+
 
390
 
391
  try:
392
  import torch
393
  from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel, AutoencoderKLHunyuanVideo
394
  from diffusers.utils import export_to_video
395
 
396
+ start_time = torch.cuda.Event(enable_timing=True)
397
+ end_time = torch.cuda.Event(enable_timing=True)
398
+
399
  log_fn("Importing HunyuanVideo model components...")
400
 
401
+ log_fn(f"Loading transformer from {model_variant}...")
 
 
402
  transformer = HunyuanVideoTransformer3DModel.from_pretrained(
403
+ model_variant,
404
  subfolder="transformer",
405
  torch_dtype=torch.bfloat16
406
  )
407
 
408
+ log_fn(f"Loading pipeline from {model_variant}...")
409
  pipe = HunyuanVideoPipeline.from_pretrained(
410
+ model_variant,
411
  transformer=transformer,
412
  torch_dtype=torch.float16
413
  )
 
433
  log_fn("Starting video generation...")
434
  start_time.record()
435
 
436
+ # Fix for Issue #2: The pipe() expected list rather than float
437
+ # Make sure negative_prompt is a list or None
438
+ neg_prompt = [negative_prompt] if negative_prompt else None
439
+
440
  output = pipe(
441
  prompt=prompt,
442
+ negative_prompt=neg_prompt,
443
  height=height,
444
  width=width,
445
  num_frames=num_frames,
vms/tabs/preview_tab.py CHANGED
@@ -6,9 +6,10 @@ import gradio as gr
6
  import logging
7
  from pathlib import Path
8
  from typing import Dict, Any, List, Optional, Tuple
 
9
 
10
- from vms.tabs import BaseTab
11
- from vms.config import (
12
  MODEL_TYPES, DEFAULT_PROMPT_PREFIX
13
  )
14
 
@@ -21,10 +22,7 @@ class PreviewTab(BaseTab):
21
  super().__init__(app_state)
22
  self.id = "preview_tab"
23
  self.title = "6️⃣ Preview"
24
-
25
- # Get reference to the preview service from app_state
26
- self.previewing_service = app_state.previewing
27
-
28
  def create(self, parent=None) -> gr.TabItem:
29
  """Create the Preview tab UI components"""
30
  with gr.TabItem(self.title, id=self.id) as tab:
@@ -53,12 +51,32 @@ class PreviewTab(BaseTab):
53
  )
54
 
55
  with gr.Row():
 
 
 
 
56
  self.components["model_type"] = gr.Dropdown(
57
  choices=list(MODEL_TYPES.keys()),
58
- label="Model Type",
59
- value=list(MODEL_TYPES.keys())[0]
 
60
  )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  self.components["resolution_preset"] = gr.Dropdown(
63
  choices=["480p", "720p"],
64
  label="Resolution Preset",
@@ -150,15 +168,89 @@ class PreviewTab(BaseTab):
150
  interactive=False
151
  )
152
 
153
- with gr.Accordion("Log", open=False):
154
  self.components["log"] = gr.TextArea(
155
  label="Generation Log",
156
  interactive=False,
157
- lines=10
158
  )
159
 
160
  return tab
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def connect_events(self) -> None:
163
  """Connect event handlers to UI components"""
164
  # Update resolution when preset changes
@@ -172,11 +264,70 @@ class PreviewTab(BaseTab):
172
  ]
173
  )
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  # Generate button click
176
  self.components["generate_btn"].click(
177
  fn=self.generate_video,
178
  inputs=[
179
  self.components["model_type"],
 
180
  self.components["prompt"],
181
  self.components["negative_prompt"],
182
  self.components["prompt_prefix"],
@@ -188,7 +339,8 @@ class PreviewTab(BaseTab):
188
  self.components["lora_weight"],
189
  self.components["inference_steps"],
190
  self.components["enable_cpu_offload"],
191
- self.components["fps"]
 
192
  ],
193
  outputs=[
194
  self.components["preview_video"],
@@ -197,6 +349,23 @@ class PreviewTab(BaseTab):
197
  ]
198
  )
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  def update_resolution(self, preset: str) -> Tuple[int, int, float]:
201
  """Update resolution and flow shift based on preset"""
202
  if preset == "480p":
@@ -206,9 +375,88 @@ class PreviewTab(BaseTab):
206
  else:
207
  return 832, 480, 3.0
208
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  def generate_video(
210
  self,
211
  model_type: str,
 
212
  prompt: str,
213
  negative_prompt: str,
214
  prompt_prefix: str,
@@ -220,21 +468,76 @@ class PreviewTab(BaseTab):
220
  lora_weight: float,
221
  inference_steps: int,
222
  enable_cpu_offload: bool,
223
- fps: int
 
224
  ) -> Tuple[Optional[str], str, str]:
225
  """Handler for generate button click, delegates to preview service"""
226
- return self.preview_service.generate_video(
227
- model_type=model_type,
228
- prompt=prompt,
229
- negative_prompt=negative_prompt,
230
- prompt_prefix=prompt_prefix,
231
- width=width,
232
- height=height,
233
- num_frames=num_frames,
234
- guidance_scale=guidance_scale,
235
- flow_shift=flow_shift,
236
- lora_weight=lora_weight,
237
- inference_steps=inference_steps,
238
- enable_cpu_offload=enable_cpu_offload,
239
- fps=fps
240
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  import logging
7
  from pathlib import Path
8
  from typing import Dict, Any, List, Optional, Tuple
9
+ import time
10
 
11
+ from .base_tab import BaseTab
12
+ from ..config import (
13
  MODEL_TYPES, DEFAULT_PROMPT_PREFIX
14
  )
15
 
 
22
  super().__init__(app_state)
23
  self.id = "preview_tab"
24
  self.title = "6️⃣ Preview"
25
+
 
 
 
26
  def create(self, parent=None) -> gr.TabItem:
27
  """Create the Preview tab UI components"""
28
  with gr.TabItem(self.title, id=self.id) as tab:
 
51
  )
52
 
53
  with gr.Row():
54
+ # Get the currently selected model type from training tab if possible
55
+ default_model = self.get_default_model_type()
56
+
57
+ # Make model_type read-only (disabled), as it must match what was trained
58
  self.components["model_type"] = gr.Dropdown(
59
  choices=list(MODEL_TYPES.keys()),
60
+ label="Model Type (from training)",
61
+ value=default_model,
62
+ interactive=False
63
  )
64
 
65
+ # Add model variant selection based on model type
66
+ self.components["model_variant"] = gr.Dropdown(
67
+ label="Model Variant",
68
+ choices=self.get_variant_choices(default_model),
69
+ value=self.get_default_variant(default_model)
70
+ )
71
+
72
+ # Add image input for image-to-video models
73
+ self.components["conditioning_image"] = gr.Image(
74
+ label="Conditioning Image (for Image-to-Video models)",
75
+ type="filepath",
76
+ visible=False
77
+ )
78
+
79
+ with gr.Row():
80
  self.components["resolution_preset"] = gr.Dropdown(
81
  choices=["480p", "720p"],
82
  label="Resolution Preset",
 
168
  interactive=False
169
  )
170
 
171
+ with gr.Accordion("Log", open=True):
172
  self.components["log"] = gr.TextArea(
173
  label="Generation Log",
174
  interactive=False,
175
+ lines=15
176
  )
177
 
178
  return tab
179
 
180
+ def get_variant_choices(self, model_type: str) -> List[str]:
181
+ """Get model variant choices based on model type"""
182
+ # Convert UI display name to internal name
183
+ internal_type = MODEL_TYPES.get(model_type)
184
+ if not internal_type:
185
+ return []
186
+
187
+ # Get variants from preview service
188
+ variants = self.app.previewing.get_model_variants(internal_type)
189
+ if not variants:
190
+ return []
191
+
192
+ # Format choices with display name and description
193
+ choices = []
194
+ for model_id, info in variants.items():
195
+ choices.append(f"{model_id} - {info.get('name', '')}")
196
+
197
+ return choices
198
+
199
+ def get_default_variant(self, model_type: str) -> str:
200
+ """Get default model variant for the model type"""
201
+ choices = self.get_variant_choices(model_type)
202
+ if choices:
203
+ return choices[0]
204
+ return ""
205
+
206
+ def get_default_model_type(self) -> str:
207
+ """Get the currently selected model type from training tab"""
208
+ try:
209
+ # Try to get the model type from UI state
210
+ ui_state = self.app.training.load_ui_state()
211
+ model_type = ui_state.get("model_type")
212
+
213
+ # Make sure it's a valid model type
214
+ if model_type in MODEL_TYPES:
215
+ return model_type
216
+
217
+ # If we couldn't get a valid model type, try to get it from the training tab directly
218
+ if hasattr(self.app, 'tabs') and 'train_tab' in self.app.tabs:
219
+ train_tab = self.app.tabs['train_tab']
220
+ if hasattr(train_tab, 'components') and 'model_type' in train_tab.components:
221
+ train_model_type = train_tab.components['model_type'].value
222
+ if train_model_type in MODEL_TYPES:
223
+ return train_model_type
224
+
225
+ # Fallback to first model type
226
+ return list(MODEL_TYPES.keys())[0]
227
+ except Exception as e:
228
+ logger.warning(f"Failed to get default model type: {e}")
229
+ return list(MODEL_TYPES.keys())[0]
230
+
231
+ def extract_model_id(self, variant_choice: str) -> str:
232
+ """Extract model ID from variant choice string"""
233
+ if " - " in variant_choice:
234
+ return variant_choice.split(" - ")[0].strip()
235
+ return variant_choice
236
+
237
+ def get_variant_type(self, model_type: str, model_variant: str) -> str:
238
+ """Get the variant type (text-to-video or image-to-video)"""
239
+ # Convert UI display name to internal name
240
+ internal_type = MODEL_TYPES.get(model_type)
241
+ if not internal_type:
242
+ return "text-to-video"
243
+
244
+ # Extract model_id from variant choice
245
+ model_id = self.extract_model_id(model_variant)
246
+
247
+ # Get variants from preview service
248
+ variants = self.app.previewing.get_model_variants(internal_type)
249
+ variant_info = variants.get(model_id, {})
250
+
251
+ # Return the variant type or default to text-to-video
252
+ return variant_info.get("type", "text-to-video")
253
+
254
  def connect_events(self) -> None:
255
  """Connect event handlers to UI components"""
256
  # Update resolution when preset changes
 
264
  ]
265
  )
266
 
267
+ # Update model_variant choices when model_type changes or tab is selected
268
+ if hasattr(self.app, 'tabs_component') and self.app.tabs_component is not None:
269
+ self.app.tabs_component.select(
270
+ fn=self.sync_model_type_and_variants,
271
+ inputs=[],
272
+ outputs=[
273
+ self.components["model_type"],
274
+ self.components["model_variant"]
275
+ ]
276
+ )
277
+
278
+ # Update variant-specific UI elements when variant changes
279
+ self.components["model_variant"].change(
280
+ fn=self.update_variant_ui,
281
+ inputs=[
282
+ self.components["model_type"],
283
+ self.components["model_variant"]
284
+ ],
285
+ outputs=[
286
+ self.components["conditioning_image"]
287
+ ]
288
+ )
289
+
290
+ # Load preview UI state when the tab is selected
291
+ if hasattr(self.app, 'tabs_component') and self.app.tabs_component is not None:
292
+ self.app.tabs_component.select(
293
+ fn=self.load_preview_state,
294
+ inputs=[],
295
+ outputs=[
296
+ self.components["prompt"],
297
+ self.components["negative_prompt"],
298
+ self.components["prompt_prefix"],
299
+ self.components["width"],
300
+ self.components["height"],
301
+ self.components["num_frames"],
302
+ self.components["fps"],
303
+ self.components["guidance_scale"],
304
+ self.components["flow_shift"],
305
+ self.components["lora_weight"],
306
+ self.components["inference_steps"],
307
+ self.components["enable_cpu_offload"],
308
+ self.components["model_variant"]
309
+ ]
310
+ )
311
+
312
+ # Save preview UI state when values change
313
+ for component_name in [
314
+ "prompt", "negative_prompt", "prompt_prefix", "model_variant", "resolution_preset",
315
+ "width", "height", "num_frames", "fps", "guidance_scale", "flow_shift",
316
+ "lora_weight", "inference_steps", "enable_cpu_offload"
317
+ ]:
318
+ if component_name in self.components:
319
+ self.components[component_name].change(
320
+ fn=self.save_preview_state_value,
321
+ inputs=[self.components[component_name]],
322
+ outputs=[]
323
+ )
324
+
325
  # Generate button click
326
  self.components["generate_btn"].click(
327
  fn=self.generate_video,
328
  inputs=[
329
  self.components["model_type"],
330
+ self.components["model_variant"],
331
  self.components["prompt"],
332
  self.components["negative_prompt"],
333
  self.components["prompt_prefix"],
 
339
  self.components["lora_weight"],
340
  self.components["inference_steps"],
341
  self.components["enable_cpu_offload"],
342
+ self.components["fps"],
343
+ self.components["conditioning_image"]
344
  ],
345
  outputs=[
346
  self.components["preview_video"],
 
349
  ]
350
  )
351
 
352
+ def update_variant_ui(self, model_type: str, model_variant: str) -> Dict[str, Any]:
353
+ """Update UI based on the selected model variant"""
354
+ variant_type = self.get_variant_type(model_type, model_variant)
355
+
356
+ # Show conditioning image input only for image-to-video models
357
+ show_conditioning_image = variant_type == "image-to-video"
358
+
359
+ return {
360
+ self.components["conditioning_image"]: gr.Image(visible=show_conditioning_image)
361
+ }
362
+
363
+ def sync_model_type_and_variants(self) -> Tuple[str, str]:
364
+ """Sync model type with training tab when preview tab is selected and update variant choices"""
365
+ model_type = self.get_default_model_type()
366
+ model_variant = self.get_default_variant(model_type)
367
+ return model_type, model_variant
368
+
369
  def update_resolution(self, preset: str) -> Tuple[int, int, float]:
370
  """Update resolution and flow shift based on preset"""
371
  if preset == "480p":
 
375
  else:
376
  return 832, 480, 3.0
377
 
378
+ def load_preview_state(self) -> Tuple:
379
+ """Load saved preview UI state"""
380
+ # Try to get the saved state
381
+ try:
382
+ state = self.app.training.load_ui_state()
383
+ preview_state = state.get("preview", {})
384
+
385
+ # Get model type (can't be changed in UI)
386
+ model_type = self.get_default_model_type()
387
+
388
+ # If model_variant not in choices for current model_type, use default
389
+ model_variant = preview_state.get("model_variant", "")
390
+ variant_choices = self.get_variant_choices(model_type)
391
+ if model_variant not in variant_choices and variant_choices:
392
+ model_variant = variant_choices[0]
393
+
394
+ return (
395
+ preview_state.get("prompt", ""),
396
+ preview_state.get("negative_prompt", "worst quality, low quality, blurry, jittery, distorted, ugly, deformed, disfigured, messy background"),
397
+ preview_state.get("prompt_prefix", DEFAULT_PROMPT_PREFIX),
398
+ preview_state.get("width", 832),
399
+ preview_state.get("height", 480),
400
+ preview_state.get("num_frames", 49),
401
+ preview_state.get("fps", 16),
402
+ preview_state.get("guidance_scale", 5.0),
403
+ preview_state.get("flow_shift", 3.0),
404
+ preview_state.get("lora_weight", 0.7),
405
+ preview_state.get("inference_steps", 30),
406
+ preview_state.get("enable_cpu_offload", True),
407
+ model_variant
408
+ )
409
+ except Exception as e:
410
+ logger.error(f"Error loading preview state: {e}")
411
+ # Return defaults if loading fails
412
+ return (
413
+ "",
414
+ "worst quality, low quality, blurry, jittery, distorted, ugly, deformed, disfigured, messy background",
415
+ DEFAULT_PROMPT_PREFIX,
416
+ 832, 480, 49, 16, 5.0, 3.0, 0.7, 30, True,
417
+ self.get_default_variant(self.get_default_model_type())
418
+ )
419
+
420
+ def save_preview_state_value(self, value: Any) -> None:
421
+ """Save an individual preview state value"""
422
+ try:
423
+ # Get the component name from the event context
424
+ import inspect
425
+ frame = inspect.currentframe()
426
+ frame = inspect.getouterframes(frame)[1]
427
+ event_context = frame.frame.f_locals
428
+ component = event_context.get('component')
429
+
430
+ if component is None:
431
+ return
432
+
433
+ # Find the component name
434
+ component_name = None
435
+ for name, comp in self.components.items():
436
+ if comp == component:
437
+ component_name = name
438
+ break
439
+
440
+ if component_name is None:
441
+ return
442
+
443
+ # Load current state
444
+ state = self.app.training.load_ui_state()
445
+ if "preview" not in state:
446
+ state["preview"] = {}
447
+
448
+ # Update the value
449
+ state["preview"][component_name] = value
450
+
451
+ # Save state
452
+ self.app.training.save_ui_state(state)
453
+ except Exception as e:
454
+ logger.error(f"Error saving preview state: {e}")
455
+
456
  def generate_video(
457
  self,
458
  model_type: str,
459
+ model_variant: str,
460
  prompt: str,
461
  negative_prompt: str,
462
  prompt_prefix: str,
 
468
  lora_weight: float,
469
  inference_steps: int,
470
  enable_cpu_offload: bool,
471
+ fps: int,
472
+ conditioning_image: Optional[str] = None
473
  ) -> Tuple[Optional[str], str, str]:
474
  """Handler for generate button click, delegates to preview service"""
475
+ # Save all the parameters to preview state before generating
476
+ try:
477
+ state = self.app.training.load_ui_state()
478
+ if "preview" not in state:
479
+ state["preview"] = {}
480
+
481
+ # Extract model ID from variant choice
482
+ model_variant_id = self.extract_model_id(model_variant)
483
+
484
+ # Update all values
485
+ preview_state = {
486
+ "prompt": prompt,
487
+ "negative_prompt": negative_prompt,
488
+ "prompt_prefix": prompt_prefix,
489
+ "model_type": model_type,
490
+ "model_variant": model_variant,
491
+ "width": width,
492
+ "height": height,
493
+ "num_frames": num_frames,
494
+ "fps": fps,
495
+ "guidance_scale": guidance_scale,
496
+ "flow_shift": flow_shift,
497
+ "lora_weight": lora_weight,
498
+ "inference_steps": inference_steps,
499
+ "enable_cpu_offload": enable_cpu_offload
500
+ }
501
+
502
+ state["preview"] = preview_state
503
+ self.app.training.save_ui_state(state)
504
+ except Exception as e:
505
+ logger.error(f"Error saving preview state before generation: {e}")
506
+
507
+ # Clear the log display at the start to make room for new logs
508
+ # Yield and sleep briefly to allow UI update
509
+ yield None, "Starting generation...", ""
510
+ time.sleep(0.1)
511
+
512
+ # Extract model ID from variant choice string
513
+ model_variant_id = self.extract_model_id(model_variant)
514
+
515
+ # Use streaming updates to provide real-time feedback during generation
516
+ def generate_with_updates():
517
+ # Initial UI update
518
+ yield None, "Initializing generation...", "Starting video generation process..."
519
+
520
+ # Start actual generation
521
+ result = self.app.previewing.generate_video(
522
+ model_type=model_type,
523
+ model_variant=model_variant_id,
524
+ prompt=prompt,
525
+ negative_prompt=negative_prompt,
526
+ prompt_prefix=prompt_prefix,
527
+ width=width,
528
+ height=height,
529
+ num_frames=num_frames,
530
+ guidance_scale=guidance_scale,
531
+ flow_shift=flow_shift,
532
+ lora_weight=lora_weight,
533
+ inference_steps=inference_steps,
534
+ enable_cpu_offload=enable_cpu_offload,
535
+ fps=fps,
536
+ conditioning_image=conditioning_image
537
+ )
538
+
539
+ # Return final result
540
+ return result
541
+
542
+ # Return the generator for streaming updates
543
+ return generate_with_updates()