import logging import os import re from typing import List, Optional import huggingface_hub import requests from tqdm.auto import tqdm _MODELS = { "tiny.en": "Systran/faster-whisper-tiny.en", "tiny": "Systran/faster-whisper-tiny", "base.en": "Systran/faster-whisper-base.en", "base": "Systran/faster-whisper-base", "small.en": "Systran/faster-whisper-small.en", "small": "Systran/faster-whisper-small", "medium.en": "Systran/faster-whisper-medium.en", "medium": "Systran/faster-whisper-medium", "large-v1": "Systran/faster-whisper-large-v1", "large-v2": "Systran/faster-whisper-large-v2", "large-v3": "Systran/faster-whisper-large-v3", "large": "Systran/faster-whisper-large-v3", "distil-large-v2": "Systran/faster-distil-whisper-large-v2", "distil-medium.en": "Systran/faster-distil-whisper-medium.en", "distil-small.en": "Systran/faster-distil-whisper-small.en", "distil-large-v3": "Systran/faster-distil-whisper-large-v3", } def available_models() -> List[str]: """Returns the names of available models.""" return list(_MODELS.keys()) def get_assets_path(): """Returns the path to the assets directory.""" return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") def get_logger(): """Returns the module logger.""" return logging.getLogger("faster_whisper") def download_model( size_or_id: str, output_dir: Optional[str] = None, local_files_only: bool = False, cache_dir: Optional[str] = None, ): """Downloads a CTranslate2 Whisper model from the Hugging Face Hub. Args: size_or_id: Size of the model to download from https://huggingface.co/Systran (tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2, distil-large-v3), or a CTranslate2-converted model ID from the Hugging Face Hub (e.g. Systran/faster-whisper-large-v3). output_dir: Directory where the model should be saved. If not set, the model is saved in the cache directory. local_files_only: If True, avoid downloading the file and return the path to the local cached file if it exists. cache_dir: Path to the folder where cached files are stored. Returns: The path to the downloaded model. Raises: ValueError: if the model size is invalid. """ if re.match(r".*/.*", size_or_id): repo_id = size_or_id else: repo_id = _MODELS.get(size_or_id) if repo_id is None: raise ValueError( "Invalid model size '%s', expected one of: %s" % (size_or_id, ", ".join(_MODELS.keys())) ) allow_patterns = [ "config.json", "preprocessor_config.json", "model.bin", "tokenizer.json", "vocabulary.*", ] kwargs = { "local_files_only": local_files_only, "allow_patterns": allow_patterns, "tqdm_class": disabled_tqdm, } if output_dir is not None: kwargs["local_dir"] = output_dir kwargs["local_dir_use_symlinks"] = False if cache_dir is not None: kwargs["cache_dir"] = cache_dir try: return huggingface_hub.snapshot_download(repo_id, **kwargs) except ( huggingface_hub.utils.HfHubHTTPError, requests.exceptions.ConnectionError, ) as exception: logger = get_logger() logger.warning( "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s", repo_id, exception, ) logger.warning( "Trying to load the model directly from the local cache, if it exists." ) kwargs["local_files_only"] = True return huggingface_hub.snapshot_download(repo_id, **kwargs) def format_timestamp( seconds: float, always_include_hours: bool = False, decimal_marker: str = ".", ) -> str: assert seconds >= 0, "non-negative timestamp expected" milliseconds = round(seconds * 1000.0) hours = milliseconds // 3_600_000 milliseconds -= hours * 3_600_000 minutes = milliseconds // 60_000 milliseconds -= minutes * 60_000 seconds = milliseconds // 1_000 milliseconds -= seconds * 1_000 hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" return ( f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" ) class disabled_tqdm(tqdm): def __init__(self, *args, **kwargs): kwargs["disable"] = True super().__init__(*args, **kwargs) def get_end(segments: List[dict]) -> Optional[float]: return next( (w["end"] for s in reversed(segments) for w in reversed(s["words"])), segments[-1]["end"] if segments else None, )