|
import os |
|
import pickle as pickle_tts |
|
from typing import Any, Callable, Dict, Union |
|
|
|
import fsspec |
|
import torch |
|
|
|
from TTS.utils.generic_utils import get_user_data_dir |
|
|
|
|
|
class RenamingUnpickler(pickle_tts.Unpickler): |
|
"""Overload default pickler to solve module renaming problem""" |
|
|
|
def find_class(self, module, name): |
|
return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) |
|
|
|
|
|
class AttrDict(dict): |
|
"""A custom dict which converts dict keys |
|
to class attributes""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.__dict__ = self |
|
|
|
|
|
def load_fsspec( |
|
path: str, |
|
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, |
|
cache: bool = True, |
|
**kwargs, |
|
) -> Any: |
|
"""Like torch.load but can load from other locations (e.g. s3:// , gs://). |
|
|
|
Args: |
|
path: Any path or url supported by fsspec. |
|
map_location: torch.device or str. |
|
cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True. |
|
**kwargs: Keyword arguments forwarded to torch.load. |
|
|
|
Returns: |
|
Object stored in path. |
|
""" |
|
is_local = os.path.isdir(path) or os.path.isfile(path) |
|
if cache and not is_local: |
|
with fsspec.open( |
|
f"filecache::{path}", |
|
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))}, |
|
mode="rb", |
|
) as f: |
|
return torch.load(f, map_location=map_location, **kwargs) |
|
else: |
|
with fsspec.open(path, "rb") as f: |
|
return torch.load(f, map_location=map_location, **kwargs) |
|
|
|
|
|
def load_checkpoint( |
|
model, checkpoint_path, use_cuda=False, eval=False, cache=False |
|
): |
|
try: |
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) |
|
except ModuleNotFoundError: |
|
pickle_tts.Unpickler = RenamingUnpickler |
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache) |
|
model.load_state_dict(state["model"]) |
|
if use_cuda: |
|
model.cuda() |
|
if eval: |
|
model.eval() |
|
return model, state |
|
|