import logging import torch import shutil import gradio as gr import numpy as np from decord import VideoReader, cpu from pathlib import Path from typing import Any, Tuple, Dict, Optional, AsyncGenerator, List import asyncio from dataclasses import dataclass from datetime import datetime import cv2 import copy from llava.model.builder import load_pretrained_model from llava.mm_utils import tokenizer_image_token from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from config import TRAINING_VIDEOS_PATH, STAGING_PATH, PRELOAD_CAPTIONING_MODEL, CAPTIONING_MODEL, USE_MOCK_CAPTIONING_MODEL, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, VIDEOS_TO_SPLIT_PATH, DEFAULT_PROMPT_PREFIX from utils import extract_scene_info, is_image_file, is_video_file from finetrainers_utils import copy_files_to_training_dir, prepare_finetrainers_dataset logger = logging.getLogger(__name__) @dataclass class CaptioningProgress: video_name: str total_frames: int processed_frames: int status: str started_at: datetime completed_at: Optional[datetime] = None error: Optional[str] = None class CaptioningService: _instance = None _model = None _tokenizer = None _image_processor = None _model_loading = None _loop = None def __new__(cls, model_name=CAPTIONING_MODEL): if cls._instance is not None: return cls._instance instance = super().__new__(cls) if PRELOAD_CAPTIONING_MODEL: cls._instance = instance try: cls._loop = asyncio.get_running_loop() except RuntimeError: cls._loop = asyncio.new_event_loop() asyncio.set_event_loop(cls._loop) if not USE_MOCK_CAPTIONING_MODEL and cls._model_loading is None: cls._model_loading = cls._loop.create_task(cls._background_load_model(model_name)) return instance def __init__(self, model_name=CAPTIONING_MODEL): if hasattr(self, 'model_name'): # Already initialized return self.model_name = model_name self.tokenizer = None self.model = None self.image_processor = None self.active_tasks: Dict[str, CaptioningProgress] = {} self._should_stop = False self._model_loaded = False @classmethod async def _background_load_model(cls, model_name): """Background task to load the model""" try: logger.info("Starting background model loading...") if not cls._loop: cls._loop = asyncio.get_running_loop() def load_model(): try: tokenizer, model, image_processor, _ = load_pretrained_model( model_name, None, "llava_qwen", torch_dtype="bfloat16", device_map="auto" ) model.eval() return tokenizer, model, image_processor except Exception as e: logger.error(f"Error in load_model: {str(e)}") raise result = await cls._loop.run_in_executor(None, load_model) cls._tokenizer, cls._model, cls._image_processor = result logger.info("Background model loading completed successfully!") except Exception as e: logger.error(f"Background model loading failed: {str(e)}") cls._model_loading = None raise async def ensure_model_loaded(self): """Ensure model is loaded before processing""" if USE_MOCK_CAPTIONING_MODEL: logger.info("Using mock model, skipping model loading") self.__class__._model_loading = None self._model_loaded = True return if not self._model_loaded: try: if PRELOAD_CAPTIONING_MODEL and self.__class__._model_loading: logger.info("Waiting for background model loading to complete...") if self.__class__._loop and self.__class__._loop != asyncio.get_running_loop(): logger.warning("Different event loop detected, creating new loading task") self.__class__._model_loading = None await self._load_model_sync() else: await self.__class__._model_loading self.model = self.__class__._model self.tokenizer = self.__class__._tokenizer self.image_processor = self.__class__._image_processor else: await self._load_model_sync() self._model_loaded = True logger.info("Model loading completed!") except Exception as e: logger.error(f"Error loading model: {str(e)}") raise async def _load_model_sync(self): """Synchronously load the model""" logger.info("Loading model synchronously...") current_loop = asyncio.get_running_loop() def load_model(): return load_pretrained_model( self.model_name, None, "llava_qwen", torch_dtype="bfloat16", device_map="auto" ) self.tokenizer, self.model, self.image_processor, _ = await current_loop.run_in_executor( None, load_model ) self.model.eval() def _load_video(self, video_path: Path, max_frames_num: int = 64, fps: int = 1, force_sample: bool = True) -> tuple[np.ndarray, str, float]: """Load and preprocess video frames with strict limits Args: video_path: Path to video file max_frames_num: Maximum number of frames to extract (default: 64) fps: Frames per second to sample (default: 1) force_sample: Whether to force uniform sampling (default: True) Returns: Tuple of (frames, frame_times_str, video_time) """ video_path_str = str(video_path) logger.debug(f"Loading video: {video_path_str}") # Handle empty video case if max_frames_num == 0: return np.zeros((1, 336, 336, 3)), "", 0 vr = VideoReader(video_path_str, ctx=cpu(0), num_threads=1) total_frame_num = len(vr) video_time = total_frame_num / vr.get_avg_fps() # Calculate frame indices with uniform sampling fps = round(vr.get_avg_fps()/fps) frame_idx = [i for i in range(0, len(vr), fps)] frame_time = [i/fps for i in frame_idx] # Force uniform sampling if too many frames if len(frame_idx) > max_frames_num or force_sample: sample_fps = max_frames_num uniform_sampled_frames = np.linspace(0, total_frame_num - 1, sample_fps, dtype=int) frame_idx = uniform_sampled_frames.tolist() frame_time = [i/vr.get_avg_fps() for i in frame_idx] frame_time_str = ",".join([f"{i:.2f}s" for i in frame_time]) try: frames = vr.get_batch(frame_idx).asnumpy() logger.debug(f"Loaded {len(frames)} frames with shape {frames.shape}") return frames, frame_time_str, video_time except Exception as e: logger.error(f"Error loading video frames: {str(e)}") raise async def process_video(self, video_path: Path, prompt: str, prompt_prefix: str = "") -> AsyncGenerator[tuple[CaptioningProgress, Optional[str]], None]: try: video_name = video_path.name logger.info(f"Starting processing of video: {video_name}") # Load video metadata with strict frame limits logger.debug(f"Loading video metadata for {video_name}") loop = asyncio.get_event_loop() vr = await loop.run_in_executor(None, lambda: VideoReader(str(video_path), ctx=cpu(0))) total_frames = len(vr) progress = CaptioningProgress( video_name=video_name, total_frames=total_frames, processed_frames=0, status="initializing", started_at=datetime.now() ) self.active_tasks[video_name] = progress yield progress, None # Get parent caption if this is a clip parent_caption = "" if "___" in video_path.stem: parent_name, _ = extract_scene_info(video_path.stem) parent_txt_path = VIDEOS_TO_SPLIT_PATH / f"{parent_name}.txt" if parent_txt_path.exists(): parent_caption = parent_txt_path.read_text().strip() # Ensure model is loaded before processing await self.ensure_model_loaded() if USE_MOCK_CAPTIONING_MODEL: # Even in mock mode, we'll generate a caption that shows we processed parent info clip_caption = f"This is a test caption for {video_name}" # Combine clip caption with parent caption if parent_caption: full_caption = f"{clip_caption}\n{parent_caption}" else: full_caption = clip_caption if prompt_prefix and not full_caption.startswith(prompt_prefix): full_caption = f"{prompt_prefix}{full_caption}" # Write the caption file txt_path = video_path.with_suffix('.txt') txt_path.write_text(full_caption) logger.debug(f"Mock mode: Saved caption to {txt_path}") progress.status = "completed" progress.processed_frames = total_frames progress.completed_at = datetime.now() yield progress, full_caption else: # Process frames with strict limits max_frames_num = 64 # Maximum frames supported by the model frames, frame_times_str, video_time = await loop.run_in_executor( None, lambda: self._load_video(video_path, max_frames_num, fps=1, force_sample=True) ) # Process all frames at once using the image processor processed_frames = await loop.run_in_executor( None, lambda: self.image_processor.preprocess( frames, return_tensors="pt" )["pixel_values"] ) # Update progress progress.processed_frames = len(frames) progress.status = "generating caption" yield progress, None # Move processed frames to GPU video_tensor = processed_frames.to('cuda').bfloat16() # Use proper conversation template and tokens conv_template = "qwen_1_5" time_instruction = (f"The video lasts for {video_time:.2f} seconds, and {len(frames)} " f"frames are uniformly sampled from it. These frames are located at {frame_times_str}.") full_question = DEFAULT_IMAGE_TOKEN + f"{time_instruction}\n{prompt}" conv = copy.deepcopy(conv_templates[conv_template]) conv.append_message(conv.roles[0], full_question) conv.append_message(conv.roles[1], None) prompt_question = conv.get_prompt() # Cap the output length to prevent hallucination max_new_tokens = 512 # Reasonable limit for caption length input_ids = await loop.run_in_executor( None, lambda: tokenizer_image_token(prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') ) # Generate caption with controlled parameters with torch.no_grad(): output = await loop.run_in_executor( None, lambda: self.model.generate( input_ids, images=[video_tensor], modalities=["video"], do_sample=False, temperature=0, max_new_tokens=max_new_tokens, ) ) clip_caption = await loop.run_in_executor( None, lambda: self.tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip() ) # Remove the instruction/question part from the response if time_instruction in clip_caption: clip_caption = clip_caption.split(time_instruction)[1].strip() if prompt in clip_caption: clip_caption = clip_caption.split(prompt)[1].strip() # Combine captions with proper formatting if parent_caption: full_caption = f"{clip_caption}\n{parent_caption}" else: full_caption = clip_caption if prompt_prefix and not full_caption.startswith(prompt_prefix): full_caption = f"{prompt_prefix}{full_caption}" # Write caption txt_path = video_path.with_suffix('.txt') txt_path.write_text(full_caption) progress.status = "completed" progress.completed_at = datetime.now() yield progress, full_caption except Exception as e: progress.status = "error" progress.error = str(e) progress.completed_at = datetime.now() yield progress, None raise async def process_image(self, image_path: Path, prompt: str, prompt_prefix: str = "") -> AsyncGenerator[tuple[CaptioningProgress, Optional[str]], None]: """Process a single image for captioning""" try: image_name = image_path.name logger.info(f"Starting processing of image: {image_name}") progress = CaptioningProgress( video_name=image_name, # Reusing video_name field for images total_frames=1, processed_frames=0, status="initializing", started_at=datetime.now() ) self.active_tasks[image_name] = progress yield progress, None # Ensure model is loaded await self.ensure_model_loaded() if USE_MOCK_CAPTIONING_MODEL: progress.status = "completed" progress.processed_frames = 1 progress.completed_at = datetime.now() print("yielding fake") yield progress, "This is a test image caption" return # Read and process image loop = asyncio.get_event_loop() image = await loop.run_in_executor( None, lambda: cv2.imread(str(image_path)) ) if image is None: raise ValueError(f"Could not read image: {str(image_path)}") # Convert BGR to RGB image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Process image processed_image = await loop.run_in_executor( None, lambda: self.image_processor.preprocess( image, return_tensors="pt" )["pixel_values"] ) progress.processed_frames = 1 progress.status = "generating caption" yield progress, None # Move to GPU and generate caption image_tensor = processed_image.to('cuda').bfloat16() full_prompt = f"{prompt}" input_ids = await loop.run_in_executor( None, lambda: tokenizer_image_token(full_prompt, self.tokenizer, return_tensors="pt").unsqueeze(0).to('cuda') ) with torch.no_grad(): output = await loop.run_in_executor( None, lambda: self.model.generate( input_ids, images=[image_tensor], modalities=["image"], do_sample=False, temperature=0, max_new_tokens=4096, ) ) caption = await loop.run_in_executor( None, lambda: self.tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip() ) progress.status = "completed" progress.completed_at = datetime.now() gr.Info(f"Successfully generated caption for {image_name}") yield progress, caption except Exception as e: progress.status = "error" progress.error = str(e) progress.completed_at = datetime.now() yield progress, None raise gr.Error(f"Error processing image: {str(e)}") async def start_caption_generation(self, custom_prompt: str, prompt_prefix: str) -> AsyncGenerator[List[List[str]], None]: """Iterates over clips to auto-generate captions asynchronously.""" try: logger.info("Starting auto-caption generation") # Use provided prompt or default default_prompt = DEFAULT_CAPTIONING_BOT_INSTRUCTIONS prompt = custom_prompt.strip() or default_prompt logger.debug(f"Using prompt: {prompt}") # Find files needing captions video_files = list(STAGING_PATH.glob("*.mp4")) image_files = [f for f in STAGING_PATH.glob("*") if is_image_file(f)] all_files = video_files + image_files # Filter for files missing captions or with empty caption files files_to_process = [] for file_path in all_files: caption_path = file_path.with_suffix('.txt') needs_caption = ( not caption_path.exists() or caption_path.stat().st_size == 0 or caption_path.read_text().strip() == "" ) if needs_caption: files_to_process.append(file_path) logger.info(f"Found {len(files_to_process)} files needing captions") if not files_to_process: logger.info("No files need captioning") yield [] return self._should_stop = False self.active_tasks.clear() status_update: Dict[str, Dict[str, Any]] = {} for file_path in all_files: if self._should_stop: break try: print(f"we are in file_path {str(file_path)}") # Choose appropriate processing method based on file type if is_video_file(file_path): process_gen = self.process_video(file_path, prompt, prompt_prefix) else: process_gen = self.process_image(file_path, prompt, prompt_prefix) print("got process_gen = ", process_gen) async for progress, caption in process_gen: print(f"process_gen contains this caption = {caption}") if caption and prompt_prefix and not caption.startswith(prompt_prefix): caption = f"{prompt_prefix}{caption}" # Save caption if caption: txt_path = file_path.with_suffix('.txt') txt_path.write_text(caption) logger.debug(f"Progress update: {progress.status}") # Store progress info status_update[file_path.name] = { "status": progress.status, "frames": progress.processed_frames, "total": progress.total_frames } # Convert to list format for Gradio DataFrame rows = [] for file_name, info in status_update.items(): status = info["status"] if status == "processing": percent = (info["frames"] / info["total"]) * 100 status = f"Analyzing... {percent:.1f}% ({info['frames']}/{info['total']} frames)" elif status == "generating caption": status = "Generating caption..." elif status == "error": status = f"Error: {progress.error}" elif status == "completed": status = "Completed" rows.append([file_name, status]) yield rows await asyncio.sleep(0.1) except Exception as e: logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True) rows = [[str(file_path.name), f"Error: {str(e)}"]] yield rows continue logger.info("Auto-caption generation completed, cyping assets to the training dir..") copy_files_to_training_dir(prompt_prefix) except Exception as e: logger.error(f"Error in start_caption_generation: {str(e)}") yield [[str(e), "error"]] raise def stop_captioning(self): """Stop all ongoing captioning tasks""" logger.info("Stopping all captioning tasks") self._should_stop = True def close(self): """Clean up resources""" logger.info("Cleaning up captioning service resources") if hasattr(self, 'model'): del self.model if hasattr(self, 'tokenizer'): del self.tokenizer if hasattr(self, 'image_processor'): del self.image_processor torch.cuda.empty_cache()