File size: 3,053 Bytes
d5ee97c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# -*- coding: utf-8 -*-

# Copyright 2019 Tomoki Hayashi
#  MIT License (https://opensource.org/licenses/MIT)
"""Utility functions."""

import fnmatch
import os
import re
import tempfile
from pathlib import Path

import tensorflow as tf

MODEL_FILE_NAME = "model.h5"
CONFIG_FILE_NAME = "config.yml"
PROCESSOR_FILE_NAME = "processor.json"
LIBRARY_NAME = "tensorflow_tts"
CACHE_DIRECTORY = os.path.join(Path.home(), ".cache", LIBRARY_NAME)


def find_files(root_dir, query="*.wav", include_root_dir=True):
    """Find files recursively.
    Args:
        root_dir (str): Root root_dir to find.
        query (str): Query to find.
        include_root_dir (bool): If False, root_dir name is not included.
    Returns:
        list: List of found filenames.
    """
    files = []
    for root, _, filenames in os.walk(root_dir, followlinks=True):
        for filename in fnmatch.filter(filenames, query):
            files.append(os.path.join(root, filename))
    if not include_root_dir:
        files = [file_.replace(root_dir + "/", "") for file_ in files]

    return files


def _path_requires_gfile(filepath):
    """Checks if the given path requires use of GFile API.

    Args:
        filepath (str): Path to check.
    Returns:
        bool: True if the given path needs GFile API to access, such as
            "s3://some/path" and "gs://some/path".
    """
    # If the filepath contains a protocol (e.g. "gs://"), it should be handled
    # using TensorFlow GFile API.
    return bool(re.match(r"^[a-z]+://", filepath))


def save_weights(model, filepath):
    """Save model weights.

    Same as model.save_weights(filepath), but supports saving to S3 or GCS
    buckets using TensorFlow GFile API.

    Args:
        model (tf.keras.Model): Model to save.
        filepath (str): Path to save the model weights to.
    """
    if not _path_requires_gfile(filepath):
        model.save_weights(filepath)
        return

    # Save to a local temp file and copy to the desired path using GFile API.
    _, ext = os.path.splitext(filepath)
    with tempfile.NamedTemporaryFile(suffix=ext) as temp_file:
        model.save_weights(temp_file.name)
        # To preserve the original semantics, we need to overwrite the target
        # file.
        tf.io.gfile.copy(temp_file.name, filepath, overwrite=True)


def load_weights(model, filepath):
    """Load model weights.

    Same as model.load_weights(filepath), but supports loading from S3 or GCS
    buckets using TensorFlow GFile API.

    Args:
        model (tf.keras.Model): Model to load weights to.
        filepath (str): Path to the weights file.
    """
    if not _path_requires_gfile(filepath):
        model.load_weights(filepath)
        return

    # Make a local copy and load it.
    _, ext = os.path.splitext(filepath)
    with tempfile.NamedTemporaryFile(suffix=ext) as temp_file:
        # The target temp_file should be created above, so we need to overwrite.
        tf.io.gfile.copy(filepath, temp_file.name, overwrite=True)
        model.load_weights(temp_file.name)