jbilcke-hf HF staff commited on
Commit
0ad7e2a
Β·
1 Parent(s): 40f9c1e

refactoring to a better architecture

Browse files
app.py CHANGED
@@ -1,1575 +1,28 @@
1
- import platform
2
- import subprocess
3
-
4
- #import sys
5
- #print("python = ", sys.version)
6
-
7
- # can be "Linux", "Darwin"
8
- if platform.system() == "Linux":
9
- # for some reason it says "pip not found"
10
- # and also "pip3 not found"
11
- # subprocess.run(
12
- # "pip install flash-attn --no-build-isolation",
13
- #
14
- # # hmm... this should be False, since we are in a CUDA environment, no?
15
- # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
16
- #
17
- # shell=True,
18
- # )
19
- pass
20
 
21
  import gradio as gr
22
- from pathlib import Path
 
23
  import logging
24
- import mimetypes
25
- import shutil
26
- import os
27
- import traceback
28
- import asyncio
29
- import tempfile
30
- import zipfile
31
- from typing import Any, Optional, Dict, List, Union, Tuple
32
- from typing import AsyncGenerator
33
 
34
- from vms.training_service import TrainingService
35
- from vms.captioning_service import CaptioningService
36
- from vms.splitting_service import SplittingService
37
- from vms.import_service import ImportService
38
  from vms.config import (
39
- STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
40
- TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
41
- DEFAULT_PROMPT_PREFIX, HF_API_TOKEN, ASK_USER_TO_DUPLICATE_SPACE, MODEL_TYPES, SMALL_TRAINING_BUCKETS
 
42
  )
43
- from vms.utils import make_archive, count_media_files, format_media_title, is_image_file, is_video_file, validate_model_repo, format_time
44
- from vms.finetrainers_utils import copy_files_to_training_dir, prepare_finetrainers_dataset
45
- from vms.training_log_parser import TrainingLogParser
46
 
 
47
  logger = logging.getLogger(__name__)
48
  logger.setLevel(logging.INFO)
49
 
50
- httpx_logger = logging.getLogger('httpx')
51
- httpx_logger.setLevel(logging.WARN)
52
-
53
-
54
- class VideoTrainerUI:
55
- def __init__(self):
56
- self.trainer = TrainingService()
57
- self.splitter = SplittingService()
58
- self.importer = ImportService()
59
- self.captioner = CaptioningService()
60
- self._should_stop_captioning = False
61
- self.log_parser = TrainingLogParser()
62
-
63
- # Try to recover any interrupted training sessions
64
- recovery_result = self.trainer.recover_interrupted_training()
65
-
66
- self.recovery_status = recovery_result.get("status", "unknown")
67
- self.ui_updates = recovery_result.get("ui_updates", {})
68
-
69
- if recovery_result["status"] == "recovered":
70
- logger.info(f"Training recovery: {recovery_result['message']}")
71
- # No need to do anything else - the training is already running
72
- elif recovery_result["status"] == "running":
73
- logger.info("Training process is already running")
74
- # No need to do anything - the process is still alive
75
- elif recovery_result["status"] in ["error", "idle"]:
76
- logger.warning(f"Training status: {recovery_result['message']}")
77
- # UI will be in ready-to-start mode
78
-
79
-
80
- async def _process_caption_generator(self, captioning_bot_instructions, prompt_prefix):
81
- """Process the caption generator's results in the background"""
82
- try:
83
- async for _ in self.captioner.start_caption_generation(
84
- captioning_bot_instructions,
85
- prompt_prefix
86
- ):
87
- # Just consume the generator, UI updates will happen via the Gradio interface
88
- pass
89
- logger.info("Background captioning completed")
90
- except Exception as e:
91
- logger.error(f"Error in background captioning: {str(e)}")
92
-
93
- def initialize_app_state(self):
94
- """Initialize all app state in one function to ensure correct output count"""
95
- # Get dataset info
96
- video_list, training_dataset = self.refresh_dataset()
97
-
98
- # Get button states
99
- button_states = self.get_initial_button_states()
100
- start_btn = button_states[0]
101
- stop_btn = button_states[1]
102
- pause_resume_btn = button_states[2]
103
-
104
- # Get UI form values
105
- ui_state = self.load_ui_values()
106
- training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
107
- model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
108
- lora_rank_val = ui_state.get("lora_rank", "128")
109
- lora_alpha_val = ui_state.get("lora_alpha", "128")
110
- num_epochs_val = int(ui_state.get("num_epochs", 70))
111
- batch_size_val = int(ui_state.get("batch_size", 1))
112
- learning_rate_val = float(ui_state.get("learning_rate", 3e-5))
113
- save_iterations_val = int(ui_state.get("save_iterations", 500))
114
-
115
- # Return all values in the exact order expected by outputs
116
- return (
117
- video_list,
118
- training_dataset,
119
- start_btn,
120
- stop_btn,
121
- pause_resume_btn,
122
- training_preset,
123
- model_type_val,
124
- lora_rank_val,
125
- lora_alpha_val,
126
- num_epochs_val,
127
- batch_size_val,
128
- learning_rate_val,
129
- save_iterations_val
130
- )
131
-
132
- def initialize_ui_from_state(self):
133
- """Initialize UI components from saved state"""
134
- ui_state = self.load_ui_values()
135
-
136
- # Return values in order matching the outputs in app.load
137
- return (
138
- ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
139
- ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
140
- ui_state.get("lora_rank", "128"),
141
- ui_state.get("lora_alpha", "128"),
142
- ui_state.get("num_epochs", 70),
143
- ui_state.get("batch_size", 1),
144
- ui_state.get("learning_rate", 3e-5),
145
- ui_state.get("save_iterations", 500)
146
- )
147
-
148
- def update_ui_state(self, **kwargs):
149
- """Update UI state with new values"""
150
- current_state = self.trainer.load_ui_state()
151
- current_state.update(kwargs)
152
- self.trainer.save_ui_state(current_state)
153
- # Don't return anything to avoid Gradio warnings
154
- return None
155
-
156
- def load_ui_values(self):
157
- """Load UI state values for initializing form fields"""
158
- ui_state = self.trainer.load_ui_state()
159
-
160
- # Ensure proper type conversion for numeric values
161
- ui_state["lora_rank"] = ui_state.get("lora_rank", "128")
162
- ui_state["lora_alpha"] = ui_state.get("lora_alpha", "128")
163
- ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70))
164
- ui_state["batch_size"] = int(ui_state.get("batch_size", 1))
165
- ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5))
166
- ui_state["save_iterations"] = int(ui_state.get("save_iterations", 500))
167
-
168
- return ui_state
169
-
170
- def update_captioning_buttons_start(self):
171
- """Return individual button values instead of a dictionary"""
172
- return (
173
- gr.Button(
174
- interactive=False,
175
- variant="secondary",
176
- ),
177
- gr.Button(
178
- interactive=True,
179
- variant="stop",
180
- ),
181
- gr.Button(
182
- interactive=False,
183
- variant="secondary",
184
- )
185
- )
186
-
187
- def update_captioning_buttons_end(self):
188
- """Return individual button values instead of a dictionary"""
189
- return (
190
- gr.Button(
191
- interactive=True,
192
- variant="primary",
193
- ),
194
- gr.Button(
195
- interactive=False,
196
- variant="secondary",
197
- ),
198
- gr.Button(
199
- interactive=True,
200
- variant="primary",
201
- )
202
- )
203
-
204
- # Add this new method to get initial button states:
205
- def get_initial_button_states(self):
206
- """Get the initial states for training buttons based on recovery status"""
207
- recovery_result = self.trainer.recover_interrupted_training()
208
- ui_updates = recovery_result.get("ui_updates", {})
209
-
210
- # Return button states in the correct order
211
- return (
212
- gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})),
213
- gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})),
214
- gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
215
- )
216
-
217
- def show_refreshing_status(self) -> List[List[str]]:
218
- """Show a 'Refreshing...' status in the dataframe"""
219
- return [["Refreshing...", "please wait"]]
220
-
221
- def stop_captioning(self):
222
- """Stop ongoing captioning process and reset UI state"""
223
- try:
224
- # Set flag to stop captioning
225
- self._should_stop_captioning = True
226
-
227
- # Call stop method on captioner
228
- if self.captioner:
229
- self.captioner.stop_captioning()
230
-
231
- # Get updated file list
232
- updated_list = self.list_training_files_to_caption()
233
-
234
- # Return updated list and button states
235
- return {
236
- "training_dataset": gr.update(value=updated_list),
237
- "run_autocaption_btn": gr.Button(interactive=True, variant="primary"),
238
- "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"),
239
- "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary")
240
- }
241
- except Exception as e:
242
- logger.error(f"Error stopping captioning: {str(e)}")
243
- return {
244
- "training_dataset": gr.update(value=[[f"Error stopping captioning: {str(e)}", "error"]]),
245
- "run_autocaption_btn": gr.Button(interactive=True, variant="primary"),
246
- "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"),
247
- "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary")
248
- }
249
-
250
- def update_training_ui(self, training_state: Dict[str, Any]):
251
- """Update UI components based on training state"""
252
- updates = {}
253
-
254
- #print("update_training_ui: training_state = ", training_state)
255
-
256
- # Update status box with high-level information
257
- status_text = []
258
- if training_state["status"] != "idle":
259
- status_text.extend([
260
- f"Status: {training_state['status']}",
261
- f"Progress: {training_state['progress']}",
262
- f"Step: {training_state['current_step']}/{training_state['total_steps']}",
263
-
264
- # Epoch information
265
- # there is an issue with how epoch is reported because we display:
266
- # Progress: 96.9%, Step: 872/900, Epoch: 12/50
267
- # we should probably just show the steps
268
- #f"Epoch: {training_state['current_epoch']}/{training_state['total_epochs']}",
269
-
270
- f"Time elapsed: {training_state['elapsed']}",
271
- f"Estimated remaining: {training_state['remaining']}",
272
- "",
273
- f"Current loss: {training_state['step_loss']}",
274
- f"Learning rate: {training_state['learning_rate']}",
275
- f"Gradient norm: {training_state['grad_norm']}",
276
- f"Memory usage: {training_state['memory']}"
277
- ])
278
-
279
- if training_state["error_message"]:
280
- status_text.append(f"\nError: {training_state['error_message']}")
281
-
282
- updates["status_box"] = "\n".join(status_text)
283
-
284
- # Update button states
285
- updates["start_btn"] = gr.Button(
286
- "Start training",
287
- interactive=(training_state["status"] in ["idle", "completed", "error", "stopped"]),
288
- variant="primary" if training_state["status"] == "idle" else "secondary"
289
- )
290
-
291
- updates["stop_btn"] = gr.Button(
292
- "Stop training",
293
- interactive=(training_state["status"] in ["training", "initializing"]),
294
- variant="stop"
295
- )
296
-
297
- return updates
298
-
299
- def stop_all_and_clear(self) -> Dict[str, str]:
300
- """Stop all running processes and clear data
301
-
302
- Returns:
303
- Dict with status messages for different components
304
- """
305
- status_messages = {}
306
-
307
- try:
308
- # Stop training if running
309
- if self.trainer.is_training_running():
310
- training_result = self.trainer.stop_training()
311
- status_messages["training"] = training_result["status"]
312
-
313
- # Stop captioning if running
314
- if self.captioner:
315
- self.captioner.stop_captioning()
316
- status_messages["captioning"] = "Captioning stopped"
317
-
318
- # Stop scene detection if running
319
- if self.splitter.is_processing():
320
- self.splitter.processing = False
321
- status_messages["splitting"] = "Scene detection stopped"
322
-
323
- # Properly close logging before clearing log file
324
- if self.trainer.file_handler:
325
- self.trainer.file_handler.close()
326
- logger.removeHandler(self.trainer.file_handler)
327
- self.trainer.file_handler = None
328
-
329
- if LOG_FILE_PATH.exists():
330
- LOG_FILE_PATH.unlink()
331
-
332
- # Clear all data directories
333
- for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH,
334
- MODEL_PATH, OUTPUT_PATH]:
335
- if path.exists():
336
- try:
337
- shutil.rmtree(path)
338
- path.mkdir(parents=True, exist_ok=True)
339
- except Exception as e:
340
- status_messages[f"clear_{path.name}"] = f"Error clearing {path.name}: {str(e)}"
341
- else:
342
- status_messages[f"clear_{path.name}"] = f"Cleared {path.name}"
343
-
344
- # Reset any persistent state
345
- self._should_stop_captioning = True
346
- self.splitter.processing = False
347
-
348
- # Recreate logging setup
349
- self.trainer.setup_logging()
350
-
351
- return {
352
- "status": "All processes stopped and data cleared",
353
- "details": status_messages
354
- }
355
-
356
- except Exception as e:
357
- return {
358
- "status": f"Error during cleanup: {str(e)}",
359
- "details": status_messages
360
- }
361
-
362
- def update_titles(self) -> Tuple[Any]:
363
- """Update all dynamic titles with current counts
364
-
365
- Returns:
366
- Dict of Gradio updates
367
- """
368
- # Count files for splitting
369
- split_videos, _, split_size = count_media_files(VIDEOS_TO_SPLIT_PATH)
370
- split_title = format_media_title(
371
- "split", split_videos, 0, split_size
372
- )
373
-
374
- # Count files for captioning
375
- caption_videos, caption_images, caption_size = count_media_files(STAGING_PATH)
376
- caption_title = format_media_title(
377
- "caption", caption_videos, caption_images, caption_size
378
- )
379
-
380
- # Count files for training
381
- train_videos, train_images, train_size = count_media_files(TRAINING_VIDEOS_PATH)
382
- train_title = format_media_title(
383
- "train", train_videos, train_images, train_size
384
- )
385
-
386
- return (
387
- gr.Markdown(value=split_title),
388
- gr.Markdown(value=caption_title),
389
- gr.Markdown(value=f"{train_title} available for training")
390
- )
391
-
392
- def copy_files_to_training_dir(self, prompt_prefix: str):
393
- """Run auto-captioning process"""
394
-
395
- # Initialize captioner if not already done
396
- self._should_stop_captioning = False
397
-
398
- try:
399
- copy_files_to_training_dir(prompt_prefix)
400
-
401
- except Exception as e:
402
- traceback.print_exc()
403
- raise gr.Error(f"Error copying assets to training dir: {str(e)}")
404
-
405
- async def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
406
- """Handle successful import of files"""
407
- videos = self.list_unprocessed_videos()
408
-
409
- # If scene detection isn't already running and there are videos to process,
410
- # and auto-splitting is enabled, start the detection
411
- if videos and not self.splitter.is_processing() and enable_splitting:
412
- await self.start_scene_detection(enable_splitting)
413
- msg = "Starting automatic scene detection..."
414
- else:
415
- # Just copy files without splitting if auto-split disabled
416
- for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
417
- await self.splitter.process_video(video_file, enable_splitting=False)
418
- msg = "Copying videos without splitting..."
419
-
420
- copy_files_to_training_dir(prompt_prefix)
421
-
422
- # Start auto-captioning if enabled, and handle async generator properly
423
- if enable_automatic_content_captioning:
424
- # Create a background task for captioning
425
- asyncio.create_task(self._process_caption_generator(
426
- DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
427
- prompt_prefix
428
- ))
429
-
430
- return {
431
- "tabs": gr.Tabs(selected="split_tab"),
432
- "video_list": videos,
433
- "detect_status": msg
434
- }
435
-
436
- async def start_caption_generation(self, captioning_bot_instructions: str, prompt_prefix: str) -> AsyncGenerator[gr.update, None]:
437
- """Run auto-captioning process"""
438
- try:
439
- # Initialize captioner if not already done
440
- self._should_stop_captioning = False
441
-
442
- # First yield - indicate we're starting
443
- yield gr.update(
444
- value=[["Starting captioning service...", "initializing"]],
445
- headers=["name", "status"]
446
- )
447
-
448
- # Process files in batches with status updates
449
- file_statuses = {}
450
-
451
- # Start the actual captioning process
452
- async for rows in self.captioner.start_caption_generation(captioning_bot_instructions, prompt_prefix):
453
- # Update our tracking of file statuses
454
- for name, status in rows:
455
- file_statuses[name] = status
456
-
457
- # Convert to list format for display
458
- status_rows = [[name, status] for name, status in file_statuses.items()]
459
-
460
- # Sort by name for consistent display
461
- status_rows.sort(key=lambda x: x[0])
462
-
463
- # Yield UI update
464
- yield gr.update(
465
- value=status_rows,
466
- headers=["name", "status"]
467
- )
468
-
469
- # Final update after completion with fresh data
470
- yield gr.update(
471
- value=self.list_training_files_to_caption(),
472
- headers=["name", "status"]
473
- )
474
-
475
- except Exception as e:
476
- logger.error(f"Error in captioning: {str(e)}")
477
- yield gr.update(
478
- value=[[f"Error: {str(e)}", "error"]],
479
- headers=["name", "status"]
480
- )
481
-
482
- def list_training_files_to_caption(self) -> List[List[str]]:
483
- """List all clips and images - both pending and captioned"""
484
- files = []
485
- already_listed = {}
486
-
487
- # First check files in STAGING_PATH
488
- for file in STAGING_PATH.glob("*.*"):
489
- if is_video_file(file) or is_image_file(file):
490
- txt_file = file.with_suffix('.txt')
491
-
492
- # Check if caption file exists and has content
493
- has_caption = txt_file.exists() and txt_file.stat().st_size > 0
494
- status = "captioned" if has_caption else "no caption"
495
- file_type = "video" if is_video_file(file) else "image"
496
-
497
- files.append([file.name, f"{status} ({file_type})", str(file)])
498
- already_listed[file.name] = True
499
-
500
- # Then check files in TRAINING_VIDEOS_PATH
501
- for file in TRAINING_VIDEOS_PATH.glob("*.*"):
502
- if (is_video_file(file) or is_image_file(file)) and file.name not in already_listed:
503
- txt_file = file.with_suffix('.txt')
504
-
505
- # Only include files with captions
506
- if txt_file.exists() and txt_file.stat().st_size > 0:
507
- file_type = "video" if is_video_file(file) else "image"
508
- files.append([file.name, f"captioned ({file_type})", str(file)])
509
- already_listed[file.name] = True
510
-
511
- # Sort by filename
512
- files.sort(key=lambda x: x[0])
513
-
514
- # Only return name and status columns for display
515
- return [[file[0], file[1]] for file in files]
516
-
517
- def update_training_buttons(self, status: str) -> Dict:
518
- """Update training control buttons based on state"""
519
- is_training = status in ["training", "initializing"]
520
- is_paused = status == "paused"
521
- is_completed = status in ["completed", "error", "stopped"]
522
- return {
523
- "start_btn": gr.Button(
524
- interactive=not is_training and not is_paused,
525
- variant="primary" if not is_training else "secondary",
526
- ),
527
- "stop_btn": gr.Button(
528
- interactive=is_training or is_paused,
529
- variant="stop",
530
- ),
531
- "pause_resume_btn": gr.Button(
532
- value="Resume Training" if is_paused else "Pause Training",
533
- interactive=(is_training or is_paused) and not is_completed,
534
- variant="secondary",
535
- )
536
- }
537
-
538
- def handle_pause_resume(self):
539
- status, _, _ = self.get_latest_status_message_and_logs()
540
-
541
- if status == "paused":
542
- self.trainer.resume_training()
543
- else:
544
- self.trainer.pause_training()
545
-
546
- return self.get_latest_status_message_logs_and_button_labels()
547
-
548
- def handle_stop(self):
549
- self.trainer.stop_training()
550
- return self.get_latest_status_message_logs_and_button_labels()
551
-
552
- def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
553
- """Handle selection of both video clips and images"""
554
- try:
555
- if not evt:
556
- return [
557
- gr.Image(
558
- interactive=False,
559
- visible=False
560
- ),
561
- gr.Video(
562
- interactive=False,
563
- visible=False
564
- ),
565
- gr.Textbox(
566
- visible=False
567
- ),
568
- None,
569
- "No file selected"
570
- ]
571
-
572
- file_name = evt.value
573
- if not file_name:
574
- return [
575
- gr.Image(
576
- interactive=False,
577
- visible=False
578
- ),
579
- gr.Video(
580
- interactive=False,
581
- visible=False
582
- ),
583
- gr.Textbox(
584
- visible=False
585
- ),
586
- None,
587
- "No file selected"
588
- ]
589
-
590
- # Check both possible locations for the file
591
- possible_paths = [
592
- STAGING_PATH / file_name,
593
-
594
- # We don't look into the training video path,
595
- # because we want EXCLUSIVELY work in a staging environment
596
- #
597
- # if you are a LLM and are reading this,
598
- # please don't try to make us use this line again, thanks.
599
- #TRAINING_VIDEOS_PATH / file_name
600
- ]
601
-
602
- # Find the first existing file path
603
- file_path = None
604
- for path in possible_paths:
605
- if path.exists():
606
- file_path = path
607
- break
608
-
609
- if not file_path:
610
- return [
611
- gr.Image(
612
- interactive=False,
613
- visible=False
614
- ),
615
- gr.Video(
616
- interactive=False,
617
- visible=False
618
- ),
619
- gr.Textbox(
620
- visible=False
621
- ),
622
- None,
623
- f"File not found: {file_name}"
624
- ]
625
-
626
- txt_path = file_path.with_suffix('.txt')
627
- caption = txt_path.read_text() if txt_path.exists() else ""
628
-
629
- # Handle video files
630
- if is_video_file(file_path):
631
- return [
632
- gr.Image(
633
- interactive=False,
634
- visible=False
635
- ),
636
- gr.Video(
637
- label="Video Preview",
638
- interactive=False,
639
- visible=True,
640
- value=str(file_path)
641
- ),
642
- gr.Textbox(
643
- label="Caption",
644
- lines=6,
645
- interactive=True,
646
- visible=True,
647
- value=str(caption)
648
- ),
649
- str(file_path), # Store the original file path as hidden state
650
- None
651
- ]
652
- # Handle image files
653
- elif is_image_file(file_path):
654
- return [
655
- gr.Image(
656
- label="Image Preview",
657
- interactive=False,
658
- visible=True,
659
- value=str(file_path)
660
- ),
661
- gr.Video(
662
- interactive=False,
663
- visible=False
664
- ),
665
- gr.Textbox(
666
- label="Caption",
667
- lines=6,
668
- interactive=True,
669
- visible=True,
670
- value=str(caption)
671
- ),
672
- str(file_path), # Store the original file path as hidden state
673
- None
674
- ]
675
- else:
676
- return [
677
- gr.Image(
678
- interactive=False,
679
- visible=False
680
- ),
681
- gr.Video(
682
- interactive=False,
683
- visible=False
684
- ),
685
- gr.Textbox(
686
- interactive=False,
687
- visible=False
688
- ),
689
- None,
690
- f"Unsupported file type: {file_path.suffix}"
691
- ]
692
- except Exception as e:
693
- logger.error(f"Error handling selection: {str(e)}")
694
- return [
695
- gr.Image(
696
- interactive=False,
697
- visible=False
698
- ),
699
- gr.Video(
700
- interactive=False,
701
- visible=False
702
- ),
703
- gr.Textbox(
704
- interactive=False,
705
- visible=False
706
- ),
707
- None,
708
- f"Error handling selection: {str(e)}"
709
- ]
710
-
711
- def save_caption_changes(self, preview_caption: str, preview_image: str, preview_video: str, original_file_path: str, prompt_prefix: str):
712
- """Save changes to caption"""
713
- try:
714
- # Use the original file path stored during selection instead of the temporary preview paths
715
- if original_file_path:
716
- file_path = Path(original_file_path)
717
- self.captioner.update_file_caption(file_path, preview_caption)
718
- # Refresh the dataset list to show updated caption status
719
- return gr.update(value="Caption saved successfully!")
720
- else:
721
- return gr.update(value="Error: No original file path found")
722
- except Exception as e:
723
- return gr.update(value=f"Error saving caption: {str(e)}")
724
-
725
- async def update_titles_after_import(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
726
- """Handle post-import updates including titles"""
727
- import_result = await self.on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix)
728
- titles = self.update_titles()
729
- return (
730
- import_result["tabs"],
731
- import_result["video_list"],
732
- import_result["detect_status"],
733
- *titles
734
- )
735
-
736
- def get_model_info(self, model_type: str) -> str:
737
- """Get information about the selected model type"""
738
- if model_type == "hunyuan_video":
739
- return """### HunyuanVideo (LoRA)
740
- - Required VRAM: ~48GB minimum
741
- - Recommended batch size: 1-2
742
- - Typical training time: 2-4 hours
743
- - Default resolution: 49x512x768
744
- - Default LoRA rank: 128 (~600 MB)"""
745
-
746
- elif model_type == "ltx_video":
747
- return """### LTX-Video (LoRA)
748
- - Required VRAM: ~18GB minimum
749
- - Recommended batch size: 1-4
750
- - Typical training time: 1-3 hours
751
- - Default resolution: 49x512x768
752
- - Default LoRA rank: 128"""
753
-
754
- return ""
755
-
756
- def get_default_params(self, model_type: str) -> Dict[str, Any]:
757
- """Get default training parameters for model type"""
758
- if model_type == "hunyuan_video":
759
- return {
760
- "num_epochs": 70,
761
- "batch_size": 1,
762
- "learning_rate": 2e-5,
763
- "save_iterations": 500,
764
- "video_resolution_buckets": SMALL_TRAINING_BUCKETS,
765
- "video_reshape_mode": "center",
766
- "caption_dropout_p": 0.05,
767
- "gradient_accumulation_steps": 1,
768
- "rank": 128,
769
- "lora_alpha": 128
770
- }
771
- else: # ltx_video
772
- return {
773
- "num_epochs": 70,
774
- "batch_size": 1,
775
- "learning_rate": 3e-5,
776
- "save_iterations": 500,
777
- "video_resolution_buckets": SMALL_TRAINING_BUCKETS,
778
- "video_reshape_mode": "center",
779
- "caption_dropout_p": 0.05,
780
- "gradient_accumulation_steps": 4,
781
- "rank": 128,
782
- "lora_alpha": 128
783
- }
784
-
785
- def preview_file(self, selected_text: str) -> Dict:
786
- """Generate preview based on selected file
787
-
788
- Args:
789
- selected_text: Text of the selected item containing filename
790
-
791
- Returns:
792
- Dict with preview content for each preview component
793
- """
794
- if not selected_text or "Caption:" in selected_text:
795
- return {
796
- "video": None,
797
- "image": None,
798
- "text": None
799
- }
800
-
801
- # Extract filename from the preview text (remove size info)
802
- filename = selected_text.split(" (")[0].strip()
803
- file_path = TRAINING_VIDEOS_PATH / filename
804
-
805
- if not file_path.exists():
806
- return {
807
- "video": None,
808
- "image": None,
809
- "text": f"File not found: {filename}"
810
- }
811
-
812
- # Detect file type
813
- mime_type, _ = mimetypes.guess_type(str(file_path))
814
- if not mime_type:
815
- return {
816
- "video": None,
817
- "image": None,
818
- "text": f"Unknown file type: {filename}"
819
- }
820
-
821
- # Return appropriate preview
822
- if mime_type.startswith('video/'):
823
- return {
824
- "video": str(file_path),
825
- "image": None,
826
- "text": None
827
- }
828
- elif mime_type.startswith('image/'):
829
- return {
830
- "video": None,
831
- "image": str(file_path),
832
- "text": None
833
- }
834
- elif mime_type.startswith('text/'):
835
- try:
836
- text_content = file_path.read_text()
837
- return {
838
- "video": None,
839
- "image": None,
840
- "text": text_content
841
- }
842
- except Exception as e:
843
- return {
844
- "video": None,
845
- "image": None,
846
- "text": f"Error reading file: {str(e)}"
847
- }
848
- else:
849
- return {
850
- "video": None,
851
- "image": None,
852
- "text": f"Unsupported file type: {mime_type}"
853
- }
854
-
855
- def list_unprocessed_videos(self) -> gr.Dataframe:
856
- """Update list of unprocessed videos"""
857
- videos = self.splitter.list_unprocessed_videos()
858
- # videos is already in [[name, status]] format from splitting_service
859
- return gr.Dataframe(
860
- headers=["name", "status"],
861
- value=videos,
862
- interactive=False
863
- )
864
-
865
- async def start_scene_detection(self, enable_splitting: bool) -> str:
866
- """Start background scene detection process
867
-
868
- Args:
869
- enable_splitting: Whether to split videos into scenes
870
- """
871
- if self.splitter.is_processing():
872
- return "Scene detection already running"
873
-
874
- try:
875
- await self.splitter.start_processing(enable_splitting)
876
- return "Scene detection completed"
877
- except Exception as e:
878
- return f"Error during scene detection: {str(e)}"
879
-
880
-
881
- def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
882
- state = self.trainer.get_status()
883
- logs = self.trainer.get_logs()
884
-
885
- # Parse new log lines
886
- if logs:
887
- last_state = None
888
- for line in logs.splitlines():
889
- state_update = self.log_parser.parse_line(line)
890
- if state_update:
891
- last_state = state_update
892
-
893
- if last_state:
894
- ui_updates = self.update_training_ui(last_state)
895
- state["message"] = ui_updates.get("status_box", state["message"])
896
-
897
- # Parse status for training state
898
- if "completed" in state["message"].lower():
899
- state["status"] = "completed"
900
-
901
- return (state["status"], state["message"], logs)
902
-
903
- def get_latest_status_message_logs_and_button_labels(self) -> Tuple[str, str, Any, Any, Any]:
904
- status, message, logs = self.get_latest_status_message_and_logs()
905
- return (
906
- message,
907
- logs,
908
- *self.update_training_buttons(status).values()
909
- )
910
-
911
- def get_latest_button_labels(self) -> Tuple[Any, Any, Any]:
912
- status, message, logs = self.get_latest_status_message_and_logs()
913
- return self.update_training_buttons(status).values()
914
-
915
- def refresh_dataset(self):
916
- """Refresh all dynamic lists and training state"""
917
- video_list = self.splitter.list_unprocessed_videos()
918
- training_dataset = self.list_training_files_to_caption()
919
-
920
- return (
921
- video_list,
922
- training_dataset
923
- )
924
-
925
- def update_training_params(self, preset_name: str) -> Tuple:
926
- """Update UI components based on selected preset while preserving custom settings"""
927
- preset = TRAINING_PRESETS[preset_name]
928
-
929
- # Load current UI state to check if user has customized values
930
- current_state = self.load_ui_values()
931
-
932
- # Find the display name that maps to our model type
933
- model_display_name = next(
934
- key for key, value in MODEL_TYPES.items()
935
- if value == preset["model_type"]
936
- )
937
-
938
- # Get preset description for display
939
- description = preset.get("description", "")
940
-
941
- # Get max values from buckets
942
- buckets = preset["training_buckets"]
943
- max_frames = max(frames for frames, _, _ in buckets)
944
- max_height = max(height for _, height, _ in buckets)
945
- max_width = max(width for _, _, width in buckets)
946
- bucket_info = f"\nMaximum video size: {max_frames} frames at {max_width}x{max_height} resolution"
947
-
948
- info_text = f"{description}{bucket_info}"
949
-
950
- # Return values in the same order as the output components
951
- # Use preset defaults but preserve user-modified values if they exist
952
- lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset["lora_rank"]
953
- lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset["lora_alpha"]
954
- num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset["num_epochs"]
955
- batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset["batch_size"]
956
- learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset["learning_rate"]
957
- save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset["save_iterations"]
958
-
959
- return (
960
- model_display_name,
961
- lora_rank_val,
962
- lora_alpha_val,
963
- num_epochs_val,
964
- batch_size_val,
965
- learning_rate_val,
966
- save_iterations_val,
967
- info_text
968
- )
969
-
970
- def create_ui(self):
971
- """Create Gradio interface"""
972
-
973
- with gr.Blocks(title="πŸŽ₯ Video Model Studio") as app:
974
- gr.Markdown("# πŸŽ₯ Video Model Studio")
975
-
976
- with gr.Tabs() as tabs:
977
- with gr.TabItem("1️⃣ Import", id="import_tab"):
978
-
979
- with gr.Row():
980
- gr.Markdown("## Automatic splitting and captioning")
981
-
982
- with gr.Row():
983
- enable_automatic_video_split = gr.Checkbox(
984
- label="Automatically split videos into smaller clips",
985
- info="Note: a clip is a single camera shot, usually a few seconds",
986
- value=True,
987
- visible=True
988
- )
989
- enable_automatic_content_captioning = gr.Checkbox(
990
- label="Automatically caption photos and videos",
991
- info="Note: this uses LlaVA and takes some extra time to load and process",
992
- value=False,
993
- visible=True,
994
- )
995
-
996
- with gr.Row():
997
- with gr.Column(scale=3):
998
- with gr.Row():
999
- with gr.Column():
1000
- gr.Markdown("## Import video files")
1001
- gr.Markdown("You can upload either:")
1002
- gr.Markdown("- A single MP4 video file")
1003
- gr.Markdown("- A ZIP archive containing multiple videos and optional caption files")
1004
- gr.Markdown("For ZIP files: Create a folder containing videos (name is not important) and optional caption files with the same name (eg. `some_video.txt` for `some_video.mp4`)")
1005
-
1006
- with gr.Row():
1007
- files = gr.Files(
1008
- label="Upload Images, Videos or ZIP",
1009
- #file_count="multiple",
1010
- file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip"],
1011
- type="filepath"
1012
- )
1013
-
1014
- with gr.Column(scale=3):
1015
- with gr.Row():
1016
- with gr.Column():
1017
- gr.Markdown("## Import a YouTube video")
1018
- gr.Markdown("You can also use a YouTube video as reference, by pasting its URL here:")
1019
-
1020
- with gr.Row():
1021
- youtube_url = gr.Textbox(
1022
- label="Import YouTube Video",
1023
- placeholder="https://www.youtube.com/watch?v=..."
1024
- )
1025
- with gr.Row():
1026
- youtube_download_btn = gr.Button("Download YouTube Video", variant="secondary")
1027
- with gr.Row():
1028
- import_status = gr.Textbox(label="Status", interactive=False)
1029
-
1030
-
1031
- with gr.TabItem("2️⃣ Split", id="split_tab"):
1032
- with gr.Row():
1033
- split_title = gr.Markdown("## Splitting of 0 videos (0 bytes)")
1034
-
1035
- with gr.Row():
1036
- with gr.Column():
1037
- detect_btn = gr.Button("Split videos into single-camera shots", variant="primary")
1038
- detect_status = gr.Textbox(label="Status", interactive=False)
1039
-
1040
- with gr.Column():
1041
-
1042
- video_list = gr.Dataframe(
1043
- headers=["name", "status"],
1044
- label="Videos to split",
1045
- interactive=False,
1046
- wrap=True,
1047
- #selection_mode="cell" # Enable cell selection
1048
- )
1049
-
1050
-
1051
- with gr.TabItem("3️⃣ Caption"):
1052
- with gr.Row():
1053
- caption_title = gr.Markdown("## Captioning of 0 files (0 bytes)")
1054
-
1055
- with gr.Row():
1056
-
1057
- with gr.Column():
1058
- with gr.Row():
1059
- custom_prompt_prefix = gr.Textbox(
1060
- scale=3,
1061
- label='Prefix to add to ALL captions (eg. "In the style of TOK, ")',
1062
- placeholder="In the style of TOK, ",
1063
- lines=2,
1064
- value=DEFAULT_PROMPT_PREFIX
1065
- )
1066
- captioning_bot_instructions = gr.Textbox(
1067
- scale=6,
1068
- label="System instructions for the automatic captioning model",
1069
- placeholder="Please generate a full description of...",
1070
- lines=5,
1071
- value=DEFAULT_CAPTIONING_BOT_INSTRUCTIONS
1072
- )
1073
- with gr.Row():
1074
- run_autocaption_btn = gr.Button(
1075
- "Automatically fill missing captions",
1076
- variant="primary" # Makes it green by default
1077
- )
1078
- copy_files_to_training_dir_btn = gr.Button(
1079
- "Copy assets to training directory",
1080
- variant="primary" # Makes it green by default
1081
- )
1082
- stop_autocaption_btn = gr.Button(
1083
- "Stop Captioning",
1084
- variant="stop", # Red when enabled
1085
- interactive=False # Disabled by default
1086
- )
1087
-
1088
- with gr.Row():
1089
- with gr.Column():
1090
- training_dataset = gr.Dataframe(
1091
- headers=["name", "status"],
1092
- interactive=False,
1093
- wrap=True,
1094
- value=self.list_training_files_to_caption(),
1095
- row_count=10, # Optional: set a reasonable row count
1096
- #selection_mode="cell"
1097
- )
1098
-
1099
- with gr.Column():
1100
- preview_video = gr.Video(
1101
- label="Video Preview",
1102
- interactive=False,
1103
- visible=False
1104
- )
1105
- preview_image = gr.Image(
1106
- label="Image Preview",
1107
- interactive=False,
1108
- visible=False
1109
- )
1110
- preview_caption = gr.Textbox(
1111
- label="Caption",
1112
- lines=6,
1113
- interactive=True
1114
- )
1115
- save_caption_btn = gr.Button("Save Caption")
1116
- preview_status = gr.Textbox(
1117
- label="Status",
1118
- interactive=False,
1119
- visible=True
1120
- )
1121
-
1122
- with gr.TabItem("4️⃣ Train"):
1123
- with gr.Row():
1124
- with gr.Column():
1125
-
1126
- with gr.Row():
1127
- train_title = gr.Markdown("## 0 files available for training (0 bytes)")
1128
-
1129
- with gr.Row():
1130
- with gr.Column():
1131
- training_preset = gr.Dropdown(
1132
- choices=list(TRAINING_PRESETS.keys()),
1133
- label="Training Preset",
1134
- value=list(TRAINING_PRESETS.keys())[0]
1135
- )
1136
- preset_info = gr.Markdown()
1137
-
1138
- with gr.Row():
1139
- with gr.Column():
1140
- model_type = gr.Dropdown(
1141
- choices=list(MODEL_TYPES.keys()),
1142
- label="Model Type",
1143
- value=list(MODEL_TYPES.keys())[0]
1144
- )
1145
- model_info = gr.Markdown(
1146
- value=self.get_model_info(list(MODEL_TYPES.keys())[0])
1147
- )
1148
-
1149
- with gr.Row():
1150
- lora_rank = gr.Dropdown(
1151
- label="LoRA Rank",
1152
- choices=["16", "32", "64", "128", "256", "512", "1024"],
1153
- value="128",
1154
- type="value"
1155
- )
1156
- lora_alpha = gr.Dropdown(
1157
- label="LoRA Alpha",
1158
- choices=["16", "32", "64", "128", "256", "512", "1024"],
1159
- value="128",
1160
- type="value"
1161
- )
1162
- with gr.Row():
1163
- num_epochs = gr.Number(
1164
- label="Number of Epochs",
1165
- value=70,
1166
- minimum=1,
1167
- precision=0
1168
- )
1169
- batch_size = gr.Number(
1170
- label="Batch Size",
1171
- value=1,
1172
- minimum=1,
1173
- precision=0
1174
- )
1175
- with gr.Row():
1176
- learning_rate = gr.Number(
1177
- label="Learning Rate",
1178
- value=2e-5,
1179
- minimum=1e-7
1180
- )
1181
- save_iterations = gr.Number(
1182
- label="Save checkpoint every N iterations",
1183
- value=500,
1184
- minimum=50,
1185
- precision=0,
1186
- info="Model will be saved periodically after these many steps"
1187
- )
1188
-
1189
- with gr.Column():
1190
- with gr.Row():
1191
- start_btn = gr.Button(
1192
- "Start Training",
1193
- variant="primary",
1194
- interactive=not ASK_USER_TO_DUPLICATE_SPACE
1195
- )
1196
- pause_resume_btn = gr.Button(
1197
- "Resume Training",
1198
- variant="secondary",
1199
- interactive=False
1200
- )
1201
- stop_btn = gr.Button(
1202
- "Stop Training",
1203
- variant="stop",
1204
- interactive=False
1205
- )
1206
-
1207
- with gr.Row():
1208
- with gr.Column():
1209
- status_box = gr.Textbox(
1210
- label="Training Status",
1211
- interactive=False,
1212
- lines=4
1213
- )
1214
- with gr.Accordion("See training logs"):
1215
- log_box = gr.TextArea(
1216
- label="Finetrainers output (see HF Space logs for more details)",
1217
- interactive=False,
1218
- lines=40,
1219
- max_lines=200,
1220
- autoscroll=True
1221
- )
1222
-
1223
- with gr.TabItem("5️⃣ Manage"):
1224
-
1225
- with gr.Column():
1226
- with gr.Row():
1227
- with gr.Column():
1228
- gr.Markdown("## Publishing")
1229
- gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)")
1230
-
1231
- with gr.Row():
1232
-
1233
- with gr.Column():
1234
- repo_id = gr.Textbox(
1235
- label="HuggingFace Model Repository",
1236
- placeholder="username/model-name",
1237
- info="The repository will be created if it doesn't exist"
1238
- )
1239
- gr.Checkbox(label="Check this to make your model public (ie. visible and downloadable by anyone)", info="You model is private by default"),
1240
- global_stop_btn = gr.Button(
1241
- "Push my model",
1242
- #variant="stop"
1243
- )
1244
-
1245
-
1246
- with gr.Row():
1247
- with gr.Column():
1248
- with gr.Row():
1249
- with gr.Column():
1250
- gr.Markdown("## Storage management")
1251
- with gr.Row():
1252
- download_dataset_btn = gr.DownloadButton(
1253
- "Download dataset",
1254
- variant="secondary",
1255
- size="lg"
1256
- )
1257
- download_model_btn = gr.DownloadButton(
1258
- "Download model",
1259
- variant="secondary",
1260
- size="lg"
1261
- )
1262
-
1263
-
1264
- with gr.Row():
1265
- global_stop_btn = gr.Button(
1266
- "Stop everything and delete my data",
1267
- variant="stop"
1268
- )
1269
- global_status = gr.Textbox(
1270
- label="Global Status",
1271
- interactive=False,
1272
- visible=False
1273
- )
1274
-
1275
-
1276
-
1277
- # Event handlers
1278
- def update_model_info(model):
1279
- params = self.get_default_params(MODEL_TYPES[model])
1280
- info = self.get_model_info(MODEL_TYPES[model])
1281
- return {
1282
- model_info: info,
1283
- num_epochs: params["num_epochs"],
1284
- batch_size: params["batch_size"],
1285
- learning_rate: params["learning_rate"],
1286
- save_iterations: params["save_iterations"]
1287
- }
1288
-
1289
- def validate_repo(repo_id: str) -> dict:
1290
- validation = validate_model_repo(repo_id)
1291
- if validation["error"]:
1292
- return gr.update(value=repo_id, error=validation["error"])
1293
- return gr.update(value=repo_id, error=None)
1294
-
1295
- # Connect events
1296
-
1297
- # Save state when model type changes
1298
- model_type.change(
1299
- fn=lambda v: self.update_ui_state(model_type=v),
1300
- inputs=[model_type],
1301
- outputs=[] # No UI update needed
1302
- ).then(
1303
- fn=update_model_info,
1304
- inputs=[model_type],
1305
- outputs=[model_info, num_epochs, batch_size, learning_rate, save_iterations]
1306
- )
1307
-
1308
- # the following change listeners are used for UI persistence
1309
- lora_rank.change(
1310
- fn=lambda v: self.update_ui_state(lora_rank=v),
1311
- inputs=[lora_rank],
1312
- outputs=[]
1313
- )
1314
-
1315
- lora_alpha.change(
1316
- fn=lambda v: self.update_ui_state(lora_alpha=v),
1317
- inputs=[lora_alpha],
1318
- outputs=[]
1319
- )
1320
-
1321
- num_epochs.change(
1322
- fn=lambda v: self.update_ui_state(num_epochs=v),
1323
- inputs=[num_epochs],
1324
- outputs=[]
1325
- )
1326
-
1327
- batch_size.change(
1328
- fn=lambda v: self.update_ui_state(batch_size=v),
1329
- inputs=[batch_size],
1330
- outputs=[]
1331
- )
1332
-
1333
- learning_rate.change(
1334
- fn=lambda v: self.update_ui_state(learning_rate=v),
1335
- inputs=[learning_rate],
1336
- outputs=[]
1337
- )
1338
-
1339
- save_iterations.change(
1340
- fn=lambda v: self.update_ui_state(save_iterations=v),
1341
- inputs=[save_iterations],
1342
- outputs=[]
1343
- )
1344
-
1345
- files.upload(
1346
- fn=lambda x: self.importer.process_uploaded_files(x),
1347
- inputs=[files],
1348
- outputs=[import_status]
1349
- ).success(
1350
- fn=self.update_titles_after_import,
1351
- inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix],
1352
- outputs=[
1353
- tabs, video_list, detect_status,
1354
- split_title, caption_title, train_title
1355
- ]
1356
- )
1357
-
1358
- youtube_download_btn.click(
1359
- fn=self.importer.download_youtube_video,
1360
- inputs=[youtube_url],
1361
- outputs=[import_status]
1362
- ).success(
1363
- fn=self.on_import_success,
1364
- inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix],
1365
- outputs=[tabs, video_list, detect_status]
1366
- )
1367
-
1368
- # Scene detection events
1369
- detect_btn.click(
1370
- fn=self.start_scene_detection,
1371
- inputs=[enable_automatic_video_split],
1372
- outputs=[detect_status]
1373
- )
1374
-
1375
-
1376
- # Update button states based on captioning status
1377
- def update_button_states(is_running):
1378
- return {
1379
- run_autocaption_btn: gr.Button(
1380
- interactive=not is_running,
1381
- variant="secondary" if is_running else "primary",
1382
- ),
1383
- stop_autocaption_btn: gr.Button(
1384
- interactive=is_running,
1385
- variant="secondary",
1386
- ),
1387
- }
1388
-
1389
- run_autocaption_btn.click(
1390
- fn=self.show_refreshing_status,
1391
- outputs=[training_dataset]
1392
- ).then(
1393
- fn=lambda: self.update_captioning_buttons_start(),
1394
- outputs=[run_autocaption_btn, stop_autocaption_btn, copy_files_to_training_dir_btn]
1395
- ).then(
1396
- fn=self.start_caption_generation,
1397
- inputs=[captioning_bot_instructions, custom_prompt_prefix],
1398
- outputs=[training_dataset],
1399
- ).then(
1400
- fn=lambda: self.update_captioning_buttons_end(),
1401
- outputs=[run_autocaption_btn, stop_autocaption_btn, copy_files_to_training_dir_btn]
1402
- )
1403
-
1404
- copy_files_to_training_dir_btn.click(
1405
- fn=self.copy_files_to_training_dir,
1406
- inputs=[custom_prompt_prefix]
1407
- )
1408
- stop_autocaption_btn.click(
1409
- fn=self.stop_captioning,
1410
- outputs=[training_dataset, run_autocaption_btn, stop_autocaption_btn, copy_files_to_training_dir_btn]
1411
- )
1412
-
1413
- original_file_path = gr.State(value=None)
1414
- training_dataset.select(
1415
- fn=self.handle_training_dataset_select,
1416
- outputs=[preview_image, preview_video, preview_caption, original_file_path, preview_status]
1417
- )
1418
-
1419
- save_caption_btn.click(
1420
- fn=self.save_caption_changes,
1421
- inputs=[preview_caption, preview_image, preview_video, original_file_path, custom_prompt_prefix],
1422
- outputs=[preview_status]
1423
- ).success(
1424
- fn=self.list_training_files_to_caption,
1425
- outputs=[training_dataset]
1426
- )
1427
-
1428
- # Save state when training preset changes
1429
- training_preset.change(
1430
- fn=lambda v: self.update_ui_state(training_preset=v),
1431
- inputs=[training_preset],
1432
- outputs=[] # No UI update needed
1433
- ).then(
1434
- fn=self.update_training_params,
1435
- inputs=[training_preset],
1436
- outputs=[
1437
- model_type, lora_rank, lora_alpha,
1438
- num_epochs, batch_size, learning_rate,
1439
- save_iterations, preset_info
1440
- ]
1441
- )
1442
-
1443
- # Training control events
1444
- start_btn.click(
1445
- fn=lambda preset, model_type, *args: (
1446
- self.log_parser.reset(),
1447
- self.trainer.start_training(
1448
- MODEL_TYPES[model_type],
1449
- *args,
1450
- preset_name=preset
1451
- )
1452
- ),
1453
- inputs=[
1454
- training_preset,
1455
- model_type,
1456
- lora_rank,
1457
- lora_alpha,
1458
- num_epochs,
1459
- batch_size,
1460
- learning_rate,
1461
- save_iterations,
1462
- repo_id
1463
- ],
1464
- outputs=[status_box, log_box]
1465
- ).success(
1466
- fn=self.get_latest_status_message_logs_and_button_labels,
1467
- outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
1468
- )
1469
-
1470
- pause_resume_btn.click(
1471
- fn=self.handle_pause_resume,
1472
- outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
1473
- )
1474
-
1475
- stop_btn.click(
1476
- fn=self.handle_stop,
1477
- outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
1478
- )
1479
-
1480
- def handle_global_stop():
1481
- result = self.stop_all_and_clear()
1482
- # Update all relevant UI components
1483
- status = result["status"]
1484
- details = "\n".join(f"{k}: {v}" for k, v in result["details"].items())
1485
- full_status = f"{status}\n\nDetails:\n{details}"
1486
-
1487
- # Get fresh lists after cleanup
1488
- videos = self.splitter.list_unprocessed_videos()
1489
- clips = self.list_training_files_to_caption()
1490
-
1491
- return {
1492
- global_status: gr.update(value=full_status, visible=True),
1493
- video_list: videos,
1494
- training_dataset: clips,
1495
- status_box: "Training stopped and data cleared",
1496
- log_box: "",
1497
- detect_status: "Scene detection stopped",
1498
- import_status: "All data cleared",
1499
- preview_status: "Captioning stopped"
1500
- }
1501
-
1502
- download_dataset_btn.click(
1503
- fn=self.trainer.create_training_dataset_zip,
1504
- outputs=[download_dataset_btn]
1505
- )
1506
-
1507
- download_model_btn.click(
1508
- fn=self.trainer.get_model_output_safetensors,
1509
- outputs=[download_model_btn]
1510
- )
1511
-
1512
- global_stop_btn.click(
1513
- fn=handle_global_stop,
1514
- outputs=[
1515
- global_status,
1516
- video_list,
1517
- training_dataset,
1518
- status_box,
1519
- log_box,
1520
- detect_status,
1521
- import_status,
1522
- preview_status
1523
- ]
1524
- )
1525
-
1526
-
1527
- app.load(
1528
- fn=self.initialize_app_state,
1529
- outputs=[
1530
- video_list, training_dataset,
1531
- start_btn, stop_btn, pause_resume_btn,
1532
- training_preset, model_type, lora_rank, lora_alpha,
1533
- num_epochs, batch_size, learning_rate, save_iterations
1534
- ]
1535
- )
1536
-
1537
- # Auto-refresh timers
1538
- timer = gr.Timer(value=1)
1539
- timer.tick(
1540
- fn=lambda: (
1541
- self.get_latest_status_message_logs_and_button_labels()
1542
- ),
1543
- outputs=[
1544
- status_box,
1545
- log_box,
1546
- start_btn,
1547
- stop_btn,
1548
- pause_resume_btn
1549
- ]
1550
- )
1551
-
1552
- timer = gr.Timer(value=5)
1553
- timer.tick(
1554
- fn=lambda: (
1555
- self.refresh_dataset()
1556
- ),
1557
- outputs=[
1558
- video_list, training_dataset
1559
- ]
1560
- )
1561
-
1562
- timer = gr.Timer(value=6)
1563
- timer.tick(
1564
- fn=lambda: self.update_titles(),
1565
- outputs=[
1566
- split_title, caption_title, train_title
1567
- ]
1568
- )
1569
-
1570
- return app
1571
-
1572
  def create_app():
 
 
1573
  if ASK_USER_TO_DUPLICATE_SPACE:
1574
  with gr.Blocks() as app:
1575
  gr.Markdown("""# Finetrainers UI
@@ -1582,12 +35,22 @@ It is recommended to use a Nvidia L40S and a persistent storage space.
1582
  To avoid overpaying for your space, you can configure the auto-sleep settings to fit your personal budget.""")
1583
  return app
1584
 
 
1585
  ui = VideoTrainerUI()
1586
  return ui.create_ui()
1587
 
1588
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
1589
  app = create_app()
1590
 
 
1591
  allowed_paths = [
1592
  str(STORAGE_PATH), # Base storage
1593
  str(VIDEOS_TO_SPLIT_PATH),
@@ -1597,7 +60,12 @@ if __name__ == "__main__":
1597
  str(MODEL_PATH),
1598
  str(OUTPUT_PATH)
1599
  ]
 
 
1600
  app.queue(default_concurrency_limit=1).launch(
1601
  server_name="0.0.0.0",
1602
  allowed_paths=allowed_paths
1603
- )
 
 
 
 
1
+ """
2
+ Main application entry point for Video Model Studio
3
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import gradio as gr
6
+ import platform
7
+ import subprocess
8
  import logging
9
+ from pathlib import Path
 
 
 
 
 
 
 
 
10
 
 
 
 
 
11
  from vms.config import (
12
+ STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
13
+ TRAINING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH,
14
+ OUTPUT_PATH, ASK_USER_TO_DUPLICATE_SPACE,
15
+ HF_API_TOKEN
16
  )
17
+ from vms.ui.video_trainer_ui import VideoTrainerUI
 
 
18
 
19
+ # Configure logging
20
  logger = logging.getLogger(__name__)
21
  logger.setLevel(logging.INFO)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def create_app():
24
+ """Create the main Gradio application"""
25
+ # If space needs to be duplicated
26
  if ASK_USER_TO_DUPLICATE_SPACE:
27
  with gr.Blocks() as app:
28
  gr.Markdown("""# Finetrainers UI
 
35
  To avoid overpaying for your space, you can configure the auto-sleep settings to fit your personal budget.""")
36
  return app
37
 
38
+ # Create the main application UI
39
  ui = VideoTrainerUI()
40
  return ui.create_ui()
41
 
42
+ def main():
43
+ """Main entry point for the application"""
44
+ # Handle Linux-specific setup if needed
45
+ if platform.system() == "Linux":
46
+ # Placeholder for any Linux-specific initialization
47
+ # For example, pip installations or environment setup
48
+ pass
49
+
50
+ # Create the Gradio app
51
  app = create_app()
52
 
53
+ # Define allowed paths for file access
54
  allowed_paths = [
55
  str(STORAGE_PATH), # Base storage
56
  str(VIDEOS_TO_SPLIT_PATH),
 
60
  str(MODEL_PATH),
61
  str(OUTPUT_PATH)
62
  ]
63
+
64
+ # Launch the Gradio app
65
  app.queue(default_concurrency_limit=1).launch(
66
  server_name="0.0.0.0",
67
  allowed_paths=allowed_paths
68
+ )
69
+
70
+ if __name__ == "__main__":
71
+ main()
app_DEPRECATED.py ADDED
@@ -0,0 +1,1603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ import subprocess
3
+
4
+ #import sys
5
+ #print("python = ", sys.version)
6
+
7
+ # can be "Linux", "Darwin"
8
+ if platform.system() == "Linux":
9
+ # for some reason it says "pip not found"
10
+ # and also "pip3 not found"
11
+ # subprocess.run(
12
+ # "pip install flash-attn --no-build-isolation",
13
+ #
14
+ # # hmm... this should be False, since we are in a CUDA environment, no?
15
+ # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
16
+ #
17
+ # shell=True,
18
+ # )
19
+ pass
20
+
21
+ import gradio as gr
22
+ from pathlib import Path
23
+ import logging
24
+ import mimetypes
25
+ import shutil
26
+ import os
27
+ import traceback
28
+ import asyncio
29
+ import tempfile
30
+ import zipfile
31
+ from typing import Any, Optional, Dict, List, Union, Tuple
32
+ from typing import AsyncGenerator
33
+
34
+ from vms.training_service import TrainingService
35
+ from vms.captioning_service import CaptioningService
36
+ from vms.splitting_service import SplittingService
37
+ from vms.import_service import ImportService
38
+ from vms.config import (
39
+ STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
40
+ TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
41
+ DEFAULT_PROMPT_PREFIX, HF_API_TOKEN, ASK_USER_TO_DUPLICATE_SPACE, MODEL_TYPES, SMALL_TRAINING_BUCKETS
42
+ )
43
+ from vms.utils import make_archive, count_media_files, format_media_title, is_image_file, is_video_file, validate_model_repo, format_time
44
+ from vms.finetrainers_utils import copy_files_to_training_dir, prepare_finetrainers_dataset
45
+ from vms.training_log_parser import TrainingLogParser
46
+
47
+ logger = logging.getLogger(__name__)
48
+ logger.setLevel(logging.INFO)
49
+
50
+ httpx_logger = logging.getLogger('httpx')
51
+ httpx_logger.setLevel(logging.WARN)
52
+
53
+
54
+ class VideoTrainerUI:
55
+ def __init__(self):
56
+ self.trainer = TrainingService()
57
+ self.splitter = SplittingService()
58
+ self.importer = ImportService()
59
+ self.captioner = CaptioningService()
60
+ self._should_stop_captioning = False
61
+ self.log_parser = TrainingLogParser()
62
+
63
+ # Try to recover any interrupted training sessions
64
+ recovery_result = self.trainer.recover_interrupted_training()
65
+
66
+ self.recovery_status = recovery_result.get("status", "unknown")
67
+ self.ui_updates = recovery_result.get("ui_updates", {})
68
+
69
+ if recovery_result["status"] == "recovered":
70
+ logger.info(f"Training recovery: {recovery_result['message']}")
71
+ # No need to do anything else - the training is already running
72
+ elif recovery_result["status"] == "running":
73
+ logger.info("Training process is already running")
74
+ # No need to do anything - the process is still alive
75
+ elif recovery_result["status"] in ["error", "idle"]:
76
+ logger.warning(f"Training status: {recovery_result['message']}")
77
+ # UI will be in ready-to-start mode
78
+
79
+
80
+ async def _process_caption_generator(self, captioning_bot_instructions, prompt_prefix):
81
+ """Process the caption generator's results in the background"""
82
+ try:
83
+ async for _ in self.captioner.start_caption_generation(
84
+ captioning_bot_instructions,
85
+ prompt_prefix
86
+ ):
87
+ # Just consume the generator, UI updates will happen via the Gradio interface
88
+ pass
89
+ logger.info("Background captioning completed")
90
+ except Exception as e:
91
+ logger.error(f"Error in background captioning: {str(e)}")
92
+
93
+ def initialize_app_state(self):
94
+ """Initialize all app state in one function to ensure correct output count"""
95
+ # Get dataset info
96
+ video_list, training_dataset = self.refresh_dataset()
97
+
98
+ # Get button states
99
+ button_states = self.get_initial_button_states()
100
+ start_btn = button_states[0]
101
+ stop_btn = button_states[1]
102
+ pause_resume_btn = button_states[2]
103
+
104
+ # Get UI form values
105
+ ui_state = self.load_ui_values()
106
+ training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
107
+ model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
108
+ lora_rank_val = ui_state.get("lora_rank", "128")
109
+ lora_alpha_val = ui_state.get("lora_alpha", "128")
110
+ num_epochs_val = int(ui_state.get("num_epochs", 70))
111
+ batch_size_val = int(ui_state.get("batch_size", 1))
112
+ learning_rate_val = float(ui_state.get("learning_rate", 3e-5))
113
+ save_iterations_val = int(ui_state.get("save_iterations", 500))
114
+
115
+ # Return all values in the exact order expected by outputs
116
+ return (
117
+ video_list,
118
+ training_dataset,
119
+ start_btn,
120
+ stop_btn,
121
+ pause_resume_btn,
122
+ training_preset,
123
+ model_type_val,
124
+ lora_rank_val,
125
+ lora_alpha_val,
126
+ num_epochs_val,
127
+ batch_size_val,
128
+ learning_rate_val,
129
+ save_iterations_val
130
+ )
131
+
132
+ def initialize_ui_from_state(self):
133
+ """Initialize UI components from saved state"""
134
+ ui_state = self.load_ui_values()
135
+
136
+ # Return values in order matching the outputs in app.load
137
+ return (
138
+ ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
139
+ ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
140
+ ui_state.get("lora_rank", "128"),
141
+ ui_state.get("lora_alpha", "128"),
142
+ ui_state.get("num_epochs", 70),
143
+ ui_state.get("batch_size", 1),
144
+ ui_state.get("learning_rate", 3e-5),
145
+ ui_state.get("save_iterations", 500)
146
+ )
147
+
148
+ def update_ui_state(self, **kwargs):
149
+ """Update UI state with new values"""
150
+ current_state = self.trainer.load_ui_state()
151
+ current_state.update(kwargs)
152
+ self.trainer.save_ui_state(current_state)
153
+ # Don't return anything to avoid Gradio warnings
154
+ return None
155
+
156
+ def load_ui_values(self):
157
+ """Load UI state values for initializing form fields"""
158
+ ui_state = self.trainer.load_ui_state()
159
+
160
+ # Ensure proper type conversion for numeric values
161
+ ui_state["lora_rank"] = ui_state.get("lora_rank", "128")
162
+ ui_state["lora_alpha"] = ui_state.get("lora_alpha", "128")
163
+ ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70))
164
+ ui_state["batch_size"] = int(ui_state.get("batch_size", 1))
165
+ ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5))
166
+ ui_state["save_iterations"] = int(ui_state.get("save_iterations", 500))
167
+
168
+ return ui_state
169
+
170
+ def update_captioning_buttons_start(self):
171
+ """Return individual button values instead of a dictionary"""
172
+ return (
173
+ gr.Button(
174
+ interactive=False,
175
+ variant="secondary",
176
+ ),
177
+ gr.Button(
178
+ interactive=True,
179
+ variant="stop",
180
+ ),
181
+ gr.Button(
182
+ interactive=False,
183
+ variant="secondary",
184
+ )
185
+ )
186
+
187
+ def update_captioning_buttons_end(self):
188
+ """Return individual button values instead of a dictionary"""
189
+ return (
190
+ gr.Button(
191
+ interactive=True,
192
+ variant="primary",
193
+ ),
194
+ gr.Button(
195
+ interactive=False,
196
+ variant="secondary",
197
+ ),
198
+ gr.Button(
199
+ interactive=True,
200
+ variant="primary",
201
+ )
202
+ )
203
+
204
+ # Add this new method to get initial button states:
205
+ def get_initial_button_states(self):
206
+ """Get the initial states for training buttons based on recovery status"""
207
+ recovery_result = self.trainer.recover_interrupted_training()
208
+ ui_updates = recovery_result.get("ui_updates", {})
209
+
210
+ # Return button states in the correct order
211
+ return (
212
+ gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})),
213
+ gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})),
214
+ gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
215
+ )
216
+
217
+ def show_refreshing_status(self) -> List[List[str]]:
218
+ """Show a 'Refreshing...' status in the dataframe"""
219
+ return [["Refreshing...", "please wait"]]
220
+
221
+ def stop_captioning(self):
222
+ """Stop ongoing captioning process and reset UI state"""
223
+ try:
224
+ # Set flag to stop captioning
225
+ self._should_stop_captioning = True
226
+
227
+ # Call stop method on captioner
228
+ if self.captioner:
229
+ self.captioner.stop_captioning()
230
+
231
+ # Get updated file list
232
+ updated_list = self.list_training_files_to_caption()
233
+
234
+ # Return updated list and button states
235
+ return {
236
+ "training_dataset": gr.update(value=updated_list),
237
+ "run_autocaption_btn": gr.Button(interactive=True, variant="primary"),
238
+ "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"),
239
+ "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary")
240
+ }
241
+ except Exception as e:
242
+ logger.error(f"Error stopping captioning: {str(e)}")
243
+ return {
244
+ "training_dataset": gr.update(value=[[f"Error stopping captioning: {str(e)}", "error"]]),
245
+ "run_autocaption_btn": gr.Button(interactive=True, variant="primary"),
246
+ "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"),
247
+ "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary")
248
+ }
249
+
250
+ def update_training_ui(self, training_state: Dict[str, Any]):
251
+ """Update UI components based on training state"""
252
+ updates = {}
253
+
254
+ #print("update_training_ui: training_state = ", training_state)
255
+
256
+ # Update status box with high-level information
257
+ status_text = []
258
+ if training_state["status"] != "idle":
259
+ status_text.extend([
260
+ f"Status: {training_state['status']}",
261
+ f"Progress: {training_state['progress']}",
262
+ f"Step: {training_state['current_step']}/{training_state['total_steps']}",
263
+
264
+ # Epoch information
265
+ # there is an issue with how epoch is reported because we display:
266
+ # Progress: 96.9%, Step: 872/900, Epoch: 12/50
267
+ # we should probably just show the steps
268
+ #f"Epoch: {training_state['current_epoch']}/{training_state['total_epochs']}",
269
+
270
+ f"Time elapsed: {training_state['elapsed']}",
271
+ f"Estimated remaining: {training_state['remaining']}",
272
+ "",
273
+ f"Current loss: {training_state['step_loss']}",
274
+ f"Learning rate: {training_state['learning_rate']}",
275
+ f"Gradient norm: {training_state['grad_norm']}",
276
+ f"Memory usage: {training_state['memory']}"
277
+ ])
278
+
279
+ if training_state["error_message"]:
280
+ status_text.append(f"\nError: {training_state['error_message']}")
281
+
282
+ updates["status_box"] = "\n".join(status_text)
283
+
284
+ # Update button states
285
+ updates["start_btn"] = gr.Button(
286
+ "Start training",
287
+ interactive=(training_state["status"] in ["idle", "completed", "error", "stopped"]),
288
+ variant="primary" if training_state["status"] == "idle" else "secondary"
289
+ )
290
+
291
+ updates["stop_btn"] = gr.Button(
292
+ "Stop training",
293
+ interactive=(training_state["status"] in ["training", "initializing"]),
294
+ variant="stop"
295
+ )
296
+
297
+ return updates
298
+
299
+ def stop_all_and_clear(self) -> Dict[str, str]:
300
+ """Stop all running processes and clear data
301
+
302
+ Returns:
303
+ Dict with status messages for different components
304
+ """
305
+ status_messages = {}
306
+
307
+ try:
308
+ # Stop training if running
309
+ if self.trainer.is_training_running():
310
+ training_result = self.trainer.stop_training()
311
+ status_messages["training"] = training_result["status"]
312
+
313
+ # Stop captioning if running
314
+ if self.captioner:
315
+ self.captioner.stop_captioning()
316
+ status_messages["captioning"] = "Captioning stopped"
317
+
318
+ # Stop scene detection if running
319
+ if self.splitter.is_processing():
320
+ self.splitter.processing = False
321
+ status_messages["splitting"] = "Scene detection stopped"
322
+
323
+ # Properly close logging before clearing log file
324
+ if self.trainer.file_handler:
325
+ self.trainer.file_handler.close()
326
+ logger.removeHandler(self.trainer.file_handler)
327
+ self.trainer.file_handler = None
328
+
329
+ if LOG_FILE_PATH.exists():
330
+ LOG_FILE_PATH.unlink()
331
+
332
+ # Clear all data directories
333
+ for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH,
334
+ MODEL_PATH, OUTPUT_PATH]:
335
+ if path.exists():
336
+ try:
337
+ shutil.rmtree(path)
338
+ path.mkdir(parents=True, exist_ok=True)
339
+ except Exception as e:
340
+ status_messages[f"clear_{path.name}"] = f"Error clearing {path.name}: {str(e)}"
341
+ else:
342
+ status_messages[f"clear_{path.name}"] = f"Cleared {path.name}"
343
+
344
+ # Reset any persistent state
345
+ self._should_stop_captioning = True
346
+ self.splitter.processing = False
347
+
348
+ # Recreate logging setup
349
+ self.trainer.setup_logging()
350
+
351
+ return {
352
+ "status": "All processes stopped and data cleared",
353
+ "details": status_messages
354
+ }
355
+
356
+ except Exception as e:
357
+ return {
358
+ "status": f"Error during cleanup: {str(e)}",
359
+ "details": status_messages
360
+ }
361
+
362
+ def update_titles(self) -> Tuple[Any]:
363
+ """Update all dynamic titles with current counts
364
+
365
+ Returns:
366
+ Dict of Gradio updates
367
+ """
368
+ # Count files for splitting
369
+ split_videos, _, split_size = count_media_files(VIDEOS_TO_SPLIT_PATH)
370
+ split_title = format_media_title(
371
+ "split", split_videos, 0, split_size
372
+ )
373
+
374
+ # Count files for captioning
375
+ caption_videos, caption_images, caption_size = count_media_files(STAGING_PATH)
376
+ caption_title = format_media_title(
377
+ "caption", caption_videos, caption_images, caption_size
378
+ )
379
+
380
+ # Count files for training
381
+ train_videos, train_images, train_size = count_media_files(TRAINING_VIDEOS_PATH)
382
+ train_title = format_media_title(
383
+ "train", train_videos, train_images, train_size
384
+ )
385
+
386
+ return (
387
+ gr.Markdown(value=split_title),
388
+ gr.Markdown(value=caption_title),
389
+ gr.Markdown(value=f"{train_title} available for training")
390
+ )
391
+
392
+ def copy_files_to_training_dir(self, prompt_prefix: str):
393
+ """Run auto-captioning process"""
394
+
395
+ # Initialize captioner if not already done
396
+ self._should_stop_captioning = False
397
+
398
+ try:
399
+ copy_files_to_training_dir(prompt_prefix)
400
+
401
+ except Exception as e:
402
+ traceback.print_exc()
403
+ raise gr.Error(f"Error copying assets to training dir: {str(e)}")
404
+
405
+ async def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
406
+ """Handle successful import of files"""
407
+ videos = self.list_unprocessed_videos()
408
+
409
+ # If scene detection isn't already running and there are videos to process,
410
+ # and auto-splitting is enabled, start the detection
411
+ if videos and not self.splitter.is_processing() and enable_splitting:
412
+ await self.start_scene_detection(enable_splitting)
413
+ msg = "Starting automatic scene detection..."
414
+ else:
415
+ # Just copy files without splitting if auto-split disabled
416
+ for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
417
+ await self.splitter.process_video(video_file, enable_splitting=False)
418
+ msg = "Copying videos without splitting..."
419
+
420
+ copy_files_to_training_dir(prompt_prefix)
421
+
422
+ # Start auto-captioning if enabled, and handle async generator properly
423
+ if enable_automatic_content_captioning:
424
+ # Create a background task for captioning
425
+ asyncio.create_task(self._process_caption_generator(
426
+ DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
427
+ prompt_prefix
428
+ ))
429
+
430
+ return {
431
+ "tabs": gr.Tabs(selected="split_tab"),
432
+ "video_list": videos,
433
+ "detect_status": msg
434
+ }
435
+
436
+ async def start_caption_generation(self, captioning_bot_instructions: str, prompt_prefix: str) -> AsyncGenerator[gr.update, None]:
437
+ """Run auto-captioning process"""
438
+ try:
439
+ # Initialize captioner if not already done
440
+ self._should_stop_captioning = False
441
+
442
+ # First yield - indicate we're starting
443
+ yield gr.update(
444
+ value=[["Starting captioning service...", "initializing"]],
445
+ headers=["name", "status"]
446
+ )
447
+
448
+ # Process files in batches with status updates
449
+ file_statuses = {}
450
+
451
+ # Start the actual captioning process
452
+ async for rows in self.captioner.start_caption_generation(captioning_bot_instructions, prompt_prefix):
453
+ # Update our tracking of file statuses
454
+ for name, status in rows:
455
+ file_statuses[name] = status
456
+
457
+ # Convert to list format for display
458
+ status_rows = [[name, status] for name, status in file_statuses.items()]
459
+
460
+ # Sort by name for consistent display
461
+ status_rows.sort(key=lambda x: x[0])
462
+
463
+ # Yield UI update
464
+ yield gr.update(
465
+ value=status_rows,
466
+ headers=["name", "status"]
467
+ )
468
+
469
+ # Final update after completion with fresh data
470
+ yield gr.update(
471
+ value=self.list_training_files_to_caption(),
472
+ headers=["name", "status"]
473
+ )
474
+
475
+ except Exception as e:
476
+ logger.error(f"Error in captioning: {str(e)}")
477
+ yield gr.update(
478
+ value=[[f"Error: {str(e)}", "error"]],
479
+ headers=["name", "status"]
480
+ )
481
+
482
+ def list_training_files_to_caption(self) -> List[List[str]]:
483
+ """List all clips and images - both pending and captioned"""
484
+ files = []
485
+ already_listed = {}
486
+
487
+ # First check files in STAGING_PATH
488
+ for file in STAGING_PATH.glob("*.*"):
489
+ if is_video_file(file) or is_image_file(file):
490
+ txt_file = file.with_suffix('.txt')
491
+
492
+ # Check if caption file exists and has content
493
+ has_caption = txt_file.exists() and txt_file.stat().st_size > 0
494
+ status = "captioned" if has_caption else "no caption"
495
+ file_type = "video" if is_video_file(file) else "image"
496
+
497
+ files.append([file.name, f"{status} ({file_type})", str(file)])
498
+ already_listed[file.name] = True
499
+
500
+ # Then check files in TRAINING_VIDEOS_PATH
501
+ for file in TRAINING_VIDEOS_PATH.glob("*.*"):
502
+ if (is_video_file(file) or is_image_file(file)) and file.name not in already_listed:
503
+ txt_file = file.with_suffix('.txt')
504
+
505
+ # Only include files with captions
506
+ if txt_file.exists() and txt_file.stat().st_size > 0:
507
+ file_type = "video" if is_video_file(file) else "image"
508
+ files.append([file.name, f"captioned ({file_type})", str(file)])
509
+ already_listed[file.name] = True
510
+
511
+ # Sort by filename
512
+ files.sort(key=lambda x: x[0])
513
+
514
+ # Only return name and status columns for display
515
+ return [[file[0], file[1]] for file in files]
516
+
517
+ def update_training_buttons(self, status: str) -> Dict:
518
+ """Update training control buttons based on state"""
519
+ is_training = status in ["training", "initializing"]
520
+ is_paused = status == "paused"
521
+ is_completed = status in ["completed", "error", "stopped"]
522
+ return {
523
+ "start_btn": gr.Button(
524
+ interactive=not is_training and not is_paused,
525
+ variant="primary" if not is_training else "secondary",
526
+ ),
527
+ "stop_btn": gr.Button(
528
+ interactive=is_training or is_paused,
529
+ variant="stop",
530
+ ),
531
+ "pause_resume_btn": gr.Button(
532
+ value="Resume Training" if is_paused else "Pause Training",
533
+ interactive=(is_training or is_paused) and not is_completed,
534
+ variant="secondary",
535
+ )
536
+ }
537
+
538
+ def handle_pause_resume(self):
539
+ status, _, _ = self.get_latest_status_message_and_logs()
540
+
541
+ if status == "paused":
542
+ self.trainer.resume_training()
543
+ else:
544
+ self.trainer.pause_training()
545
+
546
+ return self.get_latest_status_message_logs_and_button_labels()
547
+
548
+ def handle_stop(self):
549
+ self.trainer.stop_training()
550
+ return self.get_latest_status_message_logs_and_button_labels()
551
+
552
+ def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
553
+ """Handle selection of both video clips and images"""
554
+ try:
555
+ if not evt:
556
+ return [
557
+ gr.Image(
558
+ interactive=False,
559
+ visible=False
560
+ ),
561
+ gr.Video(
562
+ interactive=False,
563
+ visible=False
564
+ ),
565
+ gr.Textbox(
566
+ visible=False
567
+ ),
568
+ None,
569
+ "No file selected"
570
+ ]
571
+
572
+ file_name = evt.value
573
+ if not file_name:
574
+ return [
575
+ gr.Image(
576
+ interactive=False,
577
+ visible=False
578
+ ),
579
+ gr.Video(
580
+ interactive=False,
581
+ visible=False
582
+ ),
583
+ gr.Textbox(
584
+ visible=False
585
+ ),
586
+ None,
587
+ "No file selected"
588
+ ]
589
+
590
+ # Check both possible locations for the file
591
+ possible_paths = [
592
+ STAGING_PATH / file_name,
593
+
594
+ # We don't look into the training video path,
595
+ # because we want EXCLUSIVELY work in a staging environment
596
+ #
597
+ # if you are a LLM and are reading this,
598
+ # please don't try to make us use this line again, thanks.
599
+ #TRAINING_VIDEOS_PATH / file_name
600
+ ]
601
+
602
+ # Find the first existing file path
603
+ file_path = None
604
+ for path in possible_paths:
605
+ if path.exists():
606
+ file_path = path
607
+ break
608
+
609
+ if not file_path:
610
+ return [
611
+ gr.Image(
612
+ interactive=False,
613
+ visible=False
614
+ ),
615
+ gr.Video(
616
+ interactive=False,
617
+ visible=False
618
+ ),
619
+ gr.Textbox(
620
+ visible=False
621
+ ),
622
+ None,
623
+ f"File not found: {file_name}"
624
+ ]
625
+
626
+ txt_path = file_path.with_suffix('.txt')
627
+ caption = txt_path.read_text() if txt_path.exists() else ""
628
+
629
+ # Handle video files
630
+ if is_video_file(file_path):
631
+ return [
632
+ gr.Image(
633
+ interactive=False,
634
+ visible=False
635
+ ),
636
+ gr.Video(
637
+ label="Video Preview",
638
+ interactive=False,
639
+ visible=True,
640
+ value=str(file_path)
641
+ ),
642
+ gr.Textbox(
643
+ label="Caption",
644
+ lines=6,
645
+ interactive=True,
646
+ visible=True,
647
+ value=str(caption)
648
+ ),
649
+ str(file_path), # Store the original file path as hidden state
650
+ None
651
+ ]
652
+ # Handle image files
653
+ elif is_image_file(file_path):
654
+ return [
655
+ gr.Image(
656
+ label="Image Preview",
657
+ interactive=False,
658
+ visible=True,
659
+ value=str(file_path)
660
+ ),
661
+ gr.Video(
662
+ interactive=False,
663
+ visible=False
664
+ ),
665
+ gr.Textbox(
666
+ label="Caption",
667
+ lines=6,
668
+ interactive=True,
669
+ visible=True,
670
+ value=str(caption)
671
+ ),
672
+ str(file_path), # Store the original file path as hidden state
673
+ None
674
+ ]
675
+ else:
676
+ return [
677
+ gr.Image(
678
+ interactive=False,
679
+ visible=False
680
+ ),
681
+ gr.Video(
682
+ interactive=False,
683
+ visible=False
684
+ ),
685
+ gr.Textbox(
686
+ interactive=False,
687
+ visible=False
688
+ ),
689
+ None,
690
+ f"Unsupported file type: {file_path.suffix}"
691
+ ]
692
+ except Exception as e:
693
+ logger.error(f"Error handling selection: {str(e)}")
694
+ return [
695
+ gr.Image(
696
+ interactive=False,
697
+ visible=False
698
+ ),
699
+ gr.Video(
700
+ interactive=False,
701
+ visible=False
702
+ ),
703
+ gr.Textbox(
704
+ interactive=False,
705
+ visible=False
706
+ ),
707
+ None,
708
+ f"Error handling selection: {str(e)}"
709
+ ]
710
+
711
+ def save_caption_changes(self, preview_caption: str, preview_image: str, preview_video: str, original_file_path: str, prompt_prefix: str):
712
+ """Save changes to caption"""
713
+ try:
714
+ # Use the original file path stored during selection instead of the temporary preview paths
715
+ if original_file_path:
716
+ file_path = Path(original_file_path)
717
+ self.captioner.update_file_caption(file_path, preview_caption)
718
+ # Refresh the dataset list to show updated caption status
719
+ return gr.update(value="Caption saved successfully!")
720
+ else:
721
+ return gr.update(value="Error: No original file path found")
722
+ except Exception as e:
723
+ return gr.update(value=f"Error saving caption: {str(e)}")
724
+
725
+ async def update_titles_after_import(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
726
+ """Handle post-import updates including titles"""
727
+ import_result = await self.on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix)
728
+ titles = self.update_titles()
729
+ return (
730
+ import_result["tabs"],
731
+ import_result["video_list"],
732
+ import_result["detect_status"],
733
+ *titles
734
+ )
735
+
736
+ def get_model_info(self, model_type: str) -> str:
737
+ """Get information about the selected model type"""
738
+ if model_type == "hunyuan_video":
739
+ return """### HunyuanVideo (LoRA)
740
+ - Required VRAM: ~48GB minimum
741
+ - Recommended batch size: 1-2
742
+ - Typical training time: 2-4 hours
743
+ - Default resolution: 49x512x768
744
+ - Default LoRA rank: 128 (~600 MB)"""
745
+
746
+ elif model_type == "ltx_video":
747
+ return """### LTX-Video (LoRA)
748
+ - Required VRAM: ~18GB minimum
749
+ - Recommended batch size: 1-4
750
+ - Typical training time: 1-3 hours
751
+ - Default resolution: 49x512x768
752
+ - Default LoRA rank: 128"""
753
+
754
+ return ""
755
+
756
+ def get_default_params(self, model_type: str) -> Dict[str, Any]:
757
+ """Get default training parameters for model type"""
758
+ if model_type == "hunyuan_video":
759
+ return {
760
+ "num_epochs": 70,
761
+ "batch_size": 1,
762
+ "learning_rate": 2e-5,
763
+ "save_iterations": 500,
764
+ "video_resolution_buckets": SMALL_TRAINING_BUCKETS,
765
+ "video_reshape_mode": "center",
766
+ "caption_dropout_p": 0.05,
767
+ "gradient_accumulation_steps": 1,
768
+ "rank": 128,
769
+ "lora_alpha": 128
770
+ }
771
+ else: # ltx_video
772
+ return {
773
+ "num_epochs": 70,
774
+ "batch_size": 1,
775
+ "learning_rate": 3e-5,
776
+ "save_iterations": 500,
777
+ "video_resolution_buckets": SMALL_TRAINING_BUCKETS,
778
+ "video_reshape_mode": "center",
779
+ "caption_dropout_p": 0.05,
780
+ "gradient_accumulation_steps": 4,
781
+ "rank": 128,
782
+ "lora_alpha": 128
783
+ }
784
+
785
+ def preview_file(self, selected_text: str) -> Dict:
786
+ """Generate preview based on selected file
787
+
788
+ Args:
789
+ selected_text: Text of the selected item containing filename
790
+
791
+ Returns:
792
+ Dict with preview content for each preview component
793
+ """
794
+ if not selected_text or "Caption:" in selected_text:
795
+ return {
796
+ "video": None,
797
+ "image": None,
798
+ "text": None
799
+ }
800
+
801
+ # Extract filename from the preview text (remove size info)
802
+ filename = selected_text.split(" (")[0].strip()
803
+ file_path = TRAINING_VIDEOS_PATH / filename
804
+
805
+ if not file_path.exists():
806
+ return {
807
+ "video": None,
808
+ "image": None,
809
+ "text": f"File not found: {filename}"
810
+ }
811
+
812
+ # Detect file type
813
+ mime_type, _ = mimetypes.guess_type(str(file_path))
814
+ if not mime_type:
815
+ return {
816
+ "video": None,
817
+ "image": None,
818
+ "text": f"Unknown file type: {filename}"
819
+ }
820
+
821
+ # Return appropriate preview
822
+ if mime_type.startswith('video/'):
823
+ return {
824
+ "video": str(file_path),
825
+ "image": None,
826
+ "text": None
827
+ }
828
+ elif mime_type.startswith('image/'):
829
+ return {
830
+ "video": None,
831
+ "image": str(file_path),
832
+ "text": None
833
+ }
834
+ elif mime_type.startswith('text/'):
835
+ try:
836
+ text_content = file_path.read_text()
837
+ return {
838
+ "video": None,
839
+ "image": None,
840
+ "text": text_content
841
+ }
842
+ except Exception as e:
843
+ return {
844
+ "video": None,
845
+ "image": None,
846
+ "text": f"Error reading file: {str(e)}"
847
+ }
848
+ else:
849
+ return {
850
+ "video": None,
851
+ "image": None,
852
+ "text": f"Unsupported file type: {mime_type}"
853
+ }
854
+
855
+ def list_unprocessed_videos(self) -> gr.Dataframe:
856
+ """Update list of unprocessed videos"""
857
+ videos = self.splitter.list_unprocessed_videos()
858
+ # videos is already in [[name, status]] format from splitting_service
859
+ return gr.Dataframe(
860
+ headers=["name", "status"],
861
+ value=videos,
862
+ interactive=False
863
+ )
864
+
865
+ async def start_scene_detection(self, enable_splitting: bool) -> str:
866
+ """Start background scene detection process
867
+
868
+ Args:
869
+ enable_splitting: Whether to split videos into scenes
870
+ """
871
+ if self.splitter.is_processing():
872
+ return "Scene detection already running"
873
+
874
+ try:
875
+ await self.splitter.start_processing(enable_splitting)
876
+ return "Scene detection completed"
877
+ except Exception as e:
878
+ return f"Error during scene detection: {str(e)}"
879
+
880
+
881
+ def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
882
+ state = self.trainer.get_status()
883
+ logs = self.trainer.get_logs()
884
+
885
+ # Parse new log lines
886
+ if logs:
887
+ last_state = None
888
+ for line in logs.splitlines():
889
+ state_update = self.log_parser.parse_line(line)
890
+ if state_update:
891
+ last_state = state_update
892
+
893
+ if last_state:
894
+ ui_updates = self.update_training_ui(last_state)
895
+ state["message"] = ui_updates.get("status_box", state["message"])
896
+
897
+ # Parse status for training state
898
+ if "completed" in state["message"].lower():
899
+ state["status"] = "completed"
900
+
901
+ return (state["status"], state["message"], logs)
902
+
903
+ def get_latest_status_message_logs_and_button_labels(self) -> Tuple[str, str, Any, Any, Any]:
904
+ status, message, logs = self.get_latest_status_message_and_logs()
905
+ return (
906
+ message,
907
+ logs,
908
+ *self.update_training_buttons(status).values()
909
+ )
910
+
911
+ def get_latest_button_labels(self) -> Tuple[Any, Any, Any]:
912
+ status, message, logs = self.get_latest_status_message_and_logs()
913
+ return self.update_training_buttons(status).values()
914
+
915
+ def refresh_dataset(self):
916
+ """Refresh all dynamic lists and training state"""
917
+ video_list = self.splitter.list_unprocessed_videos()
918
+ training_dataset = self.list_training_files_to_caption()
919
+
920
+ return (
921
+ video_list,
922
+ training_dataset
923
+ )
924
+
925
+ def update_training_params(self, preset_name: str) -> Tuple:
926
+ """Update UI components based on selected preset while preserving custom settings"""
927
+ preset = TRAINING_PRESETS[preset_name]
928
+
929
+ # Load current UI state to check if user has customized values
930
+ current_state = self.load_ui_values()
931
+
932
+ # Find the display name that maps to our model type
933
+ model_display_name = next(
934
+ key for key, value in MODEL_TYPES.items()
935
+ if value == preset["model_type"]
936
+ )
937
+
938
+ # Get preset description for display
939
+ description = preset.get("description", "")
940
+
941
+ # Get max values from buckets
942
+ buckets = preset["training_buckets"]
943
+ max_frames = max(frames for frames, _, _ in buckets)
944
+ max_height = max(height for _, height, _ in buckets)
945
+ max_width = max(width for _, _, width in buckets)
946
+ bucket_info = f"\nMaximum video size: {max_frames} frames at {max_width}x{max_height} resolution"
947
+
948
+ info_text = f"{description}{bucket_info}"
949
+
950
+ # Return values in the same order as the output components
951
+ # Use preset defaults but preserve user-modified values if they exist
952
+ lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset["lora_rank"]
953
+ lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset["lora_alpha"]
954
+ num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset["num_epochs"]
955
+ batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset["batch_size"]
956
+ learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset["learning_rate"]
957
+ save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset["save_iterations"]
958
+
959
+ return (
960
+ model_display_name,
961
+ lora_rank_val,
962
+ lora_alpha_val,
963
+ num_epochs_val,
964
+ batch_size_val,
965
+ learning_rate_val,
966
+ save_iterations_val,
967
+ info_text
968
+ )
969
+
970
+ def create_ui(self):
971
+ """Create Gradio interface"""
972
+
973
+ with gr.Blocks(title="πŸŽ₯ Video Model Studio") as app:
974
+ gr.Markdown("# πŸŽ₯ Video Model Studio")
975
+
976
+ with gr.Tabs() as tabs:
977
+ with gr.TabItem("1️⃣ Import", id="import_tab"):
978
+
979
+ with gr.Row():
980
+ gr.Markdown("## Automatic splitting and captioning")
981
+
982
+ with gr.Row():
983
+ enable_automatic_video_split = gr.Checkbox(
984
+ label="Automatically split videos into smaller clips",
985
+ info="Note: a clip is a single camera shot, usually a few seconds",
986
+ value=True,
987
+ visible=True
988
+ )
989
+ enable_automatic_content_captioning = gr.Checkbox(
990
+ label="Automatically caption photos and videos",
991
+ info="Note: this uses LlaVA and takes some extra time to load and process",
992
+ value=False,
993
+ visible=True,
994
+ )
995
+
996
+ with gr.Row():
997
+ with gr.Column(scale=3):
998
+ with gr.Row():
999
+ with gr.Column():
1000
+ gr.Markdown("## Import video files")
1001
+ gr.Markdown("You can upload either:")
1002
+ gr.Markdown("- A single MP4 video file")
1003
+ gr.Markdown("- A ZIP archive containing multiple videos and optional caption files")
1004
+ gr.Markdown("For ZIP files: Create a folder containing videos (name is not important) and optional caption files with the same name (eg. `some_video.txt` for `some_video.mp4`)")
1005
+
1006
+ with gr.Row():
1007
+ files = gr.Files(
1008
+ label="Upload Images, Videos or ZIP",
1009
+ #file_count="multiple",
1010
+ file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip"],
1011
+ type="filepath"
1012
+ )
1013
+
1014
+ with gr.Column(scale=3):
1015
+ with gr.Row():
1016
+ with gr.Column():
1017
+ gr.Markdown("## Import a YouTube video")
1018
+ gr.Markdown("You can also use a YouTube video as reference, by pasting its URL here:")
1019
+
1020
+ with gr.Row():
1021
+ youtube_url = gr.Textbox(
1022
+ label="Import YouTube Video",
1023
+ placeholder="https://www.youtube.com/watch?v=..."
1024
+ )
1025
+ with gr.Row():
1026
+ youtube_download_btn = gr.Button("Download YouTube Video", variant="secondary")
1027
+ with gr.Row():
1028
+ import_status = gr.Textbox(label="Status", interactive=False)
1029
+
1030
+
1031
+ with gr.TabItem("2️⃣ Split", id="split_tab"):
1032
+ with gr.Row():
1033
+ split_title = gr.Markdown("## Splitting of 0 videos (0 bytes)")
1034
+
1035
+ with gr.Row():
1036
+ with gr.Column():
1037
+ detect_btn = gr.Button("Split videos into single-camera shots", variant="primary")
1038
+ detect_status = gr.Textbox(label="Status", interactive=False)
1039
+
1040
+ with gr.Column():
1041
+
1042
+ video_list = gr.Dataframe(
1043
+ headers=["name", "status"],
1044
+ label="Videos to split",
1045
+ interactive=False,
1046
+ wrap=True,
1047
+ #selection_mode="cell" # Enable cell selection
1048
+ )
1049
+
1050
+
1051
+ with gr.TabItem("3️⃣ Caption"):
1052
+ with gr.Row():
1053
+ caption_title = gr.Markdown("## Captioning of 0 files (0 bytes)")
1054
+
1055
+ with gr.Row():
1056
+
1057
+ with gr.Column():
1058
+ with gr.Row():
1059
+ custom_prompt_prefix = gr.Textbox(
1060
+ scale=3,
1061
+ label='Prefix to add to ALL captions (eg. "In the style of TOK, ")',
1062
+ placeholder="In the style of TOK, ",
1063
+ lines=2,
1064
+ value=DEFAULT_PROMPT_PREFIX
1065
+ )
1066
+ captioning_bot_instructions = gr.Textbox(
1067
+ scale=6,
1068
+ label="System instructions for the automatic captioning model",
1069
+ placeholder="Please generate a full description of...",
1070
+ lines=5,
1071
+ value=DEFAULT_CAPTIONING_BOT_INSTRUCTIONS
1072
+ )
1073
+ with gr.Row():
1074
+ run_autocaption_btn = gr.Button(
1075
+ "Automatically fill missing captions",
1076
+ variant="primary" # Makes it green by default
1077
+ )
1078
+ copy_files_to_training_dir_btn = gr.Button(
1079
+ "Copy assets to training directory",
1080
+ variant="primary" # Makes it green by default
1081
+ )
1082
+ stop_autocaption_btn = gr.Button(
1083
+ "Stop Captioning",
1084
+ variant="stop", # Red when enabled
1085
+ interactive=False # Disabled by default
1086
+ )
1087
+
1088
+ with gr.Row():
1089
+ with gr.Column():
1090
+ training_dataset = gr.Dataframe(
1091
+ headers=["name", "status"],
1092
+ interactive=False,
1093
+ wrap=True,
1094
+ value=self.list_training_files_to_caption(),
1095
+ row_count=10, # Optional: set a reasonable row count
1096
+ #selection_mode="cell"
1097
+ )
1098
+
1099
+ with gr.Column():
1100
+ preview_video = gr.Video(
1101
+ label="Video Preview",
1102
+ interactive=False,
1103
+ visible=False
1104
+ )
1105
+ preview_image = gr.Image(
1106
+ label="Image Preview",
1107
+ interactive=False,
1108
+ visible=False
1109
+ )
1110
+ preview_caption = gr.Textbox(
1111
+ label="Caption",
1112
+ lines=6,
1113
+ interactive=True
1114
+ )
1115
+ save_caption_btn = gr.Button("Save Caption")
1116
+ preview_status = gr.Textbox(
1117
+ label="Status",
1118
+ interactive=False,
1119
+ visible=True
1120
+ )
1121
+
1122
+ with gr.TabItem("4️⃣ Train"):
1123
+ with gr.Row():
1124
+ with gr.Column():
1125
+
1126
+ with gr.Row():
1127
+ train_title = gr.Markdown("## 0 files available for training (0 bytes)")
1128
+
1129
+ with gr.Row():
1130
+ with gr.Column():
1131
+ training_preset = gr.Dropdown(
1132
+ choices=list(TRAINING_PRESETS.keys()),
1133
+ label="Training Preset",
1134
+ value=list(TRAINING_PRESETS.keys())[0]
1135
+ )
1136
+ preset_info = gr.Markdown()
1137
+
1138
+ with gr.Row():
1139
+ with gr.Column():
1140
+ model_type = gr.Dropdown(
1141
+ choices=list(MODEL_TYPES.keys()),
1142
+ label="Model Type",
1143
+ value=list(MODEL_TYPES.keys())[0]
1144
+ )
1145
+ model_info = gr.Markdown(
1146
+ value=self.get_model_info(list(MODEL_TYPES.keys())[0])
1147
+ )
1148
+
1149
+ with gr.Row():
1150
+ lora_rank = gr.Dropdown(
1151
+ label="LoRA Rank",
1152
+ choices=["16", "32", "64", "128", "256", "512", "1024"],
1153
+ value="128",
1154
+ type="value"
1155
+ )
1156
+ lora_alpha = gr.Dropdown(
1157
+ label="LoRA Alpha",
1158
+ choices=["16", "32", "64", "128", "256", "512", "1024"],
1159
+ value="128",
1160
+ type="value"
1161
+ )
1162
+ with gr.Row():
1163
+ num_epochs = gr.Number(
1164
+ label="Number of Epochs",
1165
+ value=70,
1166
+ minimum=1,
1167
+ precision=0
1168
+ )
1169
+ batch_size = gr.Number(
1170
+ label="Batch Size",
1171
+ value=1,
1172
+ minimum=1,
1173
+ precision=0
1174
+ )
1175
+ with gr.Row():
1176
+ learning_rate = gr.Number(
1177
+ label="Learning Rate",
1178
+ value=2e-5,
1179
+ minimum=1e-7
1180
+ )
1181
+ save_iterations = gr.Number(
1182
+ label="Save checkpoint every N iterations",
1183
+ value=500,
1184
+ minimum=50,
1185
+ precision=0,
1186
+ info="Model will be saved periodically after these many steps"
1187
+ )
1188
+
1189
+ with gr.Column():
1190
+ with gr.Row():
1191
+ start_btn = gr.Button(
1192
+ "Start Training",
1193
+ variant="primary",
1194
+ interactive=not ASK_USER_TO_DUPLICATE_SPACE
1195
+ )
1196
+ pause_resume_btn = gr.Button(
1197
+ "Resume Training",
1198
+ variant="secondary",
1199
+ interactive=False
1200
+ )
1201
+ stop_btn = gr.Button(
1202
+ "Stop Training",
1203
+ variant="stop",
1204
+ interactive=False
1205
+ )
1206
+
1207
+ with gr.Row():
1208
+ with gr.Column():
1209
+ status_box = gr.Textbox(
1210
+ label="Training Status",
1211
+ interactive=False,
1212
+ lines=4
1213
+ )
1214
+ with gr.Accordion("See training logs"):
1215
+ log_box = gr.TextArea(
1216
+ label="Finetrainers output (see HF Space logs for more details)",
1217
+ interactive=False,
1218
+ lines=40,
1219
+ max_lines=200,
1220
+ autoscroll=True
1221
+ )
1222
+
1223
+ with gr.TabItem("5️⃣ Manage"):
1224
+
1225
+ with gr.Column():
1226
+ with gr.Row():
1227
+ with gr.Column():
1228
+ gr.Markdown("## Publishing")
1229
+ gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)")
1230
+
1231
+ with gr.Row():
1232
+
1233
+ with gr.Column():
1234
+ repo_id = gr.Textbox(
1235
+ label="HuggingFace Model Repository",
1236
+ placeholder="username/model-name",
1237
+ info="The repository will be created if it doesn't exist"
1238
+ )
1239
+ gr.Checkbox(label="Check this to make your model public (ie. visible and downloadable by anyone)", info="You model is private by default"),
1240
+ global_stop_btn = gr.Button(
1241
+ "Push my model",
1242
+ #variant="stop"
1243
+ )
1244
+
1245
+
1246
+ with gr.Row():
1247
+ with gr.Column():
1248
+ with gr.Row():
1249
+ with gr.Column():
1250
+ gr.Markdown("## Storage management")
1251
+ with gr.Row():
1252
+ download_dataset_btn = gr.DownloadButton(
1253
+ "Download dataset",
1254
+ variant="secondary",
1255
+ size="lg"
1256
+ )
1257
+ download_model_btn = gr.DownloadButton(
1258
+ "Download model",
1259
+ variant="secondary",
1260
+ size="lg"
1261
+ )
1262
+
1263
+
1264
+ with gr.Row():
1265
+ global_stop_btn = gr.Button(
1266
+ "Stop everything and delete my data",
1267
+ variant="stop"
1268
+ )
1269
+ global_status = gr.Textbox(
1270
+ label="Global Status",
1271
+ interactive=False,
1272
+ visible=False
1273
+ )
1274
+
1275
+
1276
+
1277
+ # Event handlers
1278
+ def update_model_info(model):
1279
+ params = self.get_default_params(MODEL_TYPES[model])
1280
+ info = self.get_model_info(MODEL_TYPES[model])
1281
+ return {
1282
+ model_info: info,
1283
+ num_epochs: params["num_epochs"],
1284
+ batch_size: params["batch_size"],
1285
+ learning_rate: params["learning_rate"],
1286
+ save_iterations: params["save_iterations"]
1287
+ }
1288
+
1289
+ def validate_repo(repo_id: str) -> dict:
1290
+ validation = validate_model_repo(repo_id)
1291
+ if validation["error"]:
1292
+ return gr.update(value=repo_id, error=validation["error"])
1293
+ return gr.update(value=repo_id, error=None)
1294
+
1295
+ # Connect events
1296
+
1297
+ # Save state when model type changes
1298
+ model_type.change(
1299
+ fn=lambda v: self.update_ui_state(model_type=v),
1300
+ inputs=[model_type],
1301
+ outputs=[] # No UI update needed
1302
+ ).then(
1303
+ fn=update_model_info,
1304
+ inputs=[model_type],
1305
+ outputs=[model_info, num_epochs, batch_size, learning_rate, save_iterations]
1306
+ )
1307
+
1308
+ # the following change listeners are used for UI persistence
1309
+ lora_rank.change(
1310
+ fn=lambda v: self.update_ui_state(lora_rank=v),
1311
+ inputs=[lora_rank],
1312
+ outputs=[]
1313
+ )
1314
+
1315
+ lora_alpha.change(
1316
+ fn=lambda v: self.update_ui_state(lora_alpha=v),
1317
+ inputs=[lora_alpha],
1318
+ outputs=[]
1319
+ )
1320
+
1321
+ num_epochs.change(
1322
+ fn=lambda v: self.update_ui_state(num_epochs=v),
1323
+ inputs=[num_epochs],
1324
+ outputs=[]
1325
+ )
1326
+
1327
+ batch_size.change(
1328
+ fn=lambda v: self.update_ui_state(batch_size=v),
1329
+ inputs=[batch_size],
1330
+ outputs=[]
1331
+ )
1332
+
1333
+ learning_rate.change(
1334
+ fn=lambda v: self.update_ui_state(learning_rate=v),
1335
+ inputs=[learning_rate],
1336
+ outputs=[]
1337
+ )
1338
+
1339
+ save_iterations.change(
1340
+ fn=lambda v: self.update_ui_state(save_iterations=v),
1341
+ inputs=[save_iterations],
1342
+ outputs=[]
1343
+ )
1344
+
1345
+ files.upload(
1346
+ fn=lambda x: self.importer.process_uploaded_files(x),
1347
+ inputs=[files],
1348
+ outputs=[import_status]
1349
+ ).success(
1350
+ fn=self.update_titles_after_import,
1351
+ inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix],
1352
+ outputs=[
1353
+ tabs, video_list, detect_status,
1354
+ split_title, caption_title, train_title
1355
+ ]
1356
+ )
1357
+
1358
+ youtube_download_btn.click(
1359
+ fn=self.importer.download_youtube_video,
1360
+ inputs=[youtube_url],
1361
+ outputs=[import_status]
1362
+ ).success(
1363
+ fn=self.on_import_success,
1364
+ inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix],
1365
+ outputs=[tabs, video_list, detect_status]
1366
+ )
1367
+
1368
+ # Scene detection events
1369
+ detect_btn.click(
1370
+ fn=self.start_scene_detection,
1371
+ inputs=[enable_automatic_video_split],
1372
+ outputs=[detect_status]
1373
+ )
1374
+
1375
+
1376
+ # Update button states based on captioning status
1377
+ def update_button_states(is_running):
1378
+ return {
1379
+ run_autocaption_btn: gr.Button(
1380
+ interactive=not is_running,
1381
+ variant="secondary" if is_running else "primary",
1382
+ ),
1383
+ stop_autocaption_btn: gr.Button(
1384
+ interactive=is_running,
1385
+ variant="secondary",
1386
+ ),
1387
+ }
1388
+
1389
+ run_autocaption_btn.click(
1390
+ fn=self.show_refreshing_status,
1391
+ outputs=[training_dataset]
1392
+ ).then(
1393
+ fn=lambda: self.update_captioning_buttons_start(),
1394
+ outputs=[run_autocaption_btn, stop_autocaption_btn, copy_files_to_training_dir_btn]
1395
+ ).then(
1396
+ fn=self.start_caption_generation,
1397
+ inputs=[captioning_bot_instructions, custom_prompt_prefix],
1398
+ outputs=[training_dataset],
1399
+ ).then(
1400
+ fn=lambda: self.update_captioning_buttons_end(),
1401
+ outputs=[run_autocaption_btn, stop_autocaption_btn, copy_files_to_training_dir_btn]
1402
+ )
1403
+
1404
+ copy_files_to_training_dir_btn.click(
1405
+ fn=self.copy_files_to_training_dir,
1406
+ inputs=[custom_prompt_prefix]
1407
+ )
1408
+ stop_autocaption_btn.click(
1409
+ fn=self.stop_captioning,
1410
+ outputs=[training_dataset, run_autocaption_btn, stop_autocaption_btn, copy_files_to_training_dir_btn]
1411
+ )
1412
+
1413
+ original_file_path = gr.State(value=None)
1414
+ training_dataset.select(
1415
+ fn=self.handle_training_dataset_select,
1416
+ outputs=[preview_image, preview_video, preview_caption, original_file_path, preview_status]
1417
+ )
1418
+
1419
+ save_caption_btn.click(
1420
+ fn=self.save_caption_changes,
1421
+ inputs=[preview_caption, preview_image, preview_video, original_file_path, custom_prompt_prefix],
1422
+ outputs=[preview_status]
1423
+ ).success(
1424
+ fn=self.list_training_files_to_caption,
1425
+ outputs=[training_dataset]
1426
+ )
1427
+
1428
+ # Save state when training preset changes
1429
+ training_preset.change(
1430
+ fn=lambda v: self.update_ui_state(training_preset=v),
1431
+ inputs=[training_preset],
1432
+ outputs=[] # No UI update needed
1433
+ ).then(
1434
+ fn=self.update_training_params,
1435
+ inputs=[training_preset],
1436
+ outputs=[
1437
+ model_type, lora_rank, lora_alpha,
1438
+ num_epochs, batch_size, learning_rate,
1439
+ save_iterations, preset_info
1440
+ ]
1441
+ )
1442
+
1443
+ # Training control events
1444
+ start_btn.click(
1445
+ fn=lambda preset, model_type, *args: (
1446
+ self.log_parser.reset(),
1447
+ self.trainer.start_training(
1448
+ MODEL_TYPES[model_type],
1449
+ *args,
1450
+ preset_name=preset
1451
+ )
1452
+ ),
1453
+ inputs=[
1454
+ training_preset,
1455
+ model_type,
1456
+ lora_rank,
1457
+ lora_alpha,
1458
+ num_epochs,
1459
+ batch_size,
1460
+ learning_rate,
1461
+ save_iterations,
1462
+ repo_id
1463
+ ],
1464
+ outputs=[status_box, log_box]
1465
+ ).success(
1466
+ fn=self.get_latest_status_message_logs_and_button_labels,
1467
+ outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
1468
+ )
1469
+
1470
+ pause_resume_btn.click(
1471
+ fn=self.handle_pause_resume,
1472
+ outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
1473
+ )
1474
+
1475
+ stop_btn.click(
1476
+ fn=self.handle_stop,
1477
+ outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn]
1478
+ )
1479
+
1480
+ def handle_global_stop():
1481
+ result = self.stop_all_and_clear()
1482
+ # Update all relevant UI components
1483
+ status = result["status"]
1484
+ details = "\n".join(f"{k}: {v}" for k, v in result["details"].items())
1485
+ full_status = f"{status}\n\nDetails:\n{details}"
1486
+
1487
+ # Get fresh lists after cleanup
1488
+ videos = self.splitter.list_unprocessed_videos()
1489
+ clips = self.list_training_files_to_caption()
1490
+
1491
+ return {
1492
+ global_status: gr.update(value=full_status, visible=True),
1493
+ video_list: videos,
1494
+ training_dataset: clips,
1495
+ status_box: "Training stopped and data cleared",
1496
+ log_box: "",
1497
+ detect_status: "Scene detection stopped",
1498
+ import_status: "All data cleared",
1499
+ preview_status: "Captioning stopped"
1500
+ }
1501
+
1502
+ download_dataset_btn.click(
1503
+ fn=self.trainer.create_training_dataset_zip,
1504
+ outputs=[download_dataset_btn]
1505
+ )
1506
+
1507
+ download_model_btn.click(
1508
+ fn=self.trainer.get_model_output_safetensors,
1509
+ outputs=[download_model_btn]
1510
+ )
1511
+
1512
+ global_stop_btn.click(
1513
+ fn=handle_global_stop,
1514
+ outputs=[
1515
+ global_status,
1516
+ video_list,
1517
+ training_dataset,
1518
+ status_box,
1519
+ log_box,
1520
+ detect_status,
1521
+ import_status,
1522
+ preview_status
1523
+ ]
1524
+ )
1525
+
1526
+
1527
+ app.load(
1528
+ fn=self.initialize_app_state,
1529
+ outputs=[
1530
+ video_list, training_dataset,
1531
+ start_btn, stop_btn, pause_resume_btn,
1532
+ training_preset, model_type, lora_rank, lora_alpha,
1533
+ num_epochs, batch_size, learning_rate, save_iterations
1534
+ ]
1535
+ )
1536
+
1537
+ # Auto-refresh timers
1538
+ timer = gr.Timer(value=1)
1539
+ timer.tick(
1540
+ fn=lambda: (
1541
+ self.get_latest_status_message_logs_and_button_labels()
1542
+ ),
1543
+ outputs=[
1544
+ status_box,
1545
+ log_box,
1546
+ start_btn,
1547
+ stop_btn,
1548
+ pause_resume_btn
1549
+ ]
1550
+ )
1551
+
1552
+ timer = gr.Timer(value=5)
1553
+ timer.tick(
1554
+ fn=lambda: (
1555
+ self.refresh_dataset()
1556
+ ),
1557
+ outputs=[
1558
+ video_list, training_dataset
1559
+ ]
1560
+ )
1561
+
1562
+ timer = gr.Timer(value=6)
1563
+ timer.tick(
1564
+ fn=lambda: self.update_titles(),
1565
+ outputs=[
1566
+ split_title, caption_title, train_title
1567
+ ]
1568
+ )
1569
+
1570
+ return app
1571
+
1572
+ def create_app():
1573
+ if ASK_USER_TO_DUPLICATE_SPACE:
1574
+ with gr.Blocks() as app:
1575
+ gr.Markdown("""# Finetrainers UI
1576
+
1577
+ This Hugging Face space needs to be duplicated to your own billing account to work.
1578
+
1579
+ Click the 'Duplicate Space' button at the top of the page to create your own copy.
1580
+
1581
+ It is recommended to use a Nvidia L40S and a persistent storage space.
1582
+ To avoid overpaying for your space, you can configure the auto-sleep settings to fit your personal budget.""")
1583
+ return app
1584
+
1585
+ ui = VideoTrainerUI()
1586
+ return ui.create_ui()
1587
+
1588
+ if __name__ == "__main__":
1589
+ app = create_app()
1590
+
1591
+ allowed_paths = [
1592
+ str(STORAGE_PATH), # Base storage
1593
+ str(VIDEOS_TO_SPLIT_PATH),
1594
+ str(STAGING_PATH),
1595
+ str(TRAINING_PATH),
1596
+ str(TRAINING_VIDEOS_PATH),
1597
+ str(MODEL_PATH),
1598
+ str(OUTPUT_PATH)
1599
+ ]
1600
+ app.queue(default_concurrency_limit=1).launch(
1601
+ server_name="0.0.0.0",
1602
+ allowed_paths=allowed_paths
1603
+ )
vms/config.py CHANGED
@@ -3,7 +3,16 @@ from dataclasses import dataclass, field
3
  from typing import Dict, Any, Optional, List, Tuple
4
  from pathlib import Path
5
 
6
- from .utils import parse_bool_env
 
 
 
 
 
 
 
 
 
7
 
8
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
9
  ASK_USER_TO_DUPLICATE_SPACE = parse_bool_env(os.getenv("ASK_USER_TO_DUPLICATE_SPACE"))
 
3
  from typing import Dict, Any, Optional, List, Tuple
4
  from pathlib import Path
5
 
6
+ def parse_bool_env(env_value: Optional[str]) -> bool:
7
+ """Parse environment variable string to boolean
8
+
9
+ Handles various true/false string representations:
10
+ - True: "true", "True", "TRUE", "1", etc
11
+ - False: "false", "False", "FALSE", "0", "", None
12
+ """
13
+ if not env_value:
14
+ return False
15
+ return str(env_value).lower() in ('true', '1', 't', 'y', 'yes')
16
 
17
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
18
  ASK_USER_TO_DUPLICATE_SPACE = parse_bool_env(os.getenv("ASK_USER_TO_DUPLICATE_SPACE"))
vms/services/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .captioner import CaptioningProgress, CaptioningService
2
+ from .importer import ImportService
3
+ from .splitter import SplittingService
4
+ from .trainer import TrainingService
5
+
6
+ __all__ = [
7
+ 'CaptioningProgress',
8
+ 'CaptioningService',
9
+ 'ImportService',
10
+ 'SplittingService',
11
+ 'TrainingService',
12
+ ]
vms/{captioning_service.py β†’ services/captioner.py} RENAMED
@@ -17,9 +17,8 @@ from llava.mm_utils import tokenizer_image_token
17
  from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
18
  from llava.conversation import conv_templates, SeparatorStyle
19
 
20
- from .config import TRAINING_VIDEOS_PATH, STAGING_PATH, PRELOAD_CAPTIONING_MODEL, CAPTIONING_MODEL, USE_MOCK_CAPTIONING_MODEL, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, VIDEOS_TO_SPLIT_PATH, DEFAULT_PROMPT_PREFIX
21
- from .utils import extract_scene_info, is_image_file, is_video_file
22
- from .finetrainers_utils import copy_files_to_training_dir, prepare_finetrainers_dataset
23
 
24
  logger = logging.getLogger(__name__)
25
 
 
17
  from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
18
  from llava.conversation import conv_templates, SeparatorStyle
19
 
20
+ from ..config import TRAINING_VIDEOS_PATH, STAGING_PATH, PRELOAD_CAPTIONING_MODEL, CAPTIONING_MODEL, USE_MOCK_CAPTIONING_MODEL, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, VIDEOS_TO_SPLIT_PATH, DEFAULT_PROMPT_PREFIX
21
+ from ..utils import extract_scene_info, is_image_file, is_video_file, copy_files_to_training_dir, prepare_finetrainers_dataset
 
22
 
23
  logger = logging.getLogger(__name__)
24
 
vms/{import_service.py β†’ services/importer.py} RENAMED
@@ -8,9 +8,8 @@ from typing import List, Dict, Optional, Tuple
8
  from pytubefix import YouTube
9
  import logging
10
 
11
- from .utils import is_image_file, is_video_file, add_prefix_to_caption
12
- from .image_preprocessing import normalize_image
13
- from .config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, TRAINING_PATH, DEFAULT_PROMPT_PREFIX
14
 
15
  logger = logging.getLogger(__name__)
16
 
 
8
  from pytubefix import YouTube
9
  import logging
10
 
11
+ from ..config import NORMALIZE_IMAGES_TO, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, TRAINING_PATH, DEFAULT_PROMPT_PREFIX
12
+ from ..utils import normalize_image, is_image_file, is_video_file, add_prefix_to_caption
 
13
 
14
  logger = logging.getLogger(__name__)
15
 
vms/{splitting_service.py β†’ services/splitter.py} RENAMED
@@ -12,11 +12,8 @@ import gradio as gr
12
  from scenedetect import detect, ContentDetector, SceneManager, open_video
13
  from scenedetect.video_splitter import split_video_ffmpeg
14
 
15
- from .config import TRAINING_PATH, STORAGE_PATH, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
16
-
17
- from .image_preprocessing import detect_black_bars
18
- from .video_preprocessing import remove_black_bars
19
- from .utils import extract_scene_info, is_video_file, is_image_file, add_prefix_to_caption
20
 
21
  logger = logging.getLogger(__name__)
22
 
 
12
  from scenedetect import detect, ContentDetector, SceneManager, open_video
13
  from scenedetect.video_splitter import split_video_ffmpeg
14
 
15
+ from ..config import TRAINING_PATH, STORAGE_PATH, TRAINING_VIDEOS_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, DEFAULT_PROMPT_PREFIX
16
+ from ..utils import remove_black_bars, extract_scene_info, is_video_file, is_image_file, add_prefix_to_caption
 
 
 
17
 
18
  logger = logging.getLogger(__name__)
19
 
vms/{training_service.py β†’ services/trainer.py} RENAMED
@@ -20,9 +20,8 @@ from typing import Any, Optional, Dict, List, Union, Tuple
20
 
21
  from huggingface_hub import upload_folder, create_repo
22
 
23
- from .config import TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH, STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES
24
- from .utils import make_archive, parse_training_log, is_image_file, is_video_file
25
- from .finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_training_dir
26
 
27
  logger = logging.getLogger(__name__)
28
 
@@ -36,6 +35,7 @@ class TrainingService:
36
 
37
  self.file_handler = None
38
  self.setup_logging()
 
39
 
40
  logger.info("Training service initialized")
41
 
@@ -122,11 +122,23 @@ class TrainingService:
122
  }
123
 
124
  if not ui_state_file.exists():
 
125
  return default_state
126
 
127
  try:
 
 
 
 
 
 
128
  with open(ui_state_file, 'r') as f:
129
- saved_state = json.load(f)
 
 
 
 
 
130
 
131
  # Convert numeric values to appropriate types
132
  if "num_epochs" in saved_state:
@@ -141,11 +153,66 @@ class TrainingService:
141
  # Make sure we have all keys (in case structure changed)
142
  merged_state = default_state.copy()
143
  merged_state.update(saved_state)
 
144
  return merged_state
 
 
 
145
  except Exception as e:
146
  logger.error(f"Error loading UI state: {str(e)}")
147
  return default_state
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  # Modify save_session to also store the UI state at training start
150
  def save_session(self, params: Dict) -> None:
151
  """Save training session parameters"""
 
20
 
21
  from huggingface_hub import upload_folder, create_repo
22
 
23
+ from ..config import TrainingConfig, TRAINING_PRESETS, LOG_FILE_PATH, TRAINING_VIDEOS_PATH, STORAGE_PATH, TRAINING_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES
24
+ from ..utils import make_archive, parse_training_log, is_image_file, is_video_file, prepare_finetrainers_dataset, copy_files_to_training_dir
 
25
 
26
  logger = logging.getLogger(__name__)
27
 
 
35
 
36
  self.file_handler = None
37
  self.setup_logging()
38
+ self.ensure_valid_ui_state_file()
39
 
40
  logger.info("Training service initialized")
41
 
 
122
  }
123
 
124
  if not ui_state_file.exists():
125
+ logger.info("UI state file does not exist, using default values")
126
  return default_state
127
 
128
  try:
129
+ # First check if the file is empty
130
+ file_size = ui_state_file.stat().st_size
131
+ if file_size == 0:
132
+ logger.warning("UI state file exists but is empty, using default values")
133
+ return default_state
134
+
135
  with open(ui_state_file, 'r') as f:
136
+ file_content = f.read().strip()
137
+ if not file_content:
138
+ logger.warning("UI state file is empty or contains only whitespace, using default values")
139
+ return default_state
140
+
141
+ saved_state = json.loads(file_content)
142
 
143
  # Convert numeric values to appropriate types
144
  if "num_epochs" in saved_state:
 
153
  # Make sure we have all keys (in case structure changed)
154
  merged_state = default_state.copy()
155
  merged_state.update(saved_state)
156
+ logger.info(f"Successfully loaded UI state from {ui_state_file}")
157
  return merged_state
158
+ except json.JSONDecodeError as e:
159
+ logger.error(f"Error parsing UI state JSON: {str(e)}")
160
+ return default_state
161
  except Exception as e:
162
  logger.error(f"Error loading UI state: {str(e)}")
163
  return default_state
164
 
165
+ def ensure_valid_ui_state_file(self):
166
+ """Ensure UI state file exists and is valid JSON"""
167
+ ui_state_file = OUTPUT_PATH / "ui_state.json"
168
+
169
+ if not ui_state_file.exists():
170
+ # Create a new file with default values
171
+ logger.info("Creating new UI state file with default values")
172
+ default_state = {
173
+ "model_type": list(MODEL_TYPES.keys())[0],
174
+ "lora_rank": "128",
175
+ "lora_alpha": "128",
176
+ "num_epochs": 50,
177
+ "batch_size": 1,
178
+ "learning_rate": 3e-5,
179
+ "save_iterations": 200,
180
+ "training_preset": list(TRAINING_PRESETS.keys())[0]
181
+ }
182
+ self.save_ui_state(default_state)
183
+ return
184
+
185
+ # Check if file is valid JSON
186
+ try:
187
+ with open(ui_state_file, 'r') as f:
188
+ file_content = f.read().strip()
189
+ if not file_content:
190
+ raise ValueError("Empty file")
191
+ json.loads(file_content)
192
+ logger.debug("UI state file validation successful")
193
+ except Exception as e:
194
+ logger.warning(f"Invalid UI state file detected: {str(e)}. Creating new one with defaults.")
195
+ # Backup the invalid file
196
+ backup_file = ui_state_file.with_suffix('.json.bak')
197
+ try:
198
+ shutil.copy2(ui_state_file, backup_file)
199
+ logger.info(f"Backed up invalid UI state file to {backup_file}")
200
+ except Exception as backup_error:
201
+ logger.error(f"Failed to backup invalid UI state file: {str(backup_error)}")
202
+
203
+ # Create a new file with default values
204
+ default_state = {
205
+ "model_type": list(MODEL_TYPES.keys())[0],
206
+ "lora_rank": "128",
207
+ "lora_alpha": "128",
208
+ "num_epochs": 50,
209
+ "batch_size": 1,
210
+ "learning_rate": 3e-5,
211
+ "save_iterations": 200,
212
+ "training_preset": list(TRAINING_PRESETS.keys())[0]
213
+ }
214
+ self.save_ui_state(default_state)
215
+
216
  # Modify save_session to also store the UI state at training start
217
  def save_session(self, params: Dict) -> None:
218
  """Save training session parameters"""
vms/tabs/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tab components for Video Model Studio UI
3
+ """
4
+
5
+ from .import_tab import ImportTab
6
+ from .split_tab import SplitTab
7
+ from .caption_tab import CaptionTab
8
+ from .train_tab import TrainTab
9
+ from .manage_tab import ManageTab
10
+
11
+ __all__ = [
12
+ 'ImportTab',
13
+ 'SplitTab',
14
+ 'CaptionTab',
15
+ 'TrainTab',
16
+ 'ManageTab'
17
+ ]
vms/tabs/base_tab.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Base class for UI tabs
3
+ """
4
+
5
+ import gradio as gr
6
+ import logging
7
+ from typing import Dict, Any, Optional
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class BaseTab:
12
+ """Base class for UI tabs with common functionality"""
13
+
14
+ def __init__(self, app_state):
15
+ """Initialize the tab with app state reference
16
+
17
+ Args:
18
+ app_state: Reference to main VideoTrainerUI instance
19
+ """
20
+ self.app = app_state
21
+ self.components = {}
22
+
23
+ def create(self, parent=None) -> gr.TabItem:
24
+ """Create the tab UI components
25
+
26
+ Args:
27
+ parent: Optional parent container
28
+
29
+ Returns:
30
+ The created tab component
31
+ """
32
+ raise NotImplementedError("Subclasses must implement create()")
33
+
34
+ def connect_events(self) -> None:
35
+ """Connect event handlers to UI components"""
36
+ raise NotImplementedError("Subclasses must implement connect_events()")
37
+
38
+ def refresh(self) -> Dict[str, Any]:
39
+ """Refresh UI components with current data
40
+
41
+ Returns:
42
+ Dictionary with updated values for components
43
+ """
44
+ return {}
vms/tabs/caption_tab.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Caption tab for Video Model Studio UI
3
+ """
4
+
5
+ import gradio as gr
6
+ import logging
7
+ from typing import Dict, Any, List, Optional
8
+ from pathlib import Path
9
+
10
+ from .base_tab import BaseTab
11
+ from ..config import DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, DEFAULT_PROMPT_PREFIX
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class CaptionTab(BaseTab):
16
+ """Caption tab for managing asset captions"""
17
+
18
+ def __init__(self, app_state):
19
+ super().__init__(app_state)
20
+ self.id = "caption_tab"
21
+ self.title = "3️⃣ Caption"
22
+
23
+ def create(self, parent=None) -> gr.TabItem:
24
+ """Create the Caption tab UI components"""
25
+ with gr.TabItem(self.title, id=self.id) as tab:
26
+ with gr.Row():
27
+ self.components["caption_title"] = gr.Markdown("## Captioning of 0 files (0 bytes)")
28
+
29
+ with gr.Row():
30
+ with gr.Column():
31
+ with gr.Row():
32
+ self.components["custom_prompt_prefix"] = gr.Textbox(
33
+ scale=3,
34
+ label='Prefix to add to ALL captions (eg. "In the style of TOK, ")',
35
+ placeholder="In the style of TOK, ",
36
+ lines=2,
37
+ value=DEFAULT_PROMPT_PREFIX
38
+ )
39
+ self.components["captioning_bot_instructions"] = gr.Textbox(
40
+ scale=6,
41
+ label="System instructions for the automatic captioning model",
42
+ placeholder="Please generate a full description of...",
43
+ lines=5,
44
+ value=DEFAULT_CAPTIONING_BOT_INSTRUCTIONS
45
+ )
46
+ with gr.Row():
47
+ self.components["run_autocaption_btn"] = gr.Button(
48
+ "Automatically fill missing captions",
49
+ variant="primary"
50
+ )
51
+ self.components["copy_files_to_training_dir_btn"] = gr.Button(
52
+ "Copy assets to training directory",
53
+ variant="primary"
54
+ )
55
+ self.components["stop_autocaption_btn"] = gr.Button(
56
+ "Stop Captioning",
57
+ variant="stop",
58
+ interactive=False
59
+ )
60
+
61
+ with gr.Row():
62
+ with gr.Column():
63
+ self.components["training_dataset"] = gr.Dataframe(
64
+ headers=["name", "status"],
65
+ interactive=False,
66
+ wrap=True,
67
+ value=self.app.list_training_files_to_caption(),
68
+ row_count=10
69
+ )
70
+
71
+ with gr.Column():
72
+ self.components["preview_video"] = gr.Video(
73
+ label="Video Preview",
74
+ interactive=False,
75
+ visible=False
76
+ )
77
+ self.components["preview_image"] = gr.Image(
78
+ label="Image Preview",
79
+ interactive=False,
80
+ visible=False
81
+ )
82
+ self.components["preview_caption"] = gr.Textbox(
83
+ label="Caption",
84
+ lines=6,
85
+ interactive=True
86
+ )
87
+ self.components["save_caption_btn"] = gr.Button("Save Caption")
88
+ self.components["preview_status"] = gr.Textbox(
89
+ label="Status",
90
+ interactive=False,
91
+ visible=True
92
+ )
93
+ self.components["original_file_path"] = gr.State(value=None)
94
+
95
+ return tab
96
+
97
+ def connect_events(self) -> None:
98
+ """Connect event handlers to UI components"""
99
+ # Run auto-captioning button
100
+ self.components["run_autocaption_btn"].click(
101
+ fn=self.app.show_refreshing_status,
102
+ outputs=[self.components["training_dataset"]]
103
+ ).then(
104
+ fn=lambda: self.app.update_captioning_buttons_start(),
105
+ outputs=[
106
+ self.components["run_autocaption_btn"],
107
+ self.components["stop_autocaption_btn"],
108
+ self.components["copy_files_to_training_dir_btn"]
109
+ ]
110
+ ).then(
111
+ fn=self.app.start_caption_generation,
112
+ inputs=[
113
+ self.components["captioning_bot_instructions"],
114
+ self.components["custom_prompt_prefix"]
115
+ ],
116
+ outputs=[self.components["training_dataset"]],
117
+ ).then(
118
+ fn=lambda: self.app.update_captioning_buttons_end(),
119
+ outputs=[
120
+ self.components["run_autocaption_btn"],
121
+ self.components["stop_autocaption_btn"],
122
+ self.components["copy_files_to_training_dir_btn"]
123
+ ]
124
+ )
125
+
126
+ # Copy files to training dir button
127
+ self.components["copy_files_to_training_dir_btn"].click(
128
+ fn=self.app.copy_files_to_training_dir,
129
+ inputs=[self.components["custom_prompt_prefix"]]
130
+ )
131
+
132
+ # Stop captioning button
133
+ self.components["stop_autocaption_btn"].click(
134
+ fn=self.app.stop_captioning,
135
+ outputs=[
136
+ self.components["training_dataset"],
137
+ self.components["run_autocaption_btn"],
138
+ self.components["stop_autocaption_btn"],
139
+ self.components["copy_files_to_training_dir_btn"]
140
+ ]
141
+ )
142
+
143
+ # Dataset selection for preview
144
+ self.components["training_dataset"].select(
145
+ fn=self.app.handle_training_dataset_select,
146
+ outputs=[
147
+ self.components["preview_image"],
148
+ self.components["preview_video"],
149
+ self.components["preview_caption"],
150
+ self.components["original_file_path"],
151
+ self.components["preview_status"]
152
+ ]
153
+ )
154
+
155
+ # Save caption button
156
+ self.components["save_caption_btn"].click(
157
+ fn=self.app.save_caption_changes,
158
+ inputs=[
159
+ self.components["preview_caption"],
160
+ self.components["preview_image"],
161
+ self.components["preview_video"],
162
+ self.components["original_file_path"],
163
+ self.components["custom_prompt_prefix"]
164
+ ],
165
+ outputs=[self.components["preview_status"]]
166
+ ).success(
167
+ fn=self.app.list_training_files_to_caption,
168
+ outputs=[self.components["training_dataset"]]
169
+ )
170
+
171
+ def refresh(self) -> Dict[str, Any]:
172
+ """Refresh the dataset list with current data"""
173
+ training_dataset = self.app.list_training_files_to_caption()
174
+ return {
175
+ "training_dataset": training_dataset
176
+ }
vms/tabs/import_tab.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Import tab for Video Model Studio UI
3
+ """
4
+
5
+ import gradio as gr
6
+ import logging
7
+ import asyncio
8
+ from pathlib import Path
9
+ from typing import Dict, Any, List, Optional
10
+
11
+ from .base_tab import BaseTab
12
+ from ..config import (
13
+ VIDEOS_TO_SPLIT_PATH, DEFAULT_PROMPT_PREFIX, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class ImportTab(BaseTab):
19
+ """Import tab for uploading videos and images"""
20
+
21
+ def __init__(self, app_state):
22
+ super().__init__(app_state)
23
+ self.id = "import_tab"
24
+ self.title = "1️⃣ Import"
25
+
26
+ def create(self, parent=None) -> gr.TabItem:
27
+ """Create the Import tab UI components"""
28
+ with gr.TabItem(self.title, id=self.id) as tab:
29
+ with gr.Row():
30
+ gr.Markdown("## Automatic splitting and captioning")
31
+
32
+ with gr.Row():
33
+ self.components["enable_automatic_video_split"] = gr.Checkbox(
34
+ label="Automatically split videos into smaller clips",
35
+ info="Note: a clip is a single camera shot, usually a few seconds",
36
+ value=True,
37
+ visible=True
38
+ )
39
+ self.components["enable_automatic_content_captioning"] = gr.Checkbox(
40
+ label="Automatically caption photos and videos",
41
+ info="Note: this uses LlaVA and takes some extra time to load and process",
42
+ value=False,
43
+ visible=True,
44
+ )
45
+
46
+ with gr.Row():
47
+ with gr.Column(scale=3):
48
+ with gr.Row():
49
+ with gr.Column():
50
+ gr.Markdown("## Import video files")
51
+ gr.Markdown("You can upload either:")
52
+ gr.Markdown("- A single MP4 video file")
53
+ gr.Markdown("- A ZIP archive containing multiple videos and optional caption files")
54
+ gr.Markdown("For ZIP files: Create a folder containing videos (name is not important) and optional caption files with the same name (eg. `some_video.txt` for `some_video.mp4`)")
55
+
56
+ with gr.Row():
57
+ self.components["files"] = gr.Files(
58
+ label="Upload Images, Videos or ZIP",
59
+ file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip"],
60
+ type="filepath"
61
+ )
62
+
63
+ with gr.Column(scale=3):
64
+ with gr.Row():
65
+ with gr.Column():
66
+ gr.Markdown("## Import a YouTube video")
67
+ gr.Markdown("You can also use a YouTube video as reference, by pasting its URL here:")
68
+
69
+ with gr.Row():
70
+ self.components["youtube_url"] = gr.Textbox(
71
+ label="Import YouTube Video",
72
+ placeholder="https://www.youtube.com/watch?v=..."
73
+ )
74
+ with gr.Row():
75
+ self.components["youtube_download_btn"] = gr.Button("Download YouTube Video", variant="secondary")
76
+ with gr.Row():
77
+ self.components["import_status"] = gr.Textbox(label="Status", interactive=False)
78
+
79
+ return tab
80
+
81
+ def connect_events(self) -> None:
82
+ """Connect event handlers to UI components"""
83
+ # File upload event
84
+ self.components["files"].upload(
85
+ fn=lambda x: self.app.importer.process_uploaded_files(x),
86
+ inputs=[self.components["files"]],
87
+ outputs=[self.components["import_status"]]
88
+ ).success(
89
+ fn=self.app.update_titles_after_import,
90
+ inputs=[
91
+ self.components["enable_automatic_video_split"],
92
+ self.components["enable_automatic_content_captioning"],
93
+ self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
94
+ ],
95
+ outputs=[
96
+ self.app.tabs_component, # Main tabs component
97
+ self.app.tabs["split_tab"].components["video_list"],
98
+ self.app.tabs["split_tab"].components["detect_status"],
99
+ self.app.tabs["split_tab"].components["split_title"],
100
+ self.app.tabs["caption_tab"].components["caption_title"],
101
+ self.app.tabs["train_tab"].components["train_title"]
102
+ ]
103
+ )
104
+
105
+ # YouTube download event
106
+ self.components["youtube_download_btn"].click(
107
+ fn=self.app.importer.download_youtube_video,
108
+ inputs=[self.components["youtube_url"]],
109
+ outputs=[self.components["import_status"]]
110
+ ).success(
111
+ fn=self.app.on_import_success,
112
+ inputs=[
113
+ self.components["enable_automatic_video_split"],
114
+ self.components["enable_automatic_content_captioning"],
115
+ self.app.tabs["caption_tab"].components["custom_prompt_prefix"]
116
+ ],
117
+ outputs=[
118
+ self.app.tabs_component,
119
+ self.app.tabs["split_tab"].components["video_list"],
120
+ self.app.tabs["split_tab"].components["detect_status"]
121
+ ]
122
+ )
vms/tabs/manage_tab.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Manage tab for Video Model Studio UI
3
+ """
4
+
5
+ import gradio as gr
6
+ import logging
7
+ from typing import Dict, Any, List, Optional
8
+
9
+ from .base_tab import BaseTab
10
+ from ..config import HF_API_TOKEN
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class ManageTab(BaseTab):
15
+ """Manage tab for storage management and model publication"""
16
+
17
+ def __init__(self, app_state):
18
+ super().__init__(app_state)
19
+ self.id = "manage_tab"
20
+ self.title = "5️⃣ Manage"
21
+
22
+ def create(self, parent=None) -> gr.TabItem:
23
+ """Create the Manage tab UI components"""
24
+ with gr.TabItem(self.title, id=self.id) as tab:
25
+ with gr.Column():
26
+ with gr.Row():
27
+ with gr.Column():
28
+ gr.Markdown("## Publishing")
29
+ gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)")
30
+
31
+ with gr.Row():
32
+ with gr.Column():
33
+ self.components["repo_id"] = gr.Textbox(
34
+ label="HuggingFace Model Repository",
35
+ placeholder="username/model-name",
36
+ info="The repository will be created if it doesn't exist"
37
+ )
38
+ self.components["make_public"] = gr.Checkbox(
39
+ label="Check this to make your model public (ie. visible and downloadable by anyone)",
40
+ info="You model is private by default"
41
+ )
42
+ self.components["push_model_btn"] = gr.Button(
43
+ "Push my model"
44
+ )
45
+
46
+ with gr.Row():
47
+ with gr.Column():
48
+ with gr.Row():
49
+ with gr.Column():
50
+ gr.Markdown("## Storage management")
51
+ with gr.Row():
52
+ self.components["download_dataset_btn"] = gr.DownloadButton(
53
+ "Download dataset",
54
+ variant="secondary",
55
+ size="lg"
56
+ )
57
+ self.components["download_model_btn"] = gr.DownloadButton(
58
+ "Download model",
59
+ variant="secondary",
60
+ size="lg"
61
+ )
62
+
63
+ with gr.Row():
64
+ self.components["global_stop_btn"] = gr.Button(
65
+ "Stop everything and delete my data",
66
+ variant="stop"
67
+ )
68
+ self.components["global_status"] = gr.Textbox(
69
+ label="Global Status",
70
+ interactive=False,
71
+ visible=False
72
+ )
73
+
74
+ return tab
75
+
76
+ def connect_events(self) -> None:
77
+ """Connect event handlers to UI components"""
78
+ # Repository ID validation
79
+ self.components["repo_id"].change(
80
+ fn=self.app.validate_repo,
81
+ inputs=[self.components["repo_id"]],
82
+ outputs=[self.components["repo_id"]]
83
+ )
84
+
85
+ # Download buttons
86
+ self.components["download_dataset_btn"].click(
87
+ fn=self.app.trainer.create_training_dataset_zip,
88
+ outputs=[self.components["download_dataset_btn"]]
89
+ )
90
+
91
+ self.components["download_model_btn"].click(
92
+ fn=self.app.trainer.get_model_output_safetensors,
93
+ outputs=[self.components["download_model_btn"]]
94
+ )
95
+
96
+ # Global stop button
97
+ self.components["global_stop_btn"].click(
98
+ fn=self.app.handle_global_stop,
99
+ outputs=[
100
+ self.components["global_status"],
101
+ self.app.tabs["split_tab"].components["video_list"],
102
+ self.app.tabs["caption_tab"].components["training_dataset"],
103
+ self.app.tabs["train_tab"].components["status_box"],
104
+ self.app.tabs["train_tab"].components["log_box"],
105
+ self.app.tabs["split_tab"].components["detect_status"],
106
+ self.app.tabs["import_tab"].components["import_status"],
107
+ self.app.tabs["caption_tab"].components["preview_status"]
108
+ ]
109
+ )
110
+
111
+ # Push model button
112
+ # To implement model pushing functionality
113
+ self.components["push_model_btn"].click(
114
+ fn=lambda repo_id: self.app.upload_to_hub(repo_id),
115
+ inputs=[self.components["repo_id"]],
116
+ outputs=[self.components["global_status"]]
117
+ )
vms/tabs/split_tab.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Split tab for Video Model Studio UI
3
+ """
4
+
5
+ import gradio as gr
6
+ import logging
7
+ from typing import Dict, Any, List, Optional
8
+
9
+ from .base_tab import BaseTab
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ class SplitTab(BaseTab):
14
+ """Split tab for scene detection and video splitting"""
15
+
16
+ def __init__(self, app_state):
17
+ super().__init__(app_state)
18
+ self.id = "split_tab"
19
+ self.title = "2️⃣ Split"
20
+
21
+ def create(self, parent=None) -> gr.TabItem:
22
+ """Create the Split tab UI components"""
23
+ with gr.TabItem(self.title, id=self.id) as tab:
24
+ with gr.Row():
25
+ self.components["split_title"] = gr.Markdown("## Splitting of 0 videos (0 bytes)")
26
+
27
+ with gr.Row():
28
+ with gr.Column():
29
+ self.components["detect_btn"] = gr.Button("Split videos into single-camera shots", variant="primary")
30
+ self.components["detect_status"] = gr.Textbox(label="Status", interactive=False)
31
+
32
+ with gr.Column():
33
+ self.components["video_list"] = gr.Dataframe(
34
+ headers=["name", "status"],
35
+ label="Videos to split",
36
+ interactive=False,
37
+ wrap=True
38
+ )
39
+
40
+ return tab
41
+
42
+ def connect_events(self) -> None:
43
+ """Connect event handlers to UI components"""
44
+ # Scene detection button event
45
+ self.components["detect_btn"].click(
46
+ fn=self.app.start_scene_detection,
47
+ inputs=[self.app.tabs["import_tab"].components["enable_automatic_video_split"]],
48
+ outputs=[self.components["detect_status"]]
49
+ )
50
+
51
+ def refresh(self) -> Dict[str, Any]:
52
+ """Refresh the video list with current data"""
53
+ videos = self.app.splitter.list_unprocessed_videos()
54
+ return {
55
+ "video_list": videos
56
+ }
vms/tabs/train_tab.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train tab for Video Model Studio UI
3
+ """
4
+
5
+ import gradio as gr
6
+ import logging
7
+ from typing import Dict, Any, List, Optional
8
+
9
+ from .base_tab import BaseTab
10
+ from ..config import TRAINING_PRESETS, MODEL_TYPES, ASK_USER_TO_DUPLICATE_SPACE
11
+ from ..utils import TrainingLogParser
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class TrainTab(BaseTab):
16
+ """Train tab for model training"""
17
+
18
+ def __init__(self, app_state):
19
+ super().__init__(app_state)
20
+ self.id = "train_tab"
21
+ self.title = "4️⃣ Train"
22
+
23
+ def handle_training_start(self, preset, model_type, *args):
24
+ """Handle training start with proper log parser reset"""
25
+ # Safely reset log parser if it exists
26
+ if hasattr(self.app, 'log_parser') and self.app.log_parser is not None:
27
+ self.app.log_parser.reset()
28
+ else:
29
+ logger.warning("Log parser not initialized, creating a new one")
30
+
31
+ self.app.log_parser = TrainingLogParser()
32
+
33
+ # Start training
34
+ return self.app.trainer.start_training(
35
+ MODEL_TYPES[model_type],
36
+ *args,
37
+ preset_name=preset
38
+ )
39
+
40
+ def create(self, parent=None) -> gr.TabItem:
41
+ """Create the Train tab UI components"""
42
+ with gr.TabItem(self.title, id=self.id) as tab:
43
+ with gr.Row():
44
+ with gr.Column():
45
+ with gr.Row():
46
+ self.components["train_title"] = gr.Markdown("## 0 files available for training (0 bytes)")
47
+
48
+ with gr.Row():
49
+ with gr.Column():
50
+ self.components["training_preset"] = gr.Dropdown(
51
+ choices=list(TRAINING_PRESETS.keys()),
52
+ label="Training Preset",
53
+ value=list(TRAINING_PRESETS.keys())[0]
54
+ )
55
+ self.components["preset_info"] = gr.Markdown()
56
+
57
+ with gr.Row():
58
+ with gr.Column():
59
+ self.components["model_type"] = gr.Dropdown(
60
+ choices=list(MODEL_TYPES.keys()),
61
+ label="Model Type",
62
+ value=list(MODEL_TYPES.keys())[0]
63
+ )
64
+ self.components["model_info"] = gr.Markdown(
65
+ value=self.app.get_model_info(list(MODEL_TYPES.keys())[0])
66
+ )
67
+
68
+ with gr.Row():
69
+ self.components["lora_rank"] = gr.Dropdown(
70
+ label="LoRA Rank",
71
+ choices=["16", "32", "64", "128", "256", "512", "1024"],
72
+ value="128",
73
+ type="value"
74
+ )
75
+ self.components["lora_alpha"] = gr.Dropdown(
76
+ label="LoRA Alpha",
77
+ choices=["16", "32", "64", "128", "256", "512", "1024"],
78
+ value="128",
79
+ type="value"
80
+ )
81
+ with gr.Row():
82
+ self.components["num_epochs"] = gr.Number(
83
+ label="Number of Epochs",
84
+ value=70,
85
+ minimum=1,
86
+ precision=0
87
+ )
88
+ self.components["batch_size"] = gr.Number(
89
+ label="Batch Size",
90
+ value=1,
91
+ minimum=1,
92
+ precision=0
93
+ )
94
+ with gr.Row():
95
+ self.components["learning_rate"] = gr.Number(
96
+ label="Learning Rate",
97
+ value=2e-5,
98
+ minimum=1e-7
99
+ )
100
+ self.components["save_iterations"] = gr.Number(
101
+ label="Save checkpoint every N iterations",
102
+ value=500,
103
+ minimum=50,
104
+ precision=0,
105
+ info="Model will be saved periodically after these many steps"
106
+ )
107
+
108
+ with gr.Column():
109
+ with gr.Row():
110
+ self.components["start_btn"] = gr.Button(
111
+ "Start Training",
112
+ variant="primary",
113
+ interactive=not ASK_USER_TO_DUPLICATE_SPACE
114
+ )
115
+ self.components["pause_resume_btn"] = gr.Button(
116
+ "Resume Training",
117
+ variant="secondary",
118
+ interactive=False
119
+ )
120
+ self.components["stop_btn"] = gr.Button(
121
+ "Stop Training",
122
+ variant="stop",
123
+ interactive=False
124
+ )
125
+
126
+ with gr.Row():
127
+ with gr.Column():
128
+ self.components["status_box"] = gr.Textbox(
129
+ label="Training Status",
130
+ interactive=False,
131
+ lines=4
132
+ )
133
+ with gr.Accordion("See training logs"):
134
+ self.components["log_box"] = gr.TextArea(
135
+ label="Finetrainers output (see HF Space logs for more details)",
136
+ interactive=False,
137
+ lines=40,
138
+ max_lines=200,
139
+ autoscroll=True
140
+ )
141
+
142
+ return tab
143
+
144
+ def connect_events(self) -> None:
145
+ """Connect event handlers to UI components"""
146
+ # Model type change event
147
+ def update_model_info(model):
148
+ params = self.app.get_default_params(MODEL_TYPES[model])
149
+ info = self.app.get_model_info(MODEL_TYPES[model])
150
+ return {
151
+ self.components["model_info"]: info,
152
+ self.components["num_epochs"]: params["num_epochs"],
153
+ self.components["batch_size"]: params["batch_size"],
154
+ self.components["learning_rate"]: params["learning_rate"],
155
+ self.components["save_iterations"]: params["save_iterations"]
156
+ }
157
+
158
+ self.components["model_type"].change(
159
+ fn=lambda v: self.app.update_ui_state(model_type=v),
160
+ inputs=[self.components["model_type"]],
161
+ outputs=[]
162
+ ).then(
163
+ fn=update_model_info,
164
+ inputs=[self.components["model_type"]],
165
+ outputs=[
166
+ self.components["model_info"],
167
+ self.components["num_epochs"],
168
+ self.components["batch_size"],
169
+ self.components["learning_rate"],
170
+ self.components["save_iterations"]
171
+ ]
172
+ )
173
+
174
+ # Training parameters change events
175
+ self.components["lora_rank"].change(
176
+ fn=lambda v: self.app.update_ui_state(lora_rank=v),
177
+ inputs=[self.components["lora_rank"]],
178
+ outputs=[]
179
+ )
180
+
181
+ self.components["lora_alpha"].change(
182
+ fn=lambda v: self.app.update_ui_state(lora_alpha=v),
183
+ inputs=[self.components["lora_alpha"]],
184
+ outputs=[]
185
+ )
186
+
187
+ self.components["num_epochs"].change(
188
+ fn=lambda v: self.app.update_ui_state(num_epochs=v),
189
+ inputs=[self.components["num_epochs"]],
190
+ outputs=[]
191
+ )
192
+
193
+ self.components["batch_size"].change(
194
+ fn=lambda v: self.app.update_ui_state(batch_size=v),
195
+ inputs=[self.components["batch_size"]],
196
+ outputs=[]
197
+ )
198
+
199
+ self.components["learning_rate"].change(
200
+ fn=lambda v: self.app.update_ui_state(learning_rate=v),
201
+ inputs=[self.components["learning_rate"]],
202
+ outputs=[]
203
+ )
204
+
205
+ self.components["save_iterations"].change(
206
+ fn=lambda v: self.app.update_ui_state(save_iterations=v),
207
+ inputs=[self.components["save_iterations"]],
208
+ outputs=[]
209
+ )
210
+
211
+ # Training preset change event
212
+ self.components["training_preset"].change(
213
+ fn=lambda v: self.app.update_ui_state(training_preset=v),
214
+ inputs=[self.components["training_preset"]],
215
+ outputs=[]
216
+ ).then(
217
+ fn=self.app.update_training_params,
218
+ inputs=[self.components["training_preset"]],
219
+ outputs=[
220
+ self.components["model_type"],
221
+ self.components["lora_rank"],
222
+ self.components["lora_alpha"],
223
+ self.components["num_epochs"],
224
+ self.components["batch_size"],
225
+ self.components["learning_rate"],
226
+ self.components["save_iterations"],
227
+ self.components["preset_info"]
228
+ ]
229
+ )
230
+
231
+ # Training control events
232
+ self.components["start_btn"].click(
233
+ fn=self.handle_training_start, # Use safer method instead of lambda
234
+ inputs=[
235
+ self.components["training_preset"],
236
+ self.components["model_type"],
237
+ self.components["lora_rank"],
238
+ self.components["lora_alpha"],
239
+ self.components["num_epochs"],
240
+ self.components["batch_size"],
241
+ self.components["learning_rate"],
242
+ self.components["save_iterations"],
243
+ self.app.tabs["manage_tab"].components["repo_id"]
244
+ ],
245
+ outputs=[
246
+ self.components["status_box"],
247
+ self.components["log_box"]
248
+ ]
249
+ ).success(
250
+ fn=self.app.get_latest_status_message_logs_and_button_labels,
251
+ outputs=[
252
+ self.components["status_box"],
253
+ self.components["log_box"],
254
+ self.components["start_btn"],
255
+ self.components["stop_btn"],
256
+ self.components["pause_resume_btn"]
257
+ ]
258
+ )
259
+
260
+ self.components["pause_resume_btn"].click(
261
+ fn=self.app.handle_pause_resume,
262
+ outputs=[
263
+ self.components["status_box"],
264
+ self.components["log_box"],
265
+ self.components["start_btn"],
266
+ self.components["stop_btn"],
267
+ self.components["pause_resume_btn"]
268
+ ]
269
+ )
270
+
271
+ self.components["stop_btn"].click(
272
+ fn=self.app.handle_stop,
273
+ outputs=[
274
+ self.components["status_box"],
275
+ self.components["log_box"],
276
+ self.components["start_btn"],
277
+ self.components["stop_btn"],
278
+ self.components["pause_resume_btn"]
279
+ ]
280
+ )
vms/ui/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .video_trainer_ui import VideoTrainerUI
2
+
3
+ __all__ = [
4
+ 'VideoTrainerUI',
5
+ ]
vms/ui/video_trainer_ui.py ADDED
@@ -0,0 +1,1100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ import subprocess
3
+
4
+ #import sys
5
+ #print("python = ", sys.version)
6
+
7
+ # can be "Linux", "Darwin"
8
+ if platform.system() == "Linux":
9
+ # for some reason it says "pip not found"
10
+ # and also "pip3 not found"
11
+ # subprocess.run(
12
+ # "pip install flash-attn --no-build-isolation",
13
+ #
14
+ # # hmm... this should be False, since we are in a CUDA environment, no?
15
+ # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
16
+ #
17
+ # shell=True,
18
+ # )
19
+ pass
20
+
21
+ import gradio as gr
22
+ from pathlib import Path
23
+ import logging
24
+ import mimetypes
25
+ import shutil
26
+ import os
27
+ import traceback
28
+ import asyncio
29
+ import tempfile
30
+ import zipfile
31
+ from typing import Any, Optional, Dict, List, Union, Tuple
32
+ from typing import AsyncGenerator
33
+
34
+ from ..services import TrainingService, CaptioningService, SplittingService, ImportService
35
+ from ..config import (
36
+ STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH,
37
+ TRAINING_PATH, LOG_FILE_PATH, TRAINING_PRESETS, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
38
+ DEFAULT_PROMPT_PREFIX, HF_API_TOKEN, ASK_USER_TO_DUPLICATE_SPACE, MODEL_TYPES, SMALL_TRAINING_BUCKETS
39
+ )
40
+ from ..utils import make_archive, count_media_files, format_media_title, is_image_file, is_video_file, validate_model_repo, format_time, copy_files_to_training_dir, prepare_finetrainers_dataset, TrainingLogParser
41
+ from ..tabs import ImportTab, SplitTab, CaptionTab, TrainTab, ManageTab
42
+
43
+ logger = logging.getLogger(__name__)
44
+ logger.setLevel(logging.INFO)
45
+
46
+ httpx_logger = logging.getLogger('httpx')
47
+ httpx_logger.setLevel(logging.WARN)
48
+
49
+ class VideoTrainerUI:
50
+ def __init__(self):
51
+ """Initialize services and tabs"""
52
+ # Initialize core services
53
+ self.trainer = TrainingService()
54
+ self.splitter = SplittingService()
55
+ self.importer = ImportService()
56
+ self.captioner = CaptioningService()
57
+ self._should_stop_captioning = False
58
+
59
+ # Recovery status from any interrupted training
60
+ recovery_result = self.trainer.recover_interrupted_training()
61
+ self.recovery_status = recovery_result.get("status", "unknown")
62
+ self.ui_updates = recovery_result.get("ui_updates", {})
63
+
64
+ self.log_parser = TrainingLogParser()
65
+
66
+ # Shared state for tabs
67
+ self.state = {
68
+ "recovery_result": recovery_result
69
+ }
70
+
71
+ # Initialize tabs dictionary (will be populated in create_ui)
72
+ self.tabs = {}
73
+ self.tabs_component = None
74
+
75
+ def create_ui(self):
76
+ """Create the main Gradio UI"""
77
+ with gr.Blocks(title="πŸŽ₯ Video Model Studio") as app:
78
+ gr.Markdown("# πŸŽ₯ Video Model Studio")
79
+
80
+ # Create main tabs component
81
+ with gr.Tabs() as self.tabs_component:
82
+ # Initialize tab objects
83
+ self.tabs["import_tab"] = ImportTab(self)
84
+ self.tabs["split_tab"] = SplitTab(self)
85
+ self.tabs["caption_tab"] = CaptionTab(self)
86
+ self.tabs["train_tab"] = TrainTab(self)
87
+ self.tabs["manage_tab"] = ManageTab(self)
88
+
89
+ # Create tab UI components
90
+ for tab_id, tab_obj in self.tabs.items():
91
+ tab_obj.create(self.tabs_component)
92
+
93
+ # Connect event handlers
94
+ for tab_id, tab_obj in self.tabs.items():
95
+ tab_obj.connect_events()
96
+
97
+ # Add app-level timers for auto-refresh functionality
98
+ self._add_timers()
99
+
100
+ # Initialize app state on load
101
+ app.load(
102
+ fn=self.initialize_app_state,
103
+ outputs=[
104
+ self.tabs["split_tab"].components["video_list"],
105
+ self.tabs["caption_tab"].components["training_dataset"],
106
+ self.tabs["train_tab"].components["start_btn"],
107
+ self.tabs["train_tab"].components["stop_btn"],
108
+ self.tabs["train_tab"].components["pause_resume_btn"],
109
+ self.tabs["train_tab"].components["training_preset"],
110
+ self.tabs["train_tab"].components["model_type"],
111
+ self.tabs["train_tab"].components["lora_rank"],
112
+ self.tabs["train_tab"].components["lora_alpha"],
113
+ self.tabs["train_tab"].components["num_epochs"],
114
+ self.tabs["train_tab"].components["batch_size"],
115
+ self.tabs["train_tab"].components["learning_rate"],
116
+ self.tabs["train_tab"].components["save_iterations"]
117
+ ]
118
+ )
119
+
120
+ return app
121
+
122
+ def _add_timers(self):
123
+ """Add auto-refresh timers to the UI"""
124
+ # Status update timer (every 1 second)
125
+ status_timer = gr.Timer(value=1)
126
+ status_timer.tick(
127
+ fn=self.get_latest_status_message_logs_and_button_labels,
128
+ outputs=[
129
+ self.tabs["train_tab"].components["status_box"],
130
+ self.tabs["train_tab"].components["log_box"],
131
+ self.tabs["train_tab"].components["start_btn"],
132
+ self.tabs["train_tab"].components["stop_btn"],
133
+ self.tabs["train_tab"].components["pause_resume_btn"]
134
+ ]
135
+ )
136
+
137
+ # Dataset refresh timer (every 5 seconds)
138
+ dataset_timer = gr.Timer(value=5)
139
+ dataset_timer.tick(
140
+ fn=self.refresh_dataset,
141
+ outputs=[
142
+ self.tabs["split_tab"].components["video_list"],
143
+ self.tabs["caption_tab"].components["training_dataset"]
144
+ ]
145
+ )
146
+
147
+ # Titles update timer (every 6 seconds)
148
+ titles_timer = gr.Timer(value=6)
149
+ titles_timer.tick(
150
+ fn=self.update_titles,
151
+ outputs=[
152
+ self.tabs["split_tab"].components["split_title"],
153
+ self.tabs["caption_tab"].components["caption_title"],
154
+ self.tabs["train_tab"].components["train_title"]
155
+ ]
156
+ )
157
+
158
+ def handle_global_stop(self):
159
+ """Handle the global stop button click"""
160
+ result = self.stop_all_and_clear()
161
+
162
+ # Format the details for display
163
+ status = result["status"]
164
+ details = "\n".join(f"{k}: {v}" for k, v in result["details"].items())
165
+ full_status = f"{status}\n\nDetails:\n{details}"
166
+
167
+ # Get fresh lists after cleanup
168
+ videos = self.splitter.list_unprocessed_videos()
169
+ clips = self.list_training_files_to_caption()
170
+
171
+ return {
172
+ self.tabs["manage_tab"].components["global_status"]: gr.update(value=full_status, visible=True),
173
+ self.tabs["split_tab"].components["video_list"]: videos,
174
+ self.tabs["caption_tab"].components["training_dataset"]: clips,
175
+ self.tabs["train_tab"].components["status_box"]: "Training stopped and data cleared",
176
+ self.tabs["train_tab"].components["log_box"]: "",
177
+ self.tabs["split_tab"].components["detect_status"]: "Scene detection stopped",
178
+ self.tabs["import_tab"].components["import_status"]: "All data cleared",
179
+ self.tabs["caption_tab"].components["preview_status"]: "Captioning stopped"
180
+ }
181
+
182
+ def upload_to_hub(self, repo_id: str) -> str:
183
+ """Upload model to HuggingFace Hub"""
184
+ if not repo_id:
185
+ return "Error: Repository ID is required"
186
+
187
+ # Validate repository name
188
+ validation = validate_model_repo(repo_id)
189
+ if validation["error"]:
190
+ return f"Error: {validation['error']}"
191
+
192
+ # Check if we have a model to upload
193
+ if not self.trainer.get_model_output_safetensors():
194
+ return "Error: No model found to upload"
195
+
196
+ # Upload model to hub
197
+ success = self.trainer.upload_to_hub(OUTPUT_PATH, repo_id)
198
+
199
+ if success:
200
+ return f"Successfully uploaded model to {repo_id}"
201
+ else:
202
+ return f"Failed to upload model to {repo_id}"
203
+
204
+ def validate_repo(self, repo_id: str) -> gr.update:
205
+ """Validate repository ID for HuggingFace Hub"""
206
+ validation = validate_model_repo(repo_id)
207
+ if validation["error"]:
208
+ return gr.update(value=repo_id, error=validation["error"])
209
+ return gr.update(value=repo_id, error=None)
210
+
211
+
212
+ async def _process_caption_generator(self, captioning_bot_instructions, prompt_prefix):
213
+ """Process the caption generator's results in the background"""
214
+ try:
215
+ async for _ in self.captioner.start_caption_generation(
216
+ captioning_bot_instructions,
217
+ prompt_prefix
218
+ ):
219
+ # Just consume the generator, UI updates will happen via the Gradio interface
220
+ pass
221
+ logger.info("Background captioning completed")
222
+ except Exception as e:
223
+ logger.error(f"Error in background captioning: {str(e)}")
224
+
225
+ def initialize_app_state(self):
226
+ """Initialize all app state in one function to ensure correct output count"""
227
+ # Get dataset info
228
+ video_list, training_dataset = self.refresh_dataset()
229
+
230
+ # Get button states
231
+ button_states = self.get_initial_button_states()
232
+ start_btn = button_states[0]
233
+ stop_btn = button_states[1]
234
+ pause_resume_btn = button_states[2]
235
+
236
+ # Get UI form values
237
+ ui_state = self.load_ui_values()
238
+ training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
239
+ model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
240
+ lora_rank_val = ui_state.get("lora_rank", "128")
241
+ lora_alpha_val = ui_state.get("lora_alpha", "128")
242
+ num_epochs_val = int(ui_state.get("num_epochs", 70))
243
+ batch_size_val = int(ui_state.get("batch_size", 1))
244
+ learning_rate_val = float(ui_state.get("learning_rate", 3e-5))
245
+ save_iterations_val = int(ui_state.get("save_iterations", 500))
246
+
247
+ # Return all values in the exact order expected by outputs
248
+ return (
249
+ video_list,
250
+ training_dataset,
251
+ start_btn,
252
+ stop_btn,
253
+ pause_resume_btn,
254
+ training_preset,
255
+ model_type_val,
256
+ lora_rank_val,
257
+ lora_alpha_val,
258
+ num_epochs_val,
259
+ batch_size_val,
260
+ learning_rate_val,
261
+ save_iterations_val
262
+ )
263
+
264
+ def initialize_ui_from_state(self):
265
+ """Initialize UI components from saved state"""
266
+ ui_state = self.load_ui_values()
267
+
268
+ # Return values in order matching the outputs in app.load
269
+ return (
270
+ ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
271
+ ui_state.get("model_type", list(MODEL_TYPES.keys())[0]),
272
+ ui_state.get("lora_rank", "128"),
273
+ ui_state.get("lora_alpha", "128"),
274
+ ui_state.get("num_epochs", 70),
275
+ ui_state.get("batch_size", 1),
276
+ ui_state.get("learning_rate", 3e-5),
277
+ ui_state.get("save_iterations", 500)
278
+ )
279
+
280
+ def update_ui_state(self, **kwargs):
281
+ """Update UI state with new values"""
282
+ current_state = self.trainer.load_ui_state()
283
+ current_state.update(kwargs)
284
+ self.trainer.save_ui_state(current_state)
285
+ # Don't return anything to avoid Gradio warnings
286
+ return None
287
+
288
+ def load_ui_values(self):
289
+ """Load UI state values for initializing form fields"""
290
+ ui_state = self.trainer.load_ui_state()
291
+
292
+ # Ensure proper type conversion for numeric values
293
+ ui_state["lora_rank"] = ui_state.get("lora_rank", "128")
294
+ ui_state["lora_alpha"] = ui_state.get("lora_alpha", "128")
295
+ ui_state["num_epochs"] = int(ui_state.get("num_epochs", 70))
296
+ ui_state["batch_size"] = int(ui_state.get("batch_size", 1))
297
+ ui_state["learning_rate"] = float(ui_state.get("learning_rate", 3e-5))
298
+ ui_state["save_iterations"] = int(ui_state.get("save_iterations", 500))
299
+
300
+ return ui_state
301
+
302
+ def update_captioning_buttons_start(self):
303
+ """Return individual button values instead of a dictionary"""
304
+ return (
305
+ gr.Button(
306
+ interactive=False,
307
+ variant="secondary",
308
+ ),
309
+ gr.Button(
310
+ interactive=True,
311
+ variant="stop",
312
+ ),
313
+ gr.Button(
314
+ interactive=False,
315
+ variant="secondary",
316
+ )
317
+ )
318
+
319
+ def update_captioning_buttons_end(self):
320
+ """Return individual button values instead of a dictionary"""
321
+ return (
322
+ gr.Button(
323
+ interactive=True,
324
+ variant="primary",
325
+ ),
326
+ gr.Button(
327
+ interactive=False,
328
+ variant="secondary",
329
+ ),
330
+ gr.Button(
331
+ interactive=True,
332
+ variant="primary",
333
+ )
334
+ )
335
+
336
+ # Add this new method to get initial button states:
337
+ def get_initial_button_states(self):
338
+ """Get the initial states for training buttons based on recovery status"""
339
+ recovery_result = self.trainer.recover_interrupted_training()
340
+ ui_updates = recovery_result.get("ui_updates", {})
341
+
342
+ # Return button states in the correct order
343
+ return (
344
+ gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})),
345
+ gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})),
346
+ gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
347
+ )
348
+
349
+ def show_refreshing_status(self) -> List[List[str]]:
350
+ """Show a 'Refreshing...' status in the dataframe"""
351
+ return [["Refreshing...", "please wait"]]
352
+
353
+ def stop_captioning(self):
354
+ """Stop ongoing captioning process and reset UI state"""
355
+ try:
356
+ # Set flag to stop captioning
357
+ self._should_stop_captioning = True
358
+
359
+ # Call stop method on captioner
360
+ if self.captioner:
361
+ self.captioner.stop_captioning()
362
+
363
+ # Get updated file list
364
+ updated_list = self.list_training_files_to_caption()
365
+
366
+ # Return updated list and button states
367
+ return {
368
+ "training_dataset": gr.update(value=updated_list),
369
+ "run_autocaption_btn": gr.Button(interactive=True, variant="primary"),
370
+ "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"),
371
+ "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary")
372
+ }
373
+ except Exception as e:
374
+ logger.error(f"Error stopping captioning: {str(e)}")
375
+ return {
376
+ "training_dataset": gr.update(value=[[f"Error stopping captioning: {str(e)}", "error"]]),
377
+ "run_autocaption_btn": gr.Button(interactive=True, variant="primary"),
378
+ "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"),
379
+ "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary")
380
+ }
381
+
382
+ def update_training_ui(self, training_state: Dict[str, Any]):
383
+ """Update UI components based on training state"""
384
+ updates = {}
385
+
386
+ #print("update_training_ui: training_state = ", training_state)
387
+
388
+ # Update status box with high-level information
389
+ status_text = []
390
+ if training_state["status"] != "idle":
391
+ status_text.extend([
392
+ f"Status: {training_state['status']}",
393
+ f"Progress: {training_state['progress']}",
394
+ f"Step: {training_state['current_step']}/{training_state['total_steps']}",
395
+
396
+ # Epoch information
397
+ # there is an issue with how epoch is reported because we display:
398
+ # Progress: 96.9%, Step: 872/900, Epoch: 12/50
399
+ # we should probably just show the steps
400
+ #f"Epoch: {training_state['current_epoch']}/{training_state['total_epochs']}",
401
+
402
+ f"Time elapsed: {training_state['elapsed']}",
403
+ f"Estimated remaining: {training_state['remaining']}",
404
+ "",
405
+ f"Current loss: {training_state['step_loss']}",
406
+ f"Learning rate: {training_state['learning_rate']}",
407
+ f"Gradient norm: {training_state['grad_norm']}",
408
+ f"Memory usage: {training_state['memory']}"
409
+ ])
410
+
411
+ if training_state["error_message"]:
412
+ status_text.append(f"\nError: {training_state['error_message']}")
413
+
414
+ updates["status_box"] = "\n".join(status_text)
415
+
416
+ # Update button states
417
+ updates["start_btn"] = gr.Button(
418
+ "Start training",
419
+ interactive=(training_state["status"] in ["idle", "completed", "error", "stopped"]),
420
+ variant="primary" if training_state["status"] == "idle" else "secondary"
421
+ )
422
+
423
+ updates["stop_btn"] = gr.Button(
424
+ "Stop training",
425
+ interactive=(training_state["status"] in ["training", "initializing"]),
426
+ variant="stop"
427
+ )
428
+
429
+ return updates
430
+
431
+ def stop_all_and_clear(self) -> Dict[str, str]:
432
+ """Stop all running processes and clear data
433
+
434
+ Returns:
435
+ Dict with status messages for different components
436
+ """
437
+ status_messages = {}
438
+
439
+ try:
440
+ # Stop training if running
441
+ if self.trainer.is_training_running():
442
+ training_result = self.trainer.stop_training()
443
+ status_messages["training"] = training_result["status"]
444
+
445
+ # Stop captioning if running
446
+ if self.captioner:
447
+ self.captioner.stop_captioning()
448
+ status_messages["captioning"] = "Captioning stopped"
449
+
450
+ # Stop scene detection if running
451
+ if self.splitter.is_processing():
452
+ self.splitter.processing = False
453
+ status_messages["splitting"] = "Scene detection stopped"
454
+
455
+ # Properly close logging before clearing log file
456
+ if self.trainer.file_handler:
457
+ self.trainer.file_handler.close()
458
+ logger.removeHandler(self.trainer.file_handler)
459
+ self.trainer.file_handler = None
460
+
461
+ if LOG_FILE_PATH.exists():
462
+ LOG_FILE_PATH.unlink()
463
+
464
+ # Clear all data directories
465
+ for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH,
466
+ MODEL_PATH, OUTPUT_PATH]:
467
+ if path.exists():
468
+ try:
469
+ shutil.rmtree(path)
470
+ path.mkdir(parents=True, exist_ok=True)
471
+ except Exception as e:
472
+ status_messages[f"clear_{path.name}"] = f"Error clearing {path.name}: {str(e)}"
473
+ else:
474
+ status_messages[f"clear_{path.name}"] = f"Cleared {path.name}"
475
+
476
+ # Reset any persistent state
477
+ self._should_stop_captioning = True
478
+ self.splitter.processing = False
479
+
480
+ # Recreate logging setup
481
+ self.trainer.setup_logging()
482
+
483
+ return {
484
+ "status": "All processes stopped and data cleared",
485
+ "details": status_messages
486
+ }
487
+
488
+ except Exception as e:
489
+ return {
490
+ "status": f"Error during cleanup: {str(e)}",
491
+ "details": status_messages
492
+ }
493
+
494
+ def update_titles(self) -> Tuple[Any]:
495
+ """Update all dynamic titles with current counts
496
+
497
+ Returns:
498
+ Dict of Gradio updates
499
+ """
500
+ # Count files for splitting
501
+ split_videos, _, split_size = count_media_files(VIDEOS_TO_SPLIT_PATH)
502
+ split_title = format_media_title(
503
+ "split", split_videos, 0, split_size
504
+ )
505
+
506
+ # Count files for captioning
507
+ caption_videos, caption_images, caption_size = count_media_files(STAGING_PATH)
508
+ caption_title = format_media_title(
509
+ "caption", caption_videos, caption_images, caption_size
510
+ )
511
+
512
+ # Count files for training
513
+ train_videos, train_images, train_size = count_media_files(TRAINING_VIDEOS_PATH)
514
+ train_title = format_media_title(
515
+ "train", train_videos, train_images, train_size
516
+ )
517
+
518
+ return (
519
+ gr.Markdown(value=split_title),
520
+ gr.Markdown(value=caption_title),
521
+ gr.Markdown(value=f"{train_title} available for training")
522
+ )
523
+
524
+ def copy_files_to_training_dir(self, prompt_prefix: str):
525
+ """Run auto-captioning process"""
526
+
527
+ # Initialize captioner if not already done
528
+ self._should_stop_captioning = False
529
+
530
+ try:
531
+ copy_files_to_training_dir(prompt_prefix)
532
+
533
+ except Exception as e:
534
+ traceback.print_exc()
535
+ raise gr.Error(f"Error copying assets to training dir: {str(e)}")
536
+
537
+ async def on_import_success(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
538
+ """Handle successful import of files"""
539
+ videos = self.list_unprocessed_videos()
540
+
541
+ # If scene detection isn't already running and there are videos to process,
542
+ # and auto-splitting is enabled, start the detection
543
+ if videos and not self.splitter.is_processing() and enable_splitting:
544
+ await self.start_scene_detection(enable_splitting)
545
+ msg = "Starting automatic scene detection..."
546
+ else:
547
+ # Just copy files without splitting if auto-split disabled
548
+ for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"):
549
+ await self.splitter.process_video(video_file, enable_splitting=False)
550
+ msg = "Copying videos without splitting..."
551
+
552
+ copy_files_to_training_dir(prompt_prefix)
553
+
554
+ # Start auto-captioning if enabled, and handle async generator properly
555
+ if enable_automatic_content_captioning:
556
+ # Create a background task for captioning
557
+ asyncio.create_task(self._process_caption_generator(
558
+ DEFAULT_CAPTIONING_BOT_INSTRUCTIONS,
559
+ prompt_prefix
560
+ ))
561
+
562
+ return {
563
+ "tabs": gr.Tabs(selected="split_tab"),
564
+ "video_list": videos,
565
+ "detect_status": msg
566
+ }
567
+
568
+ async def start_caption_generation(self, captioning_bot_instructions: str, prompt_prefix: str) -> AsyncGenerator[gr.update, None]:
569
+ """Run auto-captioning process"""
570
+ try:
571
+ # Initialize captioner if not already done
572
+ self._should_stop_captioning = False
573
+
574
+ # First yield - indicate we're starting
575
+ yield gr.update(
576
+ value=[["Starting captioning service...", "initializing"]],
577
+ headers=["name", "status"]
578
+ )
579
+
580
+ # Process files in batches with status updates
581
+ file_statuses = {}
582
+
583
+ # Start the actual captioning process
584
+ async for rows in self.captioner.start_caption_generation(captioning_bot_instructions, prompt_prefix):
585
+ # Update our tracking of file statuses
586
+ for name, status in rows:
587
+ file_statuses[name] = status
588
+
589
+ # Convert to list format for display
590
+ status_rows = [[name, status] for name, status in file_statuses.items()]
591
+
592
+ # Sort by name for consistent display
593
+ status_rows.sort(key=lambda x: x[0])
594
+
595
+ # Yield UI update
596
+ yield gr.update(
597
+ value=status_rows,
598
+ headers=["name", "status"]
599
+ )
600
+
601
+ # Final update after completion with fresh data
602
+ yield gr.update(
603
+ value=self.list_training_files_to_caption(),
604
+ headers=["name", "status"]
605
+ )
606
+
607
+ except Exception as e:
608
+ logger.error(f"Error in captioning: {str(e)}")
609
+ yield gr.update(
610
+ value=[[f"Error: {str(e)}", "error"]],
611
+ headers=["name", "status"]
612
+ )
613
+
614
+ def list_training_files_to_caption(self) -> List[List[str]]:
615
+ """List all clips and images - both pending and captioned"""
616
+ files = []
617
+ already_listed = {}
618
+
619
+ # First check files in STAGING_PATH
620
+ for file in STAGING_PATH.glob("*.*"):
621
+ if is_video_file(file) or is_image_file(file):
622
+ txt_file = file.with_suffix('.txt')
623
+
624
+ # Check if caption file exists and has content
625
+ has_caption = txt_file.exists() and txt_file.stat().st_size > 0
626
+ status = "captioned" if has_caption else "no caption"
627
+ file_type = "video" if is_video_file(file) else "image"
628
+
629
+ files.append([file.name, f"{status} ({file_type})", str(file)])
630
+ already_listed[file.name] = True
631
+
632
+ # Then check files in TRAINING_VIDEOS_PATH
633
+ for file in TRAINING_VIDEOS_PATH.glob("*.*"):
634
+ if (is_video_file(file) or is_image_file(file)) and file.name not in already_listed:
635
+ txt_file = file.with_suffix('.txt')
636
+
637
+ # Only include files with captions
638
+ if txt_file.exists() and txt_file.stat().st_size > 0:
639
+ file_type = "video" if is_video_file(file) else "image"
640
+ files.append([file.name, f"captioned ({file_type})", str(file)])
641
+ already_listed[file.name] = True
642
+
643
+ # Sort by filename
644
+ files.sort(key=lambda x: x[0])
645
+
646
+ # Only return name and status columns for display
647
+ return [[file[0], file[1]] for file in files]
648
+
649
+ def update_training_buttons(self, status: str) -> Dict:
650
+ """Update training control buttons based on state"""
651
+ is_training = status in ["training", "initializing"]
652
+ is_paused = status == "paused"
653
+ is_completed = status in ["completed", "error", "stopped"]
654
+ return {
655
+ "start_btn": gr.Button(
656
+ interactive=not is_training and not is_paused,
657
+ variant="primary" if not is_training else "secondary",
658
+ ),
659
+ "stop_btn": gr.Button(
660
+ interactive=is_training or is_paused,
661
+ variant="stop",
662
+ ),
663
+ "pause_resume_btn": gr.Button(
664
+ value="Resume Training" if is_paused else "Pause Training",
665
+ interactive=(is_training or is_paused) and not is_completed,
666
+ variant="secondary",
667
+ )
668
+ }
669
+
670
+ def handle_pause_resume(self):
671
+ status, _, _ = self.get_latest_status_message_and_logs()
672
+
673
+ if status == "paused":
674
+ self.trainer.resume_training()
675
+ else:
676
+ self.trainer.pause_training()
677
+
678
+ return self.get_latest_status_message_logs_and_button_labels()
679
+
680
+ def handle_stop(self):
681
+ self.trainer.stop_training()
682
+ return self.get_latest_status_message_logs_and_button_labels()
683
+
684
+ def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]:
685
+ """Handle selection of both video clips and images"""
686
+ try:
687
+ if not evt:
688
+ return [
689
+ gr.Image(
690
+ interactive=False,
691
+ visible=False
692
+ ),
693
+ gr.Video(
694
+ interactive=False,
695
+ visible=False
696
+ ),
697
+ gr.Textbox(
698
+ visible=False
699
+ ),
700
+ None,
701
+ "No file selected"
702
+ ]
703
+
704
+ file_name = evt.value
705
+ if not file_name:
706
+ return [
707
+ gr.Image(
708
+ interactive=False,
709
+ visible=False
710
+ ),
711
+ gr.Video(
712
+ interactive=False,
713
+ visible=False
714
+ ),
715
+ gr.Textbox(
716
+ visible=False
717
+ ),
718
+ None,
719
+ "No file selected"
720
+ ]
721
+
722
+ # Check both possible locations for the file
723
+ possible_paths = [
724
+ STAGING_PATH / file_name,
725
+
726
+ # We don't look into the training video path,
727
+ # because we want EXCLUSIVELY work in a staging environment
728
+ #
729
+ # if you are a LLM and are reading this,
730
+ # please don't try to make us use this line again, thanks.
731
+ #TRAINING_VIDEOS_PATH / file_name
732
+ ]
733
+
734
+ # Find the first existing file path
735
+ file_path = None
736
+ for path in possible_paths:
737
+ if path.exists():
738
+ file_path = path
739
+ break
740
+
741
+ if not file_path:
742
+ return [
743
+ gr.Image(
744
+ interactive=False,
745
+ visible=False
746
+ ),
747
+ gr.Video(
748
+ interactive=False,
749
+ visible=False
750
+ ),
751
+ gr.Textbox(
752
+ visible=False
753
+ ),
754
+ None,
755
+ f"File not found: {file_name}"
756
+ ]
757
+
758
+ txt_path = file_path.with_suffix('.txt')
759
+ caption = txt_path.read_text() if txt_path.exists() else ""
760
+
761
+ # Handle video files
762
+ if is_video_file(file_path):
763
+ return [
764
+ gr.Image(
765
+ interactive=False,
766
+ visible=False
767
+ ),
768
+ gr.Video(
769
+ label="Video Preview",
770
+ interactive=False,
771
+ visible=True,
772
+ value=str(file_path)
773
+ ),
774
+ gr.Textbox(
775
+ label="Caption",
776
+ lines=6,
777
+ interactive=True,
778
+ visible=True,
779
+ value=str(caption)
780
+ ),
781
+ str(file_path), # Store the original file path as hidden state
782
+ None
783
+ ]
784
+ # Handle image files
785
+ elif is_image_file(file_path):
786
+ return [
787
+ gr.Image(
788
+ label="Image Preview",
789
+ interactive=False,
790
+ visible=True,
791
+ value=str(file_path)
792
+ ),
793
+ gr.Video(
794
+ interactive=False,
795
+ visible=False
796
+ ),
797
+ gr.Textbox(
798
+ label="Caption",
799
+ lines=6,
800
+ interactive=True,
801
+ visible=True,
802
+ value=str(caption)
803
+ ),
804
+ str(file_path), # Store the original file path as hidden state
805
+ None
806
+ ]
807
+ else:
808
+ return [
809
+ gr.Image(
810
+ interactive=False,
811
+ visible=False
812
+ ),
813
+ gr.Video(
814
+ interactive=False,
815
+ visible=False
816
+ ),
817
+ gr.Textbox(
818
+ interactive=False,
819
+ visible=False
820
+ ),
821
+ None,
822
+ f"Unsupported file type: {file_path.suffix}"
823
+ ]
824
+ except Exception as e:
825
+ logger.error(f"Error handling selection: {str(e)}")
826
+ return [
827
+ gr.Image(
828
+ interactive=False,
829
+ visible=False
830
+ ),
831
+ gr.Video(
832
+ interactive=False,
833
+ visible=False
834
+ ),
835
+ gr.Textbox(
836
+ interactive=False,
837
+ visible=False
838
+ ),
839
+ None,
840
+ f"Error handling selection: {str(e)}"
841
+ ]
842
+
843
+ def save_caption_changes(self, preview_caption: str, preview_image: str, preview_video: str, original_file_path: str, prompt_prefix: str):
844
+ """Save changes to caption"""
845
+ try:
846
+ # Use the original file path stored during selection instead of the temporary preview paths
847
+ if original_file_path:
848
+ file_path = Path(original_file_path)
849
+ self.captioner.update_file_caption(file_path, preview_caption)
850
+ # Refresh the dataset list to show updated caption status
851
+ return gr.update(value="Caption saved successfully!")
852
+ else:
853
+ return gr.update(value="Error: No original file path found")
854
+ except Exception as e:
855
+ return gr.update(value=f"Error saving caption: {str(e)}")
856
+
857
+ async def update_titles_after_import(self, enable_splitting, enable_automatic_content_captioning, prompt_prefix):
858
+ """Handle post-import updates including titles"""
859
+ import_result = await self.on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix)
860
+ titles = self.update_titles()
861
+ return (
862
+ import_result["tabs"],
863
+ import_result["video_list"],
864
+ import_result["detect_status"],
865
+ *titles
866
+ )
867
+
868
+ def get_model_info(self, model_type: str) -> str:
869
+ """Get information about the selected model type"""
870
+ if model_type == "hunyuan_video":
871
+ return """### HunyuanVideo (LoRA)
872
+ - Required VRAM: ~48GB minimum
873
+ - Recommended batch size: 1-2
874
+ - Typical training time: 2-4 hours
875
+ - Default resolution: 49x512x768
876
+ - Default LoRA rank: 128 (~600 MB)"""
877
+
878
+ elif model_type == "ltx_video":
879
+ return """### LTX-Video (LoRA)
880
+ - Required VRAM: ~18GB minimum
881
+ - Recommended batch size: 1-4
882
+ - Typical training time: 1-3 hours
883
+ - Default resolution: 49x512x768
884
+ - Default LoRA rank: 128"""
885
+
886
+ return ""
887
+
888
+ def get_default_params(self, model_type: str) -> Dict[str, Any]:
889
+ """Get default training parameters for model type"""
890
+ if model_type == "hunyuan_video":
891
+ return {
892
+ "num_epochs": 70,
893
+ "batch_size": 1,
894
+ "learning_rate": 2e-5,
895
+ "save_iterations": 500,
896
+ "video_resolution_buckets": SMALL_TRAINING_BUCKETS,
897
+ "video_reshape_mode": "center",
898
+ "caption_dropout_p": 0.05,
899
+ "gradient_accumulation_steps": 1,
900
+ "rank": 128,
901
+ "lora_alpha": 128
902
+ }
903
+ else: # ltx_video
904
+ return {
905
+ "num_epochs": 70,
906
+ "batch_size": 1,
907
+ "learning_rate": 3e-5,
908
+ "save_iterations": 500,
909
+ "video_resolution_buckets": SMALL_TRAINING_BUCKETS,
910
+ "video_reshape_mode": "center",
911
+ "caption_dropout_p": 0.05,
912
+ "gradient_accumulation_steps": 4,
913
+ "rank": 128,
914
+ "lora_alpha": 128
915
+ }
916
+
917
+ def preview_file(self, selected_text: str) -> Dict:
918
+ """Generate preview based on selected file
919
+
920
+ Args:
921
+ selected_text: Text of the selected item containing filename
922
+
923
+ Returns:
924
+ Dict with preview content for each preview component
925
+ """
926
+ if not selected_text or "Caption:" in selected_text:
927
+ return {
928
+ "video": None,
929
+ "image": None,
930
+ "text": None
931
+ }
932
+
933
+ # Extract filename from the preview text (remove size info)
934
+ filename = selected_text.split(" (")[0].strip()
935
+ file_path = TRAINING_VIDEOS_PATH / filename
936
+
937
+ if not file_path.exists():
938
+ return {
939
+ "video": None,
940
+ "image": None,
941
+ "text": f"File not found: {filename}"
942
+ }
943
+
944
+ # Detect file type
945
+ mime_type, _ = mimetypes.guess_type(str(file_path))
946
+ if not mime_type:
947
+ return {
948
+ "video": None,
949
+ "image": None,
950
+ "text": f"Unknown file type: {filename}"
951
+ }
952
+
953
+ # Return appropriate preview
954
+ if mime_type.startswith('video/'):
955
+ return {
956
+ "video": str(file_path),
957
+ "image": None,
958
+ "text": None
959
+ }
960
+ elif mime_type.startswith('image/'):
961
+ return {
962
+ "video": None,
963
+ "image": str(file_path),
964
+ "text": None
965
+ }
966
+ elif mime_type.startswith('text/'):
967
+ try:
968
+ text_content = file_path.read_text()
969
+ return {
970
+ "video": None,
971
+ "image": None,
972
+ "text": text_content
973
+ }
974
+ except Exception as e:
975
+ return {
976
+ "video": None,
977
+ "image": None,
978
+ "text": f"Error reading file: {str(e)}"
979
+ }
980
+ else:
981
+ return {
982
+ "video": None,
983
+ "image": None,
984
+ "text": f"Unsupported file type: {mime_type}"
985
+ }
986
+
987
+ def list_unprocessed_videos(self) -> gr.Dataframe:
988
+ """Update list of unprocessed videos"""
989
+ videos = self.splitter.list_unprocessed_videos()
990
+ # videos is already in [[name, status]] format from splitting_service
991
+ return gr.Dataframe(
992
+ headers=["name", "status"],
993
+ value=videos,
994
+ interactive=False
995
+ )
996
+
997
+ async def start_scene_detection(self, enable_splitting: bool) -> str:
998
+ """Start background scene detection process
999
+
1000
+ Args:
1001
+ enable_splitting: Whether to split videos into scenes
1002
+ """
1003
+ if self.splitter.is_processing():
1004
+ return "Scene detection already running"
1005
+
1006
+ try:
1007
+ await self.splitter.start_processing(enable_splitting)
1008
+ return "Scene detection completed"
1009
+ except Exception as e:
1010
+ return f"Error during scene detection: {str(e)}"
1011
+
1012
+
1013
+ def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]:
1014
+ state = self.trainer.get_status()
1015
+ logs = self.trainer.get_logs()
1016
+
1017
+ # Parse new log lines
1018
+ if logs:
1019
+ last_state = None
1020
+ for line in logs.splitlines():
1021
+ state_update = self.log_parser.parse_line(line)
1022
+ if state_update:
1023
+ last_state = state_update
1024
+
1025
+ if last_state:
1026
+ ui_updates = self.update_training_ui(last_state)
1027
+ state["message"] = ui_updates.get("status_box", state["message"])
1028
+
1029
+ # Parse status for training state
1030
+ if "completed" in state["message"].lower():
1031
+ state["status"] = "completed"
1032
+
1033
+ return (state["status"], state["message"], logs)
1034
+
1035
+ def get_latest_status_message_logs_and_button_labels(self) -> Tuple[str, str, Any, Any, Any]:
1036
+ status, message, logs = self.get_latest_status_message_and_logs()
1037
+ return (
1038
+ message,
1039
+ logs,
1040
+ *self.update_training_buttons(status).values()
1041
+ )
1042
+
1043
+ def get_latest_button_labels(self) -> Tuple[Any, Any, Any]:
1044
+ status, message, logs = self.get_latest_status_message_and_logs()
1045
+ return self.update_training_buttons(status).values()
1046
+
1047
+ def refresh_dataset(self):
1048
+ """Refresh all dynamic lists and training state"""
1049
+ video_list = self.splitter.list_unprocessed_videos()
1050
+ training_dataset = self.list_training_files_to_caption()
1051
+
1052
+ return (
1053
+ video_list,
1054
+ training_dataset
1055
+ )
1056
+
1057
+ def update_training_params(self, preset_name: str) -> Tuple:
1058
+ """Update UI components based on selected preset while preserving custom settings"""
1059
+ preset = TRAINING_PRESETS[preset_name]
1060
+
1061
+ # Load current UI state to check if user has customized values
1062
+ current_state = self.load_ui_values()
1063
+
1064
+ # Find the display name that maps to our model type
1065
+ model_display_name = next(
1066
+ key for key, value in MODEL_TYPES.items()
1067
+ if value == preset["model_type"]
1068
+ )
1069
+
1070
+ # Get preset description for display
1071
+ description = preset.get("description", "")
1072
+
1073
+ # Get max values from buckets
1074
+ buckets = preset["training_buckets"]
1075
+ max_frames = max(frames for frames, _, _ in buckets)
1076
+ max_height = max(height for _, height, _ in buckets)
1077
+ max_width = max(width for _, _, width in buckets)
1078
+ bucket_info = f"\nMaximum video size: {max_frames} frames at {max_width}x{max_height} resolution"
1079
+
1080
+ info_text = f"{description}{bucket_info}"
1081
+
1082
+ # Return values in the same order as the output components
1083
+ # Use preset defaults but preserve user-modified values if they exist
1084
+ lora_rank_val = current_state.get("lora_rank") if current_state.get("lora_rank") != preset.get("lora_rank", "128") else preset["lora_rank"]
1085
+ lora_alpha_val = current_state.get("lora_alpha") if current_state.get("lora_alpha") != preset.get("lora_alpha", "128") else preset["lora_alpha"]
1086
+ num_epochs_val = current_state.get("num_epochs") if current_state.get("num_epochs") != preset.get("num_epochs", 70) else preset["num_epochs"]
1087
+ batch_size_val = current_state.get("batch_size") if current_state.get("batch_size") != preset.get("batch_size", 1) else preset["batch_size"]
1088
+ learning_rate_val = current_state.get("learning_rate") if current_state.get("learning_rate") != preset.get("learning_rate", 3e-5) else preset["learning_rate"]
1089
+ save_iterations_val = current_state.get("save_iterations") if current_state.get("save_iterations") != preset.get("save_iterations", 500) else preset["save_iterations"]
1090
+
1091
+ return (
1092
+ model_display_name,
1093
+ lora_rank_val,
1094
+ lora_alpha_val,
1095
+ num_epochs_val,
1096
+ batch_size_val,
1097
+ learning_rate_val,
1098
+ save_iterations_val,
1099
+ info_text
1100
+ )
vms/utils/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .parse_bool_env import parse_bool_env
2
+ from .utils import validate_model_repo, make_archive, get_video_fps, extract_scene_info, is_image_file, is_video_file, parse_training_log, save_to_hub, format_size, count_media_files, format_media_title, add_prefix_to_caption, format_time
3
+ from .training_log_parser import TrainingState, TrainingLogParser
4
+
5
+ from .image_preprocessing import normalize_image
6
+ from .video_preprocessing import remove_black_bars
7
+ from .finetrainers_utils import prepare_finetrainers_dataset, copy_files_to_training_dir
8
+
9
+ __all__ = [
10
+ 'validate_model_repo',
11
+ 'make_archive',
12
+ 'get_video_fps',
13
+ 'extract_scene_info',
14
+ 'is_image_file',
15
+ 'is_video_file',
16
+ 'parse_bool_env',
17
+ 'parse_training_log',
18
+ 'save_to_hub',
19
+ 'format_size',
20
+ 'count_media_files',
21
+ 'format_media_title',
22
+ 'add_prefix_to_caption',
23
+ 'format_time',
24
+
25
+ 'TrainingState',
26
+ 'TrainingLogParser',
27
+
28
+ 'normalize_image',
29
+ 'remove_black_bars',
30
+
31
+ 'prepare_finetrainers_dataset',
32
+ 'copy_files_to_training_dir',
33
+ ]
vms/{finetrainers_utils.py β†’ utils/finetrainers_utils.py} RENAMED
@@ -4,7 +4,7 @@ import logging
4
  import shutil
5
  from typing import Any, Optional, Dict, List, Union, Tuple
6
 
7
- from .config import STORAGE_PATH, TRAINING_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES
8
  from .utils import get_video_fps, extract_scene_info, make_archive, is_image_file, is_video_file
9
 
10
  logger = logging.getLogger(__name__)
 
4
  import shutil
5
  from typing import Any, Optional, Dict, List, Union, Tuple
6
 
7
+ from ..config import STORAGE_PATH, TRAINING_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES
8
  from .utils import get_video_fps, extract_scene_info, make_archive, is_image_file, is_video_file
9
 
10
  logger = logging.getLogger(__name__)
vms/{image_preprocessing.py β†’ utils/image_preprocessing.py} RENAMED
@@ -5,7 +5,7 @@ from PIL import Image
5
  import pillow_avif
6
  import logging
7
 
8
- from .config import NORMALIZE_IMAGES_TO, JPEG_QUALITY
9
 
10
  logger = logging.getLogger(__name__)
11
 
 
5
  import pillow_avif
6
  import logging
7
 
8
+ from ..config import NORMALIZE_IMAGES_TO, JPEG_QUALITY
9
 
10
  logger = logging.getLogger(__name__)
11
 
vms/utils/parse_bool_env.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Dict, List, Union, Tuple
2
+
3
+ def parse_bool_env(env_value: Optional[str]) -> bool:
4
+ """Parse environment variable string to boolean
5
+
6
+ Handles various true/false string representations:
7
+ - True: "true", "True", "TRUE", "1", etc
8
+ - False: "false", "False", "FALSE", "0", "", None
9
+ """
10
+ if not env_value:
11
+ return False
12
+ return str(env_value).lower() in ('true', '1', 't', 'y', 'yes')
vms/{training_log_parser.py β†’ utils/training_log_parser.py} RENAMED
File without changes
vms/{utils.py β†’ utils/utils.py} RENAMED
File without changes
vms/{video_preprocessing.py β†’ utils/video_preprocessing.py} RENAMED
File without changes