VideoModelStudio / utils.py
jbilcke-hf's picture
jbilcke-hf HF staff
initial commit log 🪵🦫
91fb4ef
import os
import shutil
from huggingface_hub import HfApi, create_repo
from pathlib import Path
import json
import re
from typing import Any, Optional, Dict, List, Union, Tuple
def make_archive(source: str | Path, destination: str | Path):
source = str(source)
destination = str(destination)
#print(f"make_archive({source}, {destination})")
base = os.path.basename(destination)
name = base.split('.')[0]
format = base.split('.')[1]
archive_from = os.path.dirname(source)
archive_to = os.path.basename(source.strip(os.sep))
shutil.make_archive(name, format, archive_from, archive_to)
shutil.move('%s.%s'%(name,format), destination)
def extract_scene_info(filename: str) -> Tuple[str, Optional[int]]:
"""Extract base name and scene number from filename
Args:
filename: Input filename like "my_cool_video_1___001.mp4"
Returns:
Tuple of (base_name, scene_number)
e.g. ("my_cool_video_1", 1)
"""
# Match numbers at the end of the filename before extension
match = re.search(r'(.+?)___(\d+)$', Path(filename).stem)
if match:
return match.group(1), int(match.group(2))
return Path(filename).stem, None
def is_image_file(file_path: Path) -> bool:
"""Check if file is an image based on extension
Args:
file_path: Path to check
Returns:
bool: True if file has image extension
"""
image_extensions = {'.jpg', '.jpeg', '.png', '.webp', '.avif', '.heic'}
return file_path.suffix.lower() in image_extensions
def is_video_file(file_path: Path) -> bool:
"""Check if file is a video based on extension
Args:
file_path: Path to check
Returns:
bool: True if file has video extension
"""
video_extensions = {'.mp4', '.webm'}
return file_path.suffix.lower() in video_extensions
def parse_bool_env(env_value: Optional[str]) -> bool:
"""Parse environment variable string to boolean
Handles various true/false string representations:
- True: "true", "True", "TRUE", "1", etc
- False: "false", "False", "FALSE", "0", "", None
"""
if not env_value:
return False
return str(env_value).lower() in ('true', '1', 't', 'y', 'yes')
def validate_model_repo(repo_id: str) -> Dict[str, str]:
"""Validate HuggingFace model repository name
Args:
repo_id: Repository ID in format "username/model-name"
Returns:
Dict with error message if invalid, or None if valid
"""
if not repo_id:
return {"error": "Repository name is required"}
if "/" not in repo_id:
return {"error": "Repository name must be in format username/model-name"}
# Check characters
invalid_chars = set('<>:"/\\|?*')
if any(c in repo_id for c in invalid_chars):
return {"error": "Repository name contains invalid characters"}
return {"error": None}
def save_to_hub(model_path: Path, repo_id: str, token: str, commit_message: str = "Update model") -> bool:
"""Save model files to Hugging Face Hub
Args:
model_path: Path to model files
repo_id: Repository ID (username/model-name)
token: HuggingFace API token
commit_message: Optional commit message
Returns:
bool: True if successful, False if failed
"""
try:
api = HfApi(token=token)
# Validate repo_id
validation = validate_model_repo(repo_id)
if validation["error"]:
return False
# Create or get repo
try:
create_repo(repo_id, token=token, repo_type="model", exist_ok=True)
except Exception as e:
print(f"Error creating repo: {e}")
return False
# Upload all files
api.upload_folder(
folder_path=str(model_path),
repo_id=repo_id,
repo_type="model",
commit_message=commit_message
)
return True
except Exception as e:
print(f"Error uploading to hub: {e}")
return False
def parse_training_log(line: str) -> Dict:
"""Parse a training log line for metrics
Args:
line: Log line from training output
Returns:
Dict with parsed metrics (epoch, step, loss, etc)
"""
metrics = {}
try:
# Extract step/epoch info
if "step=" in line:
step = int(line.split("step=")[1].split()[0].strip(","))
metrics["step"] = step
if "epoch=" in line:
epoch = int(line.split("epoch=")[1].split()[0].strip(","))
metrics["epoch"] = epoch
if "loss=" in line:
loss = float(line.split("loss=")[1].split()[0].strip(","))
metrics["loss"] = loss
if "lr=" in line:
lr = float(line.split("lr=")[1].split()[0].strip(","))
metrics["learning_rate"] = lr
except:
pass
return metrics
def format_size(size_bytes: int) -> str:
"""Format bytes into human readable string with appropriate unit
Args:
size_bytes: Size in bytes
Returns:
Formatted string (e.g. "1.5 Gb")
"""
units = ['bytes', 'Kb', 'Mb', 'Gb', 'Tb']
unit_index = 0
size = float(size_bytes)
while size >= 1024 and unit_index < len(units) - 1:
size /= 1024
unit_index += 1
# Special case for bytes - no decimal places
if unit_index == 0:
return f"{int(size)} {units[unit_index]}"
return f"{size:.1f} {units[unit_index]}"
def count_media_files(path: Path) -> Tuple[int, int, int]:
"""Count videos and images in directory
Args:
path: Directory to scan
Returns:
Tuple of (video_count, image_count, total_size)
"""
video_count = 0
image_count = 0
total_size = 0
for file in path.glob("*"):
# Skip hidden files and caption files
if file.name.startswith('.') or file.suffix.lower() == '.txt':
continue
if is_video_file(file):
video_count += 1
total_size += file.stat().st_size
elif is_image_file(file):
image_count += 1
total_size += file.stat().st_size
return video_count, image_count, total_size
def format_media_title(action: str, video_count: int, image_count: int, total_size: int) -> str:
"""Format title with media counts and size
Args:
action: Action (eg "split", "caption")
video_count: Number of videos
image_count: Number of images
total_size: Total size in bytes
Returns:
Formatted title string
"""
parts = []
if image_count > 0:
parts.append(f"{image_count:,} photo{'s' if image_count != 1 else ''}")
if video_count > 0:
parts.append(f"{video_count:,} video{'s' if video_count != 1 else ''}")
if not parts:
return f"## 0 files to {action} (0 bytes)"
return f"## {' and '.join(parts)} to {action} ({format_size(total_size)})"
def add_prefix_to_caption(caption: str, prefix: str) -> str:
"""Add prefix to caption if not already present"""
if not prefix or not caption:
return caption
if caption.startswith(prefix):
return caption
return f"{prefix}{caption}"
def format_time(seconds: float) -> str:
"""Format time duration in seconds to human readable string
Args:
seconds: Time in seconds
Returns:
Formatted string (e.g. "2h 30m 45s")
"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
parts = []
if hours > 0:
parts.append(f"{hours}h")
if minutes > 0:
parts.append(f"{minutes}m")
if secs > 0 or not parts:
parts.append(f"{secs}s")
return " ".join(parts)