vishred18's picture
Upload 364 files
d5ee97c
raw
history blame
3.05 kB
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""Utility functions."""
import fnmatch
import os
import re
import tempfile
from pathlib import Path
import tensorflow as tf
MODEL_FILE_NAME = "model.h5"
CONFIG_FILE_NAME = "config.yml"
PROCESSOR_FILE_NAME = "processor.json"
LIBRARY_NAME = "tensorflow_tts"
CACHE_DIRECTORY = os.path.join(Path.home(), ".cache", LIBRARY_NAME)
def find_files(root_dir, query="*.wav", include_root_dir=True):
"""Find files recursively.
Args:
root_dir (str): Root root_dir to find.
query (str): Query to find.
include_root_dir (bool): If False, root_dir name is not included.
Returns:
list: List of found filenames.
"""
files = []
for root, _, filenames in os.walk(root_dir, followlinks=True):
for filename in fnmatch.filter(filenames, query):
files.append(os.path.join(root, filename))
if not include_root_dir:
files = [file_.replace(root_dir + "/", "") for file_ in files]
return files
def _path_requires_gfile(filepath):
"""Checks if the given path requires use of GFile API.
Args:
filepath (str): Path to check.
Returns:
bool: True if the given path needs GFile API to access, such as
"s3://some/path" and "gs://some/path".
"""
# If the filepath contains a protocol (e.g. "gs://"), it should be handled
# using TensorFlow GFile API.
return bool(re.match(r"^[a-z]+://", filepath))
def save_weights(model, filepath):
"""Save model weights.
Same as model.save_weights(filepath), but supports saving to S3 or GCS
buckets using TensorFlow GFile API.
Args:
model (tf.keras.Model): Model to save.
filepath (str): Path to save the model weights to.
"""
if not _path_requires_gfile(filepath):
model.save_weights(filepath)
return
# Save to a local temp file and copy to the desired path using GFile API.
_, ext = os.path.splitext(filepath)
with tempfile.NamedTemporaryFile(suffix=ext) as temp_file:
model.save_weights(temp_file.name)
# To preserve the original semantics, we need to overwrite the target
# file.
tf.io.gfile.copy(temp_file.name, filepath, overwrite=True)
def load_weights(model, filepath):
"""Load model weights.
Same as model.load_weights(filepath), but supports loading from S3 or GCS
buckets using TensorFlow GFile API.
Args:
model (tf.keras.Model): Model to load weights to.
filepath (str): Path to the weights file.
"""
if not _path_requires_gfile(filepath):
model.load_weights(filepath)
return
# Make a local copy and load it.
_, ext = os.path.splitext(filepath)
with tempfile.NamedTemporaryFile(suffix=ext) as temp_file:
# The target temp_file should be created above, so we need to overwrite.
tf.io.gfile.copy(filepath, temp_file.name, overwrite=True)
model.load_weights(temp_file.name)