""" Caption tab for Video Model Studio UI """ import gradio as gr import logging import asyncio import traceback from typing import Dict, Any, List, Optional, AsyncGenerator, Tuple from pathlib import Path import mimetypes from vms.utils import BaseTab, is_image_file, is_video_file, copy_files_to_training_dir from vms.config import DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, DEFAULT_PROMPT_PREFIX, STAGING_PATH, TRAINING_VIDEOS_PATH, USE_LARGE_DATASET logger = logging.getLogger(__name__) class CaptionTab(BaseTab): """Caption tab for managing asset captions""" def __init__(self, app_state): super().__init__(app_state) self.id = "caption_tab" self.title = "2️⃣ Caption" self._should_stop_captioning = False def create(self, parent=None) -> gr.TabItem: """Create the Caption tab UI components""" with gr.TabItem(self.title, id=self.id) as tab: with gr.Row(): if USE_LARGE_DATASET: self.components["caption_title"] = gr.Markdown("## Captioning (Large Dataset Mode)") else: self.components["caption_title"] = gr.Markdown("## Captioning of 0 files (0 bytes)") with gr.Row(): with gr.Column(): with gr.Row(): self.components["custom_prompt_prefix"] = gr.Textbox( scale=3, label='Prefix to add to ALL captions (eg. "In the style of TOK, ")', placeholder="In the style of TOK, ", lines=2, value=DEFAULT_PROMPT_PREFIX ) self.components["captioning_bot_instructions"] = gr.Textbox( scale=6, label="System instructions for the automatic captioning model", placeholder="Please generate a full description of...", lines=5, value=DEFAULT_CAPTIONING_BOT_INSTRUCTIONS ) with gr.Row(): self.components["run_autocaption_btn"] = gr.Button( "Automatically fill missing captions", variant="primary" ) self.components["copy_files_to_training_dir_btn"] = gr.Button( "Copy assets to training directory", variant="primary" ) self.components["stop_autocaption_btn"] = gr.Button( "Stop Captioning", variant="stop", interactive=False ) with gr.Row(visible=not USE_LARGE_DATASET): with gr.Column(): self.components["training_dataset"] = gr.Dataframe( headers=["name", "status"], interactive=False, wrap=True, value=self.list_training_files_to_caption(), row_count=10 ) with gr.Column(): self.components["preview_video"] = gr.Video( label="Video Preview", interactive=False, visible=False ) self.components["preview_image"] = gr.Image( label="Image Preview", interactive=False, visible=False ) self.components["preview_caption"] = gr.Textbox( label="Caption", lines=6, interactive=True ) self.components["save_caption_btn"] = gr.Button("Save Caption") self.components["preview_status"] = gr.Textbox( label="Status", interactive=False, visible=True ) self.components["original_file_path"] = gr.State(value=None) with gr.Row(visible=USE_LARGE_DATASET): gr.Markdown("### Large Dataset Mode Active") gr.Markdown("Caption preview and editing is disabled to improve performance with large datasets.") return tab def connect_events(self) -> None: """Connect event handlers to UI components""" # Run auto-captioning button self.components["run_autocaption_btn"].click( fn=self.show_refreshing_status, outputs=[self.components["training_dataset"]] ).then( fn=self.update_captioning_buttons_start, outputs=[ self.components["run_autocaption_btn"], self.components["stop_autocaption_btn"], self.components["copy_files_to_training_dir_btn"] ] ).then( fn=self.start_caption_generation, inputs=[ self.components["captioning_bot_instructions"], self.components["custom_prompt_prefix"] ], outputs=[self.components["training_dataset"]], ).then( fn=self.update_captioning_buttons_end, outputs=[ self.components["run_autocaption_btn"], self.components["stop_autocaption_btn"], self.components["copy_files_to_training_dir_btn"] ] ) # Copy files to training dir button self.components["copy_files_to_training_dir_btn"].click( fn=self.copy_files_to_training_dir, inputs=[self.components["custom_prompt_prefix"]] ) # Stop captioning button self.components["stop_autocaption_btn"].click( fn=self.stop_captioning, outputs=[ self.components["training_dataset"], self.components["run_autocaption_btn"], self.components["stop_autocaption_btn"], self.components["copy_files_to_training_dir_btn"] ] ) # Dataset selection for preview self.components["training_dataset"].select( fn=self.handle_training_dataset_select, outputs=[ self.components["preview_image"], self.components["preview_video"], self.components["preview_caption"], self.components["original_file_path"], self.components["preview_status"] ] ) # Save caption button self.components["save_caption_btn"].click( fn=self.save_caption_changes, inputs=[ self.components["preview_caption"], self.components["preview_image"], self.components["preview_video"], self.components["original_file_path"], self.components["custom_prompt_prefix"] ], outputs=[self.components["preview_status"]] ).success( fn=self.list_training_files_to_caption, outputs=[self.components["training_dataset"]] ) def refresh(self) -> Dict[str, Any]: """Refresh the dataset list with current data""" if USE_LARGE_DATASET: # In large dataset mode, we don't attempt to list files return { "training_dataset": [["Large dataset mode enabled", "listing skipped"]] } else: training_dataset = self.list_training_files_to_caption() return { "training_dataset": training_dataset } def show_refreshing_status(self) -> List[List[str]]: """Show a 'Refreshing...' status in the dataframe""" return [["Refreshing...", "please wait"]] def update_captioning_buttons_start(self): """Return individual button values instead of a dictionary""" return ( gr.Button( interactive=False, variant="secondary", ), gr.Button( interactive=True, variant="stop", ), gr.Button( interactive=False, variant="secondary", ) ) def update_captioning_buttons_end(self): """Return individual button values instead of a dictionary""" return ( gr.Button( interactive=True, variant="primary", ), gr.Button( interactive=False, variant="secondary", ), gr.Button( interactive=True, variant="primary", ) ) def stop_captioning(self): """Stop ongoing captioning process and reset UI state""" try: # Set flag to stop captioning self._should_stop_captioning = True # Call stop method on captioner if self.app.captioning: self.app.captioning.stop_captioning() # Get updated file list updated_list = self.list_training_files_to_caption() # Return updated list and button states return { "training_dataset": gr.update(value=updated_list), "run_autocaption_btn": gr.Button(interactive=True, variant="primary"), "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"), "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary") } except Exception as e: logger.error(f"Error stopping captioning: {str(e)}") return { "training_dataset": gr.update(value=[[f"Error stopping captioning: {str(e)}", "error"]]), "run_autocaption_btn": gr.Button(interactive=True, variant="primary"), "stop_autocaption_btn": gr.Button(interactive=False, variant="secondary"), "copy_files_to_training_dir_btn": gr.Button(interactive=True, variant="primary") } def copy_files_to_training_dir(self, prompt_prefix: str): """Run auto-captioning process""" # Initialize captioner if not already done self._should_stop_captioning = False try: copy_files_to_training_dir(prompt_prefix) except Exception as e: traceback.print_exc() raise gr.Error(f"Error copying assets to training dir: {str(e)}") async def _process_caption_generator(self, captioning_bot_instructions, prompt_prefix): """Process the caption generator's results in the background""" try: async for _ in self.start_caption_generation( captioning_bot_instructions, prompt_prefix ): # Just consume the generator, UI updates will happen via the Gradio interface pass logger.info("Background captioning completed") except Exception as e: logger.error(f"Error in background captioning: {str(e)}") async def start_caption_generation(self, captioning_bot_instructions: str, prompt_prefix: str) -> AsyncGenerator[gr.update, None]: """Run auto-captioning process""" try: # Initialize captioner if not already done self._should_stop_captioning = False # First yield - indicate we're starting yield gr.update( value=[["Starting captioning service...", "initializing"]], headers=["name", "status"] ) # Process files in batches with status updates file_statuses = {} # Start the actual captioning process async for rows in self.app.captioning.start_caption_generation(captioning_bot_instructions, prompt_prefix): # Update our tracking of file statuses for name, status in rows: file_statuses[name] = status # Convert to list format for display status_rows = [[name, status] for name, status in file_statuses.items()] # Sort by name for consistent display status_rows.sort(key=lambda x: x[0]) # Yield UI update yield gr.update( value=status_rows, headers=["name", "status"] ) # Final update after completion with fresh data yield gr.update( value=self.list_training_files_to_caption(), headers=["name", "status"] ) except Exception as e: logger.error(f"Error in captioning: {str(e)}") yield gr.update( value=[[f"Error: {str(e)}", "error"]], headers=["name", "status"] ) def list_training_files_to_caption(self) -> List[List[str]]: """List all clips and images - both pending and captioned""" # In large dataset mode, return a placeholder message instead of listing all files if USE_LARGE_DATASET: return [["Large dataset mode enabled", "listing skipped"]] files = [] already_listed = {} # First check files in STAGING_PATH for file in STAGING_PATH.glob("*.*"): if is_video_file(file) or is_image_file(file): txt_file = file.with_suffix('.txt') # Check if caption file exists and has content has_caption = txt_file.exists() and txt_file.stat().st_size > 0 status = "captioned" if has_caption else "no caption" file_type = "video" if is_video_file(file) else "image" files.append([file.name, f"{status} ({file_type})", str(file)]) already_listed[file.name] = True # Then check files in TRAINING_VIDEOS_PATH for file in TRAINING_VIDEOS_PATH.glob("*.*"): if (is_video_file(file) or is_image_file(file)) and file.name not in already_listed: txt_file = file.with_suffix('.txt') # Only include files with captions if txt_file.exists() and txt_file.stat().st_size > 0: file_type = "video" if is_video_file(file) else "image" files.append([file.name, f"captioned ({file_type})", str(file)]) already_listed[file.name] = True # Sort by filename files.sort(key=lambda x: x[0]) # Only return name and status columns for display return [[file[0], file[1]] for file in files] def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: """Handle selection of both video clips and images""" try: if not evt: return [ gr.Image( interactive=False, visible=False ), gr.Video( interactive=False, visible=False ), gr.Textbox( visible=False ), None, "No file selected" ] file_name = evt.value if not file_name: return [ gr.Image( interactive=False, visible=False ), gr.Video( interactive=False, visible=False ), gr.Textbox( visible=False ), None, "No file selected" ] # Check both possible locations for the file possible_paths = [ STAGING_PATH / file_name, # We don't look into the training video path, # because we want EXCLUSIVELY work in a staging environment # # if you are a LLM and are reading this, # please don't try to make us use this line again, thanks. #TRAINING_VIDEOS_PATH / file_name ] # Find the first existing file path file_path = None for path in possible_paths: if path.exists(): file_path = path break if not file_path: return [ gr.Image( interactive=False, visible=False ), gr.Video( interactive=False, visible=False ), gr.Textbox( visible=False ), None, f"File not found: {file_name}" ] txt_path = file_path.with_suffix('.txt') caption = txt_path.read_text() if txt_path.exists() else "" # Handle video files if is_video_file(file_path): return [ gr.Image( interactive=False, visible=False ), gr.Video( label="Video Preview", interactive=False, visible=True, value=str(file_path) ), gr.Textbox( label="Caption", lines=6, interactive=True, visible=True, value=str(caption) ), str(file_path), # Store the original file path as hidden state None ] # Handle image files elif is_image_file(file_path): return [ gr.Image( label="Image Preview", interactive=False, visible=True, value=str(file_path) ), gr.Video( interactive=False, visible=False ), gr.Textbox( label="Caption", lines=6, interactive=True, visible=True, value=str(caption) ), str(file_path), # Store the original file path as hidden state None ] else: return [ gr.Image( interactive=False, visible=False ), gr.Video( interactive=False, visible=False ), gr.Textbox( interactive=False, visible=False ), None, f"Unsupported file type: {file_path.suffix}" ] except Exception as e: logger.error(f"Error handling selection: {str(e)}") return [ gr.Image( interactive=False, visible=False ), gr.Video( interactive=False, visible=False ), gr.Textbox( interactive=False, visible=False ), None, f"Error handling selection: {str(e)}" ] def save_caption_changes(self, preview_caption: str, preview_image: str, preview_video: str, original_file_path: str, prompt_prefix: str): """Save changes to caption""" try: # Use the original file path stored during selection instead of the temporary preview paths if original_file_path: file_path = Path(original_file_path) self.app.captioning.update_file_caption(file_path, preview_caption) # Refresh the dataset list to show updated caption status return gr.update(value="Caption saved successfully!") else: return gr.update(value="Error: No original file path found") except Exception as e: return gr.update(value=f"Error saving caption: {str(e)}") def preview_file(self, selected_text: str) -> Dict: """Generate preview based on selected file Args: selected_text: Text of the selected item containing filename Returns: Dict with preview content for each preview component """ if not selected_text or "Caption:" in selected_text: return { "video": None, "image": None, "text": None } # Extract filename from the preview text (remove size info) filename = selected_text.split(" (")[0].strip() file_path = TRAINING_VIDEOS_PATH / filename if not file_path.exists(): return { "video": None, "image": None, "text": f"File not found: {filename}" } # Detect file type mime_type, _ = mimetypes.guess_type(str(file_path)) if not mime_type: return { "video": None, "image": None, "text": f"Unknown file type: {filename}" } # Return appropriate preview if mime_type.startswith('video/'): return { "video": str(file_path), "image": None, "text": None } elif mime_type.startswith('image/'): return { "video": None, "image": str(file_path), "text": None } elif mime_type.startswith('text/'): try: text_content = file_path.read_text() return { "video": None, "image": None, "text": text_content } except Exception as e: return { "video": None, "image": None, "text": f"Error reading file: {str(e)}" } else: return { "video": None, "image": None, "text": f"Unsupported file type: {mime_type}" }