Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- encoding: utf-8 -*- | |
import os | |
import sys | |
import torch | |
import hydra | |
import logging | |
import argparse | |
from io import BytesIO | |
import torch.distributed as dist | |
from collections.abc import Sequence | |
from omegaconf import DictConfig, OmegaConf | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | |
from funasr_detach.register import tables | |
from funasr_detach.optimizers import optim_classes | |
from funasr_detach.train_utils.trainer import Trainer | |
from funasr_detach.schedulers import scheduler_classes | |
from funasr_detach.train_utils.initialize import initialize | |
from funasr_detach.download.download_from_hub import download_model | |
from funasr_detach.models.lora.utils import mark_only_lora_as_trainable | |
from funasr_detach.train_utils.set_all_random_seed import set_all_random_seed | |
from funasr_detach.train_utils.load_pretrained_model import load_pretrained_model | |
# from funasr_detach.tokenizer.build_tokenizer import build_tokenizer | |
# from funasr_detach.tokenizer.token_id_converter import TokenIDConverter | |
# from funasr_detach.tokenizer.funtoken import build_tokenizer | |
def main_hydra(kwargs: DictConfig): | |
if kwargs.get("debug", False): | |
import pdb | |
pdb.set_trace() | |
assert "model" in kwargs | |
if "model_conf" not in kwargs: | |
logging.info( | |
"download models from model hub: {}".format(kwargs.get("model_hub", "ms")) | |
) | |
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs) | |
main(**kwargs) | |
def main(**kwargs): | |
print(kwargs) | |
# set random seed | |
set_all_random_seed(kwargs.get("seed", 0)) | |
torch.backends.cudnn.enabled = kwargs.get( | |
"cudnn_enabled", torch.backends.cudnn.enabled | |
) | |
torch.backends.cudnn.benchmark = kwargs.get( | |
"cudnn_benchmark", torch.backends.cudnn.benchmark | |
) | |
torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) | |
local_rank = int(os.environ.get("LOCAL_RANK", 0)) | |
if local_rank == 0: | |
tables.print() | |
# Check if we are using DDP or FSDP | |
use_ddp = "WORLD_SIZE" in os.environ and int(os.environ["WORLD_SIZE"]) > 1 | |
use_fsdp = kwargs.get("use_fsdp", None) | |
if use_ddp or use_fsdp: | |
dist.init_process_group( | |
backend=kwargs.get("backend", "nccl"), init_method="env://" | |
) | |
torch.cuda.set_device(local_rank) | |
# save config.yaml | |
if ( | |
(use_ddp or use_fsdp) | |
and dist.get_rank() == 0 | |
or not (use_ddp or use_fsdp) | |
and local_rank == 0 | |
): | |
os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True) | |
yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml") | |
OmegaConf.save(config=kwargs, f=yaml_file) | |
logging.info("config.yaml is saved to: %s", yaml_file) | |
tokenizer = kwargs.get("tokenizer", None) | |
if tokenizer is not None: | |
tokenizer_class = tables.tokenizer_classes.get(tokenizer) | |
tokenizer = tokenizer_class(**kwargs["tokenizer_conf"]) | |
kwargs["tokenizer"] = tokenizer | |
# build frontend if frontend is none None | |
frontend = kwargs.get("frontend", None) | |
if frontend is not None: | |
frontend_class = tables.frontend_classes.get(frontend) | |
frontend = frontend_class(**kwargs["frontend_conf"]) | |
kwargs["frontend"] = frontend | |
kwargs["input_size"] = frontend.output_size() | |
# build model | |
model_class = tables.model_classes.get(kwargs["model"]) | |
model = model_class( | |
**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list) | |
) | |
# init_param | |
init_param = kwargs.get("init_param", None) | |
if init_param is not None: | |
if not isinstance(init_param, (list, tuple)): | |
init_param = (init_param,) | |
logging.info("init_param is not None: %s", init_param) | |
for p in init_param: | |
logging.info(f"Loading pretrained params from {p}") | |
load_pretrained_model( | |
model=model, | |
path=p, | |
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), | |
oss_bucket=kwargs.get("oss_bucket", None), | |
scope_map=kwargs.get("scope_map", None), | |
excludes=kwargs.get("excludes", None), | |
) | |
else: | |
initialize(model, kwargs.get("init", "kaiming_normal")) | |
# freeze_param | |
freeze_param = kwargs.get("freeze_param", None) | |
if freeze_param is not None: | |
freeze_param = eval(freeze_param) | |
if isinstance(freeze_param, Sequence): | |
freeze_param = (freeze_param,) | |
logging.info("freeze_param is not None: %s", freeze_param) | |
for t in freeze_param: | |
for k, p in model.named_parameters(): | |
if k.startswith(t + ".") or k == t: | |
logging.info(f"Setting {k}.requires_grad = False") | |
p.requires_grad = False | |
if use_ddp: | |
model = model.cuda(local_rank) | |
model = DDP( | |
model, | |
device_ids=[local_rank], | |
find_unused_parameters=kwargs.get("train_conf", {}).get( | |
"find_unused_parameters", False | |
), | |
) | |
elif use_fsdp: | |
model = FSDP(model).cuda(local_rank) | |
else: | |
model = model.to(device=kwargs.get("device", "cuda")) | |
# optim | |
optim = kwargs.get("optim", "adam") | |
assert optim in optim_classes | |
optim_class = optim_classes.get(optim) | |
optim = optim_class(model.parameters(), **kwargs.get("optim_conf")) | |
# scheduler | |
scheduler = kwargs.get("scheduler", "warmuplr") | |
assert scheduler in scheduler_classes | |
scheduler_class = scheduler_classes.get(scheduler) | |
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) | |
# dataset | |
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) | |
dataset_tr = dataset_class( | |
kwargs.get("train_data_set_list"), | |
frontend=frontend, | |
tokenizer=tokenizer, | |
is_training=True, | |
**kwargs.get("dataset_conf"), | |
) | |
dataset_val = dataset_class( | |
kwargs.get("valid_data_set_list"), | |
frontend=frontend, | |
tokenizer=tokenizer, | |
is_training=False, | |
**kwargs.get("dataset_conf"), | |
) | |
# dataloader | |
batch_sampler = kwargs["dataset_conf"].get( | |
"batch_sampler", "DynamicBatchLocalShuffleSampler" | |
) | |
batch_sampler_val = None | |
if batch_sampler is not None: | |
batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) | |
batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) | |
batch_sampler_val = batch_sampler_class( | |
dataset_val, is_training=False, **kwargs.get("dataset_conf") | |
) | |
dataloader_tr = torch.utils.data.DataLoader( | |
dataset_tr, | |
collate_fn=dataset_tr.collator, | |
batch_sampler=batch_sampler, | |
num_workers=kwargs.get("dataset_conf").get("num_workers", 4), | |
pin_memory=True, | |
) | |
dataloader_val = torch.utils.data.DataLoader( | |
dataset_val, | |
collate_fn=dataset_val.collator, | |
batch_sampler=batch_sampler_val, | |
num_workers=kwargs.get("dataset_conf").get("num_workers", 4), | |
pin_memory=True, | |
) | |
trainer = Trainer( | |
model=model, | |
optim=optim, | |
scheduler=scheduler, | |
dataloader_train=dataloader_tr, | |
dataloader_val=dataloader_val, | |
local_rank=local_rank, | |
use_ddp=use_ddp, | |
use_fsdp=use_fsdp, | |
output_dir=kwargs.get("output_dir", "./exp"), | |
resume=kwargs.get("resume", True), | |
**kwargs.get("train_conf"), | |
) | |
trainer.run() | |
if use_ddp or use_fsdp: | |
torch.distributed.destroy_process_group() | |
if __name__ == "__main__": | |
main_hydra() | |