Update fish_speech/utils/utils.py
Browse files- fish_speech/utils/utils.py +136 -114
fish_speech/utils/utils.py
CHANGED
@@ -1,114 +1,136 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
from
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
from
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
#
|
74 |
-
|
75 |
-
|
76 |
-
#
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
#
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import warnings
|
3 |
+
from importlib.util import find_spec
|
4 |
+
from typing import Callable
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from omegaconf import DictConfig
|
9 |
+
|
10 |
+
from .logger import RankedLogger
|
11 |
+
from .rich_utils import enforce_tags, print_config_tree
|
12 |
+
|
13 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
14 |
+
|
15 |
+
|
16 |
+
def extras(cfg: DictConfig) -> None:
|
17 |
+
"""Applies optional utilities before the task is started.
|
18 |
+
|
19 |
+
Utilities:
|
20 |
+
- Ignoring python warnings
|
21 |
+
- Setting tags from command line
|
22 |
+
- Rich config printing
|
23 |
+
"""
|
24 |
+
|
25 |
+
# return if no `extras` config
|
26 |
+
if not cfg.get("extras"):
|
27 |
+
log.warning("Extras config not found! <cfg.extras=null>")
|
28 |
+
return
|
29 |
+
|
30 |
+
# disable python warnings
|
31 |
+
if cfg.extras.get("ignore_warnings"):
|
32 |
+
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
|
33 |
+
warnings.filterwarnings("ignore")
|
34 |
+
|
35 |
+
# prompt user to input tags from command line if none are provided in the config
|
36 |
+
if cfg.extras.get("enforce_tags"):
|
37 |
+
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
|
38 |
+
enforce_tags(cfg, save_to_file=True)
|
39 |
+
|
40 |
+
# pretty print config tree using Rich library
|
41 |
+
if cfg.extras.get("print_config"):
|
42 |
+
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
|
43 |
+
print_config_tree(cfg, resolve=True, save_to_file=True)
|
44 |
+
|
45 |
+
|
46 |
+
def task_wrapper(task_func: Callable) -> Callable:
|
47 |
+
"""Optional decorator that controls the failure behavior when executing the task function.
|
48 |
+
|
49 |
+
This wrapper can be used to:
|
50 |
+
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
|
51 |
+
- save the exception to a `.log` file
|
52 |
+
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
|
53 |
+
- etc. (adjust depending on your needs)
|
54 |
+
|
55 |
+
Example:
|
56 |
+
```
|
57 |
+
@utils.task_wrapper
|
58 |
+
def train(cfg: DictConfig) -> Tuple[dict, dict]:
|
59 |
+
|
60 |
+
...
|
61 |
+
|
62 |
+
return metric_dict, object_dict
|
63 |
+
```
|
64 |
+
""" # noqa: E501
|
65 |
+
|
66 |
+
def wrap(cfg: DictConfig):
|
67 |
+
# execute the task
|
68 |
+
try:
|
69 |
+
metric_dict, object_dict = task_func(cfg=cfg)
|
70 |
+
|
71 |
+
# things to do if exception occurs
|
72 |
+
except Exception as ex:
|
73 |
+
# save exception to `.log` file
|
74 |
+
log.exception("")
|
75 |
+
|
76 |
+
# some hyperparameter combinations might be invalid or
|
77 |
+
# cause out-of-memory errors so when using hparam search
|
78 |
+
# plugins like Optuna, you might want to disable
|
79 |
+
# raising the below exception to avoid multirun failure
|
80 |
+
raise ex
|
81 |
+
|
82 |
+
# things to always do after either success or exception
|
83 |
+
finally:
|
84 |
+
# display output dir path in terminal
|
85 |
+
log.info(f"Output dir: {cfg.paths.run_dir}")
|
86 |
+
|
87 |
+
# always close wandb run (even if exception occurs so multirun won't fail)
|
88 |
+
if find_spec("wandb"): # check if wandb is installed
|
89 |
+
import wandb
|
90 |
+
|
91 |
+
if wandb.run:
|
92 |
+
log.info("Closing wandb!")
|
93 |
+
wandb.finish()
|
94 |
+
|
95 |
+
return metric_dict, object_dict
|
96 |
+
|
97 |
+
return wrap
|
98 |
+
|
99 |
+
|
100 |
+
def get_metric_value(metric_dict: dict, metric_name: str) -> float:
|
101 |
+
"""Safely retrieves value of the metric logged in LightningModule."""
|
102 |
+
|
103 |
+
if not metric_name:
|
104 |
+
log.info("Metric name is None! Skipping metric value retrieval...")
|
105 |
+
return None
|
106 |
+
|
107 |
+
if metric_name not in metric_dict:
|
108 |
+
raise Exception(
|
109 |
+
f"Metric value not found! <metric_name={metric_name}>\n"
|
110 |
+
"Make sure metric name logged in LightningModule is correct!\n"
|
111 |
+
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
|
112 |
+
)
|
113 |
+
|
114 |
+
metric_value = metric_dict[metric_name].item()
|
115 |
+
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
|
116 |
+
|
117 |
+
return metric_value
|
118 |
+
|
119 |
+
|
120 |
+
def set_seed(seed: int):
|
121 |
+
if seed < 0:
|
122 |
+
seed = -seed
|
123 |
+
if seed > (1 << 31):
|
124 |
+
seed = 1 << 31
|
125 |
+
|
126 |
+
random.seed(seed)
|
127 |
+
np.random.seed(seed)
|
128 |
+
torch.manual_seed(seed)
|
129 |
+
|
130 |
+
if torch.cuda.is_available():
|
131 |
+
torch.cuda.manual_seed(seed)
|
132 |
+
torch.cuda.manual_seed_all(seed)
|
133 |
+
|
134 |
+
if torch.backends.cudnn.is_available():
|
135 |
+
torch.backends.cudnn.deterministic = True
|
136 |
+
torch.backends.cudnn.benchmark = False
|