TaiYa1 commited on
Commit
f41ab23
·
verified ·
1 Parent(s): cfe1020

Update fish_speech/utils/utils.py

Browse files
Files changed (1) hide show
  1. fish_speech/utils/utils.py +136 -114
fish_speech/utils/utils.py CHANGED
@@ -1,114 +1,136 @@
1
- import warnings
2
- from importlib.util import find_spec
3
- from typing import Callable
4
-
5
- from omegaconf import DictConfig
6
-
7
- from .logger import RankedLogger
8
- from .rich_utils import enforce_tags, print_config_tree
9
-
10
- log = RankedLogger(__name__, rank_zero_only=True)
11
-
12
-
13
- def extras(cfg: DictConfig) -> None:
14
- """Applies optional utilities before the task is started.
15
-
16
- Utilities:
17
- - Ignoring python warnings
18
- - Setting tags from command line
19
- - Rich config printing
20
- """
21
-
22
- # return if no `extras` config
23
- if not cfg.get("extras"):
24
- log.warning("Extras config not found! <cfg.extras=null>")
25
- return
26
-
27
- # disable python warnings
28
- if cfg.extras.get("ignore_warnings"):
29
- log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
30
- warnings.filterwarnings("ignore")
31
-
32
- # prompt user to input tags from command line if none are provided in the config
33
- if cfg.extras.get("enforce_tags"):
34
- log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
35
- enforce_tags(cfg, save_to_file=True)
36
-
37
- # pretty print config tree using Rich library
38
- if cfg.extras.get("print_config"):
39
- log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
40
- print_config_tree(cfg, resolve=True, save_to_file=True)
41
-
42
-
43
- def task_wrapper(task_func: Callable) -> Callable:
44
- """Optional decorator that controls the failure behavior when executing the task function.
45
-
46
- This wrapper can be used to:
47
- - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
48
- - save the exception to a `.log` file
49
- - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
50
- - etc. (adjust depending on your needs)
51
-
52
- Example:
53
- ```
54
- @utils.task_wrapper
55
- def train(cfg: DictConfig) -> Tuple[dict, dict]:
56
-
57
- ...
58
-
59
- return metric_dict, object_dict
60
- ```
61
- """ # noqa: E501
62
-
63
- def wrap(cfg: DictConfig):
64
- # execute the task
65
- try:
66
- metric_dict, object_dict = task_func(cfg=cfg)
67
-
68
- # things to do if exception occurs
69
- except Exception as ex:
70
- # save exception to `.log` file
71
- log.exception("")
72
-
73
- # some hyperparameter combinations might be invalid or
74
- # cause out-of-memory errors so when using hparam search
75
- # plugins like Optuna, you might want to disable
76
- # raising the below exception to avoid multirun failure
77
- raise ex
78
-
79
- # things to always do after either success or exception
80
- finally:
81
- # display output dir path in terminal
82
- log.info(f"Output dir: {cfg.paths.run_dir}")
83
-
84
- # always close wandb run (even if exception occurs so multirun won't fail)
85
- if find_spec("wandb"): # check if wandb is installed
86
- import wandb
87
-
88
- if wandb.run:
89
- log.info("Closing wandb!")
90
- wandb.finish()
91
-
92
- return metric_dict, object_dict
93
-
94
- return wrap
95
-
96
-
97
- def get_metric_value(metric_dict: dict, metric_name: str) -> float:
98
- """Safely retrieves value of the metric logged in LightningModule."""
99
-
100
- if not metric_name:
101
- log.info("Metric name is None! Skipping metric value retrieval...")
102
- return None
103
-
104
- if metric_name not in metric_dict:
105
- raise Exception(
106
- f"Metric value not found! <metric_name={metric_name}>\n"
107
- "Make sure metric name logged in LightningModule is correct!\n"
108
- "Make sure `optimized_metric` name in `hparams_search` config is correct!"
109
- )
110
-
111
- metric_value = metric_dict[metric_name].item()
112
- log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
113
-
114
- return metric_value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import warnings
3
+ from importlib.util import find_spec
4
+ from typing import Callable
5
+
6
+ import numpy as np
7
+ import torch
8
+ from omegaconf import DictConfig
9
+
10
+ from .logger import RankedLogger
11
+ from .rich_utils import enforce_tags, print_config_tree
12
+
13
+ log = RankedLogger(__name__, rank_zero_only=True)
14
+
15
+
16
+ def extras(cfg: DictConfig) -> None:
17
+ """Applies optional utilities before the task is started.
18
+
19
+ Utilities:
20
+ - Ignoring python warnings
21
+ - Setting tags from command line
22
+ - Rich config printing
23
+ """
24
+
25
+ # return if no `extras` config
26
+ if not cfg.get("extras"):
27
+ log.warning("Extras config not found! <cfg.extras=null>")
28
+ return
29
+
30
+ # disable python warnings
31
+ if cfg.extras.get("ignore_warnings"):
32
+ log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
33
+ warnings.filterwarnings("ignore")
34
+
35
+ # prompt user to input tags from command line if none are provided in the config
36
+ if cfg.extras.get("enforce_tags"):
37
+ log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
38
+ enforce_tags(cfg, save_to_file=True)
39
+
40
+ # pretty print config tree using Rich library
41
+ if cfg.extras.get("print_config"):
42
+ log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
43
+ print_config_tree(cfg, resolve=True, save_to_file=True)
44
+
45
+
46
+ def task_wrapper(task_func: Callable) -> Callable:
47
+ """Optional decorator that controls the failure behavior when executing the task function.
48
+
49
+ This wrapper can be used to:
50
+ - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
51
+ - save the exception to a `.log` file
52
+ - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
53
+ - etc. (adjust depending on your needs)
54
+
55
+ Example:
56
+ ```
57
+ @utils.task_wrapper
58
+ def train(cfg: DictConfig) -> Tuple[dict, dict]:
59
+
60
+ ...
61
+
62
+ return metric_dict, object_dict
63
+ ```
64
+ """ # noqa: E501
65
+
66
+ def wrap(cfg: DictConfig):
67
+ # execute the task
68
+ try:
69
+ metric_dict, object_dict = task_func(cfg=cfg)
70
+
71
+ # things to do if exception occurs
72
+ except Exception as ex:
73
+ # save exception to `.log` file
74
+ log.exception("")
75
+
76
+ # some hyperparameter combinations might be invalid or
77
+ # cause out-of-memory errors so when using hparam search
78
+ # plugins like Optuna, you might want to disable
79
+ # raising the below exception to avoid multirun failure
80
+ raise ex
81
+
82
+ # things to always do after either success or exception
83
+ finally:
84
+ # display output dir path in terminal
85
+ log.info(f"Output dir: {cfg.paths.run_dir}")
86
+
87
+ # always close wandb run (even if exception occurs so multirun won't fail)
88
+ if find_spec("wandb"): # check if wandb is installed
89
+ import wandb
90
+
91
+ if wandb.run:
92
+ log.info("Closing wandb!")
93
+ wandb.finish()
94
+
95
+ return metric_dict, object_dict
96
+
97
+ return wrap
98
+
99
+
100
+ def get_metric_value(metric_dict: dict, metric_name: str) -> float:
101
+ """Safely retrieves value of the metric logged in LightningModule."""
102
+
103
+ if not metric_name:
104
+ log.info("Metric name is None! Skipping metric value retrieval...")
105
+ return None
106
+
107
+ if metric_name not in metric_dict:
108
+ raise Exception(
109
+ f"Metric value not found! <metric_name={metric_name}>\n"
110
+ "Make sure metric name logged in LightningModule is correct!\n"
111
+ "Make sure `optimized_metric` name in `hparams_search` config is correct!"
112
+ )
113
+
114
+ metric_value = metric_dict[metric_name].item()
115
+ log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
116
+
117
+ return metric_value
118
+
119
+
120
+ def set_seed(seed: int):
121
+ if seed < 0:
122
+ seed = -seed
123
+ if seed > (1 << 31):
124
+ seed = 1 << 31
125
+
126
+ random.seed(seed)
127
+ np.random.seed(seed)
128
+ torch.manual_seed(seed)
129
+
130
+ if torch.cuda.is_available():
131
+ torch.cuda.manual_seed(seed)
132
+ torch.cuda.manual_seed_all(seed)
133
+
134
+ if torch.backends.cudnn.is_available():
135
+ torch.backends.cudnn.deterministic = True
136
+ torch.backends.cudnn.benchmark = False