File size: 3,053 Bytes
d5ee97c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""Utility functions."""
import fnmatch
import os
import re
import tempfile
from pathlib import Path
import tensorflow as tf
MODEL_FILE_NAME = "model.h5"
CONFIG_FILE_NAME = "config.yml"
PROCESSOR_FILE_NAME = "processor.json"
LIBRARY_NAME = "tensorflow_tts"
CACHE_DIRECTORY = os.path.join(Path.home(), ".cache", LIBRARY_NAME)
def find_files(root_dir, query="*.wav", include_root_dir=True):
"""Find files recursively.
Args:
root_dir (str): Root root_dir to find.
query (str): Query to find.
include_root_dir (bool): If False, root_dir name is not included.
Returns:
list: List of found filenames.
"""
files = []
for root, _, filenames in os.walk(root_dir, followlinks=True):
for filename in fnmatch.filter(filenames, query):
files.append(os.path.join(root, filename))
if not include_root_dir:
files = [file_.replace(root_dir + "/", "") for file_ in files]
return files
def _path_requires_gfile(filepath):
"""Checks if the given path requires use of GFile API.
Args:
filepath (str): Path to check.
Returns:
bool: True if the given path needs GFile API to access, such as
"s3://some/path" and "gs://some/path".
"""
# If the filepath contains a protocol (e.g. "gs://"), it should be handled
# using TensorFlow GFile API.
return bool(re.match(r"^[a-z]+://", filepath))
def save_weights(model, filepath):
"""Save model weights.
Same as model.save_weights(filepath), but supports saving to S3 or GCS
buckets using TensorFlow GFile API.
Args:
model (tf.keras.Model): Model to save.
filepath (str): Path to save the model weights to.
"""
if not _path_requires_gfile(filepath):
model.save_weights(filepath)
return
# Save to a local temp file and copy to the desired path using GFile API.
_, ext = os.path.splitext(filepath)
with tempfile.NamedTemporaryFile(suffix=ext) as temp_file:
model.save_weights(temp_file.name)
# To preserve the original semantics, we need to overwrite the target
# file.
tf.io.gfile.copy(temp_file.name, filepath, overwrite=True)
def load_weights(model, filepath):
"""Load model weights.
Same as model.load_weights(filepath), but supports loading from S3 or GCS
buckets using TensorFlow GFile API.
Args:
model (tf.keras.Model): Model to load weights to.
filepath (str): Path to the weights file.
"""
if not _path_requires_gfile(filepath):
model.load_weights(filepath)
return
# Make a local copy and load it.
_, ext = os.path.splitext(filepath)
with tempfile.NamedTemporaryFile(suffix=ext) as temp_file:
# The target temp_file should be created above, so we need to overwrite.
tf.io.gfile.copy(filepath, temp_file.name, overwrite=True)
model.load_weights(temp_file.name)
|