File size: 8,507 Bytes
83d190a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 |
import glob
import logging
import os
import re
import subprocess
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union
import numpy as np
import torch
from numpy.typing import NDArray
from style_bert_vits2.logging import logger
from style_bert_vits2.models.utils import checkpoints # type: ignore
from style_bert_vits2.models.utils import safetensors # type: ignore
if TYPE_CHECKING:
# tensorboard はライブラリとしてインストールされている場合は依存関係に含まれないため、型チェック時のみインポートする
from torch.utils.tensorboard import SummaryWriter
__is_matplotlib_imported = False
def summarize(
writer: "SummaryWriter",
global_step: int,
scalars: dict[str, float] = {},
histograms: dict[str, Any] = {},
images: dict[str, Any] = {},
audios: dict[str, Any] = {},
audio_sampling_rate: int = 22050,
) -> None:
"""
指定されたデータを TensorBoard にまとめて追加する
Args:
writer (SummaryWriter): TensorBoard への書き込みを行うオブジェクト
global_step (int): グローバルステップ数
scalars (dict[str, float]): スカラー値の辞書
histograms (dict[str, Any]): ヒストグラムの辞書
images (dict[str, Any]): 画像データの辞書
audios (dict[str, Any]): 音声データの辞書
audio_sampling_rate (int): 音声データのサンプリングレート
"""
for k, v in scalars.items():
writer.add_scalar(k, v, global_step)
for k, v in histograms.items():
writer.add_histogram(k, v, global_step)
for k, v in images.items():
writer.add_image(k, v, global_step, dataformats="HWC")
for k, v in audios.items():
writer.add_audio(k, v, global_step, audio_sampling_rate)
def is_resuming(dir_path: Union[str, Path]) -> bool:
"""
指定されたディレクトリパスに再開可能なモデルが存在するかどうかを返す
Args:
dir_path: チェックするディレクトリのパス
Returns:
bool: 再開可能なモデルが存在するかどうか
"""
# JP-ExtraバージョンではDURがなくWDがあったり変わるため、Gのみで判断する
g_list = glob.glob(os.path.join(dir_path, "G_*.pth"))
# d_list = glob.glob(os.path.join(dir_path, "D_*.pth"))
# dur_list = glob.glob(os.path.join(dir_path, "DUR_*.pth"))
return len(g_list) > 0
def plot_spectrogram_to_numpy(spectrogram: NDArray[Any]) -> NDArray[Any]:
"""
指定されたスペクトログラムを画像データに変換する
Args:
spectrogram (NDArray[Any]): スペクトログラム
Returns:
NDArray[Any]: 画像データ
"""
global __is_matplotlib_imported
if not __is_matplotlib_imported:
import matplotlib
matplotlib.use("Agg")
__is_matplotlib_imported = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
import numpy as np
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax)
plt.xlabel("Frames")
plt.ylabel("Channels")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") # type: ignore
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def plot_alignment_to_numpy(
alignment: NDArray[Any], info: Optional[str] = None
) -> NDArray[Any]:
"""
指定されたアライメントを画像データに変換する
Args:
alignment (NDArray[Any]): アライメント
info (Optional[str]): 画像に追加する情報
Returns:
NDArray[Any]: 画像データ
"""
global __is_matplotlib_imported
if not __is_matplotlib_imported:
import matplotlib
matplotlib.use("Agg")
__is_matplotlib_imported = True
mpl_logger = logging.getLogger("matplotlib")
mpl_logger.setLevel(logging.WARNING)
import matplotlib.pylab as plt
fig, ax = plt.subplots(figsize=(6, 4))
im = ax.imshow(
alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
)
fig.colorbar(im, ax=ax)
xlabel = "Decoder timestep"
if info is not None:
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel("Encoder timestep")
plt.tight_layout()
fig.canvas.draw()
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") # type: ignore
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
plt.close()
return data
def load_wav_to_torch(full_path: Union[str, Path]) -> tuple[torch.FloatTensor, int]:
"""
指定された音声ファイルを読み込み、PyTorch のテンソルに変換して返す
Args:
full_path (Union[str, Path]): 音声ファイルのパス
Returns:
tuple[torch.FloatTensor, int]: 音声データのテンソルとサンプリングレート
"""
# この関数は学習時以外使われないため、ライブラリとしての style_bert_vits2 が
# 重たい scipy に依存しないように遅延 import する
try:
from scipy.io.wavfile import read
except ImportError:
raise ImportError("scipy is required to load wav file")
sampling_rate, data = read(full_path)
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
def load_filepaths_and_text(
filename: Union[str, Path], split: str = "|"
) -> list[list[str]]:
"""
指定されたファイルからファイルパスとテキストを読み込む
Args:
filename (Union[str, Path]): ファイルのパス
split (str): ファイルの区切り文字 (デフォルト: "|")
Returns:
list[list[str]]: ファイルパスとテキストのリスト
"""
with open(filename, encoding="utf-8") as f:
filepaths_and_text = [line.strip().split(split) for line in f]
return filepaths_and_text
def get_logger(
model_dir_path: Union[str, Path], filename: str = "train.log"
) -> logging.Logger:
"""
ロガーを取得する
Args:
model_dir_path (Union[str, Path]): ログを保存するディレクトリのパス
filename (str): ログファイルの名前 (デフォルト: "train.log")
Returns:
logging.Logger: ロガー
"""
global logger
logger = logging.getLogger(os.path.basename(model_dir_path))
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
if not os.path.exists(model_dir_path):
os.makedirs(model_dir_path)
h = logging.FileHandler(os.path.join(model_dir_path, filename))
h.setLevel(logging.DEBUG)
h.setFormatter(formatter)
logger.addHandler(h)
return logger
def get_steps(model_path: Union[str, Path]) -> Optional[int]:
"""
モデルのパスからイテレーション回数を取得する
Args:
model_path (Union[str, Path]): モデルのパス
Returns:
Optional[int]: イテレーション回数
"""
matches = re.findall(r"\d+", model_path) # type: ignore
return matches[-1] if matches else None
def check_git_hash(model_dir_path: Union[str, Path]) -> None:
"""
モデルのディレクトリに .git ディレクトリが存在する場合、ハッシュ値を比較する
Args:
model_dir_path (Union[str, Path]): モデルのディレクトリのパス
"""
source_dir = os.path.dirname(os.path.realpath(__file__))
if not os.path.exists(os.path.join(source_dir, ".git")):
logger.warning(
f"{source_dir} is not a git repository, therefore hash value comparison will be ignored."
)
return
cur_hash = subprocess.getoutput("git rev-parse HEAD")
path = os.path.join(model_dir_path, "githash")
if os.path.exists(path):
with open(path, encoding="utf-8") as f:
saved_hash = f.read()
if saved_hash != cur_hash:
logger.warning(
f"git hash values are different. {saved_hash[:8]}(saved) != {cur_hash[:8]}(current)"
)
else:
with open(path, "w", encoding="utf-8") as f:
f.write(cur_hash)
|