|
|
|
|
|
|
|
|
|
"""Utility functions.""" |
|
|
|
import fnmatch |
|
import os |
|
import re |
|
import tempfile |
|
from pathlib import Path |
|
|
|
import tensorflow as tf |
|
|
|
MODEL_FILE_NAME = "model.h5" |
|
CONFIG_FILE_NAME = "config.yml" |
|
PROCESSOR_FILE_NAME = "processor.json" |
|
LIBRARY_NAME = "tensorflow_tts" |
|
CACHE_DIRECTORY = os.path.join(Path.home(), ".cache", LIBRARY_NAME) |
|
|
|
|
|
def find_files(root_dir, query="*.wav", include_root_dir=True): |
|
"""Find files recursively. |
|
Args: |
|
root_dir (str): Root root_dir to find. |
|
query (str): Query to find. |
|
include_root_dir (bool): If False, root_dir name is not included. |
|
Returns: |
|
list: List of found filenames. |
|
""" |
|
files = [] |
|
for root, _, filenames in os.walk(root_dir, followlinks=True): |
|
for filename in fnmatch.filter(filenames, query): |
|
files.append(os.path.join(root, filename)) |
|
if not include_root_dir: |
|
files = [file_.replace(root_dir + "/", "") for file_ in files] |
|
|
|
return files |
|
|
|
|
|
def _path_requires_gfile(filepath): |
|
"""Checks if the given path requires use of GFile API. |
|
|
|
Args: |
|
filepath (str): Path to check. |
|
Returns: |
|
bool: True if the given path needs GFile API to access, such as |
|
"s3://some/path" and "gs://some/path". |
|
""" |
|
|
|
|
|
return bool(re.match(r"^[a-z]+://", filepath)) |
|
|
|
|
|
def save_weights(model, filepath): |
|
"""Save model weights. |
|
|
|
Same as model.save_weights(filepath), but supports saving to S3 or GCS |
|
buckets using TensorFlow GFile API. |
|
|
|
Args: |
|
model (tf.keras.Model): Model to save. |
|
filepath (str): Path to save the model weights to. |
|
""" |
|
if not _path_requires_gfile(filepath): |
|
model.save_weights(filepath) |
|
return |
|
|
|
|
|
_, ext = os.path.splitext(filepath) |
|
with tempfile.NamedTemporaryFile(suffix=ext) as temp_file: |
|
model.save_weights(temp_file.name) |
|
|
|
|
|
tf.io.gfile.copy(temp_file.name, filepath, overwrite=True) |
|
|
|
|
|
def load_weights(model, filepath): |
|
"""Load model weights. |
|
|
|
Same as model.load_weights(filepath), but supports loading from S3 or GCS |
|
buckets using TensorFlow GFile API. |
|
|
|
Args: |
|
model (tf.keras.Model): Model to load weights to. |
|
filepath (str): Path to the weights file. |
|
""" |
|
if not _path_requires_gfile(filepath): |
|
model.load_weights(filepath) |
|
return |
|
|
|
|
|
_, ext = os.path.splitext(filepath) |
|
with tempfile.NamedTemporaryFile(suffix=ext) as temp_file: |
|
|
|
tf.io.gfile.copy(filepath, temp_file.name, overwrite=True) |
|
model.load_weights(temp_file.name) |
|
|