VideoModelStudio / vms /utils /finetrainers_utils.py
jbilcke-hf's picture
jbilcke-hf HF Staff
making our code more robust
7c52128
raw
history blame
8 kB
import gradio as gr
from pathlib import Path
import logging
import shutil
from typing import Any, Optional, Dict, List, Union, Tuple
from ..config import (
STORAGE_PATH, TRAINING_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, HF_API_TOKEN, MODEL_TYPES,
DEFAULT_VALIDATION_NB_STEPS,
DEFAULT_VALIDATION_HEIGHT,
DEFAULT_VALIDATION_WIDTH,
DEFAULT_VALIDATION_NB_FRAMES,
DEFAULT_VALIDATION_FRAMERATE
)
from .utils import get_video_fps, extract_scene_info, make_archive, is_image_file, is_video_file
logger = logging.getLogger(__name__)
def prepare_finetrainers_dataset() -> Tuple[Path, Path]:
"""Prepare a Finetrainers-compatible dataset structure
Creates:
training/
β”œβ”€β”€ prompt.txt # All captions, one per line
β”œβ”€β”€ videos.txt # All video paths, one per line
└── videos/ # Directory containing all mp4 files
β”œβ”€β”€ 00000.mp4
β”œβ”€β”€ 00001.mp4
└── ...
Returns:
Tuple of (videos_file_path, prompts_file_path)
"""
# Verifies the videos subdirectory
TRAINING_VIDEOS_PATH.mkdir(exist_ok=True)
# Clear existing training lists
for f in TRAINING_PATH.glob("*"):
if f.is_file():
if f.name in ["videos.txt", "prompts.txt", "prompt.txt"]:
f.unlink()
videos_file = TRAINING_PATH / "videos.txt"
prompts_file = TRAINING_PATH / "prompts.txt" # Finetrainers can use either prompts.txt or prompt.txt
media_files = []
captions = []
# Process all video files from the videos subdirectory
for idx, file in enumerate(sorted(TRAINING_VIDEOS_PATH.glob("*.mp4"))):
caption_file = file.with_suffix('.txt')
if caption_file.exists():
# Normalize caption to single line
caption = caption_file.read_text().strip()
caption = ' '.join(caption.split())
# Use relative path from training root
relative_path = f"videos/{file.name}"
media_files.append(relative_path)
captions.append(caption)
# Write files if we have content
if media_files and captions:
videos_file.write_text('\n'.join(media_files))
prompts_file.write_text('\n'.join(captions))
logger.info(f"Created dataset with {len(media_files)} video/caption pairs")
else:
logger.warning("No valid video/caption pairs found in training directory")
return None, None
# Verify file contents
with open(videos_file) as vf:
video_lines = [l.strip() for l in vf.readlines() if l.strip()]
with open(prompts_file) as pf:
prompt_lines = [l.strip() for l in pf.readlines() if l.strip()]
if len(video_lines) != len(prompt_lines):
logger.error(f"Mismatch in generated files: {len(video_lines)} videos vs {len(prompt_lines)} prompts")
return None, None
return videos_file, prompts_file
def copy_files_to_training_dir(prompt_prefix: str) -> int:
"""Just copy files over, with no destruction"""
gr.Info("Copying assets to the training dataset..")
# Find files needing captions
video_files = list(STAGING_PATH.glob("*.mp4"))
image_files = [f for f in STAGING_PATH.glob("*") if is_image_file(f)]
all_files = video_files + image_files
nb_copied_pairs = 0
for file_path in all_files:
caption = ""
file_caption_path = file_path.with_suffix('.txt')
if file_caption_path.exists():
logger.debug(f"Found caption file: {file_caption_path}")
caption = file_caption_path.read_text()
# Get parent caption if this is a clip
parent_caption = ""
if "___" in file_path.stem:
parent_name, _ = extract_scene_info(file_path.stem)
#print(f"parent_name is {parent_name}")
parent_caption_path = STAGING_PATH / f"{parent_name}.txt"
if parent_caption_path.exists():
logger.debug(f"Found parent caption file: {parent_caption_path}")
parent_caption = parent_caption_path.read_text().strip()
target_file_path = TRAINING_VIDEOS_PATH / file_path.name
target_caption_path = target_file_path.with_suffix('.txt')
if parent_caption and not caption.endswith(parent_caption):
caption = f"{caption}\n{parent_caption}"
# Add FPS information for videos
if is_video_file(file_path) and caption:
# Only add FPS if not already present
if not any(f"FPS, " in line for line in caption.split('\n')):
fps_info = get_video_fps(file_path)
if fps_info:
caption = f"{fps_info}{caption}"
if prompt_prefix and not caption.startswith(prompt_prefix):
caption = f"{prompt_prefix}{caption}"
# make sure we only copy over VALID pairs
if caption:
try:
target_caption_path.write_text(caption)
shutil.copy2(file_path, target_file_path)
nb_copied_pairs += 1
except Exception as e:
print(f"failed to copy one of the pairs: {e}")
pass
prepare_finetrainers_dataset()
gr.Info(f"Successfully generated the training dataset ({nb_copied_pairs} pairs)")
return nb_copied_pairs
# Add this function to finetrainers_utils.py or a suitable place
def create_validation_config() -> Optional[Path]:
"""Create a validation configuration JSON file for Finetrainers
Creates a validation dataset file with a subset of the training data
Returns:
Path to the validation JSON file, or None if no training files exist
"""
# Ensure training dataset exists
if not TRAINING_VIDEOS_PATH.exists() or not any(TRAINING_VIDEOS_PATH.glob("*.mp4")):
logger.warning("No training videos found for validation")
return None
# Get a subset of the training videos (up to 4) for validation
training_videos = list(TRAINING_VIDEOS_PATH.glob("*.mp4"))
validation_videos = training_videos[:min(4, len(training_videos))]
if not validation_videos:
logger.warning("No validation videos selected")
return None
# Create validation data entries
validation_data = {"data": []}
for video_path in validation_videos:
# Get caption from matching text file
caption_path = video_path.with_suffix('.txt')
if not caption_path.exists():
logger.warning(f"Missing caption for {video_path}, skipping for validation")
continue
caption = caption_path.read_text().strip()
# Get video dimensions and properties
try:
# Use the most common default resolution and settings
data_entry = {
"caption": caption,
"image_path": "", # No input image for text-to-video
"video_path": str(video_path),
"num_inference_steps": DEFAULT_VALIDATION_NB_STEPS,
"height": DEFAULT_VALIDATION_HEIGHT,
"width": DEFAULT_VALIDATION_WIDTH,
"num_frames": DEFAULT_VALIDATION_NB_FRAMES,
"frame_rate": DEFAULT_VALIDATION_FRAMERATE
}
validation_data["data"].append(data_entry)
except Exception as e:
logger.warning(f"Error adding validation entry for {video_path}: {e}")
if not validation_data["data"]:
logger.warning("No valid validation entries created")
return None
# Write validation config to file
validation_file = OUTPUT_PATH / "validation_config.json"
with open(validation_file, 'w') as f:
json.dump(validation_data, f, indent=2)
logger.info(f"Created validation config with {len(validation_data['data'])} entries")
return validation_file