import os |
import sys |
import math |
import glob |
import uuid |
import shutil |
import random |
import tempfile |
import importlib |
from pathlib import Path |
import torch |
import torchaudio |
import numpy as np |
from tqdm import tqdm |
from tensorboardX import SummaryWriter |
from torch.utils.data import DistributedSampler |
from torch.nn.parallel import DistributedDataParallel as DDP |
from torch.distributed import is_initialized, get_rank, get_world_size |
from s3prl import hub |
from s3prl.optimizers import get_optimizer |
from s3prl.schedulers import get_scheduler |
from s3prl.upstream.interfaces import Featurizer |
from s3prl.utility.helper import is_leader_process, get_model_state, show, defaultdict |
from huggingface_hub import HfApi, HfFolder, Repository |
SAMPLE_RATE = 16000 |
datasets: |
- superb |
tags: |
- library:s3prl |
- benchmark:superb |
- type:model |
--- |
# Fine-tuned s3prl model |
Upstream Model: {upstream_model} |
## Model description |
[More information needed] |
## Intended uses & limitations |
[More information needed] |
## How to use |
[More information needed] |
## Limitations and bias |
[More information needed] |
## Training data |
[More information needed] |
## Training procedure |
[More information needed] |
## Evaluation results |
[More information needed] |
""" |
class ModelEntry: |
def __init__(self, model, name, trainable, interfaces): |
self.model = model |
self.name = name |
self.trainable = trainable |
self.interfaces = interfaces |
class Runner(): |
""" |
Used to handle high-level concepts of a ML experiment |
eg. training loop, evaluation loop, upstream propagation, optimization, logging, checkpoint saving |
""" |
def __init__(self, args, config): |
self.args = args |
self.config = config |
self.init_ckpt = torch.load(self.args.init_ckpt, map_location='cpu') if self.args.init_ckpt else {} |
self.upstream = self._get_upstream() |
self.featurizer = self._get_featurizer() |
self.downstream = self._get_downstream() |
self.all_entries = [self.upstream, self.featurizer, self.downstream] |
def _load_weight(self, model, name): |
init_weight = self.init_ckpt.get(name) |
if init_weight: |
show(f'[Runner] - Loading {name} weights from the previous experiment') |
model.load_state_dict(init_weight) |
def _init_model(self, model, name, trainable, interfaces=None): |
for interface in interfaces or []: |
assert hasattr(model, interface), interface |
self._load_weight(model, name) |
if is_initialized() and trainable and any((p.requires_grad for p in model.parameters())): |
model = DDP(model, device_ids=[self.args.local_rank], find_unused_parameters=True) |
for interface in interfaces or []: |
setattr(model, interface, getattr(model.module, interface)) |
return ModelEntry(model, name, trainable, interfaces) |
def _get_upstream(self): |
if "from_hf_hub" in self.args and self.args.from_hf_hub == True: |
from huggingface_hub import snapshot_download |
print(f'[Runner] - Downloading upstream model {self.args.upstream} from the Hugging Face Hub') |
filepath = snapshot_download(self.args.upstream, self.args.upstream_revision, use_auth_token=True) |
sys.path.append(filepath) |
dependencies = (Path(filepath) / 'requirements.txt').resolve() |
print("[Dependency] - The downloaded upstream model requires the following dependencies. Please make sure they are installed:") |
for idx, line in enumerate((Path(filepath) / "requirements.txt").open().readlines()): |
print(f"{idx}. {line.strip()}") |
print(f"You can install them by:") |
print() |
print(f"pip install -r {dependencies}") |
print() |
from expert import UpstreamExpert |
Upstream = UpstreamExpert |
ckpt_path = os.path.join(filepath, self.args.upstream_model_name) |
else: |
Upstream = getattr(hub, self.args.upstream) |
ckpt_path = self.args.upstream_ckpt |
upstream_refresh = self.args.upstream_refresh |
if is_initialized() and get_rank() > 0: |
torch.distributed.barrier() |
upstream_refresh = False |
model = Upstream( |
ckpt = ckpt_path, |
model_config = self.args.upstream_model_config, |
refresh = upstream_refresh, |
).to(self.args.device) |
if is_initialized() and get_rank() == 0: |
torch.distributed.barrier() |
return self._init_model( |
model = model, |
name = 'Upstream', |
trainable = self.args.upstream_trainable, |
interfaces = ["get_downsample_rates"] |
) |
def _get_featurizer(self): |
model = Featurizer( |
upstream = self.upstream.model, |
feature_selection = self.args.upstream_feature_selection, |
layer_selection = self.args.upstream_layer_selection, |
upstream_device = self.args.device, |
normalize = self.args.upstream_feature_normalize, |
).to(self.args.device) |
return self._init_model( |
model = model, |
name = 'Featurizer', |
trainable = True, |
interfaces = ['output_dim', 'downsample_rate'] |
) |
def _get_downstream(self): |
expert = importlib.import_module(f"s3prl.downstream.{self.args.downstream}.expert") |
Downstream = getattr(expert, "DownstreamExpert") |
model = Downstream( |
upstream_dim = self.featurizer.model.output_dim, |
upstream_rate = self.featurizer.model.downsample_rate, |
**self.config, |
**vars(self.args) |
).to(self.args.device) |
return self._init_model( |
model = model, |
name = 'Downstream', |
trainable = True, |
interfaces = ['get_dataloader', 'log_records'] |
) |
def _get_optimizer(self, model_params): |
optimizer = get_optimizer( |
model_params, |
self.config['runner']['total_steps'], |
self.config['optimizer'] |
) |
self._load_weight(optimizer, 'Optimizer') |
return optimizer |
def _get_scheduler(self, optimizer): |
scheduler = get_scheduler( |
optimizer, |
self.config['runner']['total_steps'], |
self.config['scheduler'] |
) |
self._load_weight(scheduler, 'Scheduler') |
return scheduler |
def _create_model_card(self, path): |
model_card = MODEL_CARD_MARKDOWN.format(upstream_model=self.args.upstream) |
with open(os.path.join(path, "README.md"), "w") as f: |
f.write(model_card) |
def train(self): |
trainable_models = [] |
trainable_paras = [] |
for entry in self.all_entries: |
if entry.trainable: |
entry.model.train().to(self.args.device) |
trainable_models.append(entry.model) |
trainable_paras += list(entry.model.parameters()) |
else: |
entry.model.eval() |
amp = self.config['runner'].get('fp16', False) |
if amp: |
print('[Runner] - Enabled fp16 training') |
scaler = torch.cuda.amp.GradScaler() |
optimizer = self._get_optimizer(trainable_models) |
scheduler = None |
if self.config.get('scheduler'): |
scheduler = self._get_scheduler(optimizer) |
specaug = None |
if self.config.get('specaug'): |
from .specaug import SpecAug |
specaug = SpecAug(**self.config["specaug"]) |
tqdm_file = sys.stderr if is_leader_process() else open(os.devnull, 'w') |
pbar = tqdm(total=self.config['runner']['total_steps'], dynamic_ncols=True, desc='overall', file=tqdm_file) |
init_step = self.init_ckpt.get('Step') |
if init_step: |
pbar.n = init_step |
if is_leader_process(): |
logger = SummaryWriter(self.args.expdir) |
batch_ids = [] |
backward_steps = 0 |
records = defaultdict(list) |
epoch = self.init_ckpt.get('Epoch', 0) |
train_split = self.config['runner'].get("train_dataloader", "train") |
while pbar.n < pbar.total: |
try: |
dataloader = self.downstream.model.get_dataloader(train_split, epoch=epoch) |
except TypeError as e: |
if "unexpected keyword argument 'epoch'" in str(e): |
dataloader = self.downstream.model.get_dataloader(train_split) |
if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, DistributedSampler): |
dataloader.sampler.set_epoch(epoch) |
else: |
raise |
for batch_id, (wavs, *others) in enumerate(tqdm(dataloader, dynamic_ncols=True, desc='train', file=tqdm_file)): |
try: |
if pbar.n >= pbar.total: |
break |
global_step = pbar.n + 1 |
wavs = [torch.FloatTensor(wav).to(self.args.device) for wav in wavs] |
with torch.cuda.amp.autocast(enabled=amp): |
if self.upstream.trainable: |
features = self.upstream.model(wavs) |
else: |
with torch.no_grad(): |
features = self.upstream.model(wavs) |
features = self.featurizer.model(wavs, features) |
if specaug: |
features, _ = specaug(features) |
loss = self.downstream.model( |
train_split, |
features, *others, |
records = records, |
) |
batch_ids.append(batch_id) |
gradient_accumulate_steps = self.config['runner'].get('gradient_accumulate_steps') |
loss = (loss / gradient_accumulate_steps) |
if amp: |
scaler.scale(loss).backward() |
else: |
loss.backward() |
del loss |
except RuntimeError as e: |
if 'CUDA out of memory' in str(e): |
print(f'[Runner] - CUDA out of memory at step {global_step}') |
if is_initialized(): |
raise |
with torch.cuda.device(self.args.device): |
torch.cuda.empty_cache() |
optimizer.zero_grad() |
continue |
else: |
raise |
backward_steps += 1 |
if backward_steps % gradient_accumulate_steps > 0: |
continue |
if amp: |
scaler.unscale_(optimizer) |
grad_norm = torch.nn.utils.clip_grad_norm_( |
trainable_paras, self.config['runner']['gradient_clipping']) |
if amp: |
scaler.step(optimizer) |
scaler.update() |
elif math.isnan(grad_norm): |
print(f'[Runner] - grad norm is NaN at step {global_step}') |
else: |
optimizer.step() |
optimizer.zero_grad() |
if scheduler: |
scheduler.step() |
if not is_leader_process(): |
batch_ids = [] |
records = defaultdict(list) |
continue |
if global_step % self.config['runner']['log_step'] == 0: |
self.downstream.model.log_records( |
train_split, |
records = records, |
logger = logger, |
global_step = global_step, |
batch_ids = batch_ids, |
total_batch_num = len(dataloader), |
) |
batch_ids = [] |
records = defaultdict(list) |
save_names = [] |
if global_step % self.config['runner']['eval_step'] == 0: |
for split in self.config['runner']['eval_dataloaders']: |
save_names += self.evaluate(split, logger, global_step) |
if global_step % self.config['runner']['save_step'] == 0: |
def check_ckpt_num(directory): |
max_keep = self.config['runner']['max_keep'] |
ckpt_pths = glob.glob(f'{directory}/states-*.ckpt') |
if len(ckpt_pths) >= max_keep: |
ckpt_pths = sorted(ckpt_pths, key=lambda pth: int(pth.split('-')[-1].split('.')[0])) |
for ckpt_pth in ckpt_pths[:len(ckpt_pths) - max_keep + 1]: |
os.remove(ckpt_pth) |
check_ckpt_num(self.args.expdir) |
save_names.append(f'states-{global_step}.ckpt') |
if len(save_names) > 0: |
all_states = { |
'Optimizer': optimizer.state_dict(), |
'Step': global_step, |
'Epoch': epoch, |
'Args': self.args, |
'Config': self.config, |
} |
for entry in self.all_entries: |
if entry.trainable: |
all_states[entry.name] = get_model_state(entry.model) |
if scheduler: |
all_states['Scheduler'] = scheduler.state_dict() |
if is_initialized(): |
all_states['WorldSize'] = get_world_size() |
save_paths = [os.path.join(self.args.expdir, name) for name in save_names] |
tqdm.write(f'[Runner] - Save the checkpoint to:') |
for i, path in enumerate(save_paths): |
tqdm.write(f'{i + 1}. {path}') |
torch.save(all_states, path) |
pbar.update(1) |
epoch += 1 |
pbar.close() |
if self.args.push_to_hf_hub: |
self.push_to_huggingface_hub() |
if is_leader_process(): |
logger.close() |
def evaluate(self, split=None, logger=None, global_step=0): |
"""evaluate function will always be called on a single process even during distributed training""" |
not_during_training = split is None and logger is None and global_step == 0 |
if not_during_training: |
split = self.args.evaluate_split |
tempdir = tempfile.mkdtemp() |
logger = SummaryWriter(tempdir) |
random.seed(self.args.seed) |
np.random.seed(self.args.seed) |
torch.manual_seed(self.args.seed) |
if torch.cuda.is_available(): |
torch.cuda.manual_seed_all(self.args.seed) |
with torch.cuda.device(self.args.device): |
torch.cuda.empty_cache() |
trainings = [] |
for entry in self.all_entries: |
trainings.append(entry.model.training) |
entry.model.eval() |
dataloader = self.downstream.model.get_dataloader(split) |
evaluate_ratio = float(self.config["runner"].get("evaluate_ratio", 1)) |
evaluate_steps = round(len(dataloader) * evaluate_ratio) |
batch_ids = [] |
records = defaultdict(list) |
for batch_id, (wavs, *others) in enumerate(tqdm(dataloader, dynamic_ncols=True, desc=split, total=evaluate_steps)): |
if batch_id > evaluate_steps: |
break |
wavs = [torch.FloatTensor(wav).to(self.args.device) for wav in wavs] |
with torch.no_grad(): |
features = self.upstream.model(wavs) |
features = self.featurizer.model(wavs, features) |
self.downstream.model( |
split, |
features, *others, |
records = records, |
batch_id = batch_id, |
) |
batch_ids.append(batch_id) |
save_names = self.downstream.model.log_records( |
split, |
records = records, |
logger = logger, |
global_step = global_step, |
batch_ids = batch_ids, |
total_batch_num = len(dataloader), |
) |
batch_ids = [] |
records = defaultdict(list) |
if torch.cuda.is_available(): |
with torch.cuda.device(self.args.device): |
torch.cuda.empty_cache() |
for entry, training in zip(self.all_entries, trainings): |
if training: |
entry.model.train().to(self.args.device) |
if not_during_training: |
logger.close() |
shutil.rmtree(tempdir) |
return [] if type(save_names) is not list else save_names |
def inference(self): |
filepath = Path(self.args.evaluate_split) |
assert filepath.is_file(), filepath |
filename = filepath.stem |
if hasattr(self.downstream.model, "load_audio"): |
wav = self.downstream.model.load_audio(filepath) |
else: |
wav, sr = torchaudio.load(str(filepath)) |
assert sr == SAMPLE_RATE, sr |
wavs = [wav.view(-1).to(self.args.device)] |
for entry in self.all_entries: |
entry.model.eval() |
with torch.no_grad(): |
features = self.upstream.model(wavs) |
features = self.featurizer.model(wavs, features) |
self.downstream.model.inference(features, [filename]) |
def push_to_huggingface_hub(self): |
"""Creates a downstream repository on the Hub and pushes training artifacts to it.""" |
if self.args.hf_hub_org.lower() != "none": |
organization = self.args.hf_hub_org |
else: |
organization = os.environ.get("HF_USERNAME") |
huggingface_token = HfFolder.get_token() |
print(f"[Runner] - Organisation to push fine-tuned model to: {organization}") |
if self.args.hub == "huggingface": |
model_info = HfApi().model_info(self.args.upstream, token=huggingface_token) |
downstream_model_id = model_info.sha |
upstream_model_id = model_info.modelId.replace("/", "__") |
else: |
upstream_model_id = self.args.upstream.replace("/", "__") |
downstream_model_id = str(uuid.uuid4())[:8] |
repo_name = f"{upstream_model_id}__{downstream_model_id}" |
repo_url = HfApi().create_repo( |
token=huggingface_token, |
name=repo_name, |
organization=organization, |
exist_ok=True, |
private=False, |
) |
print(f"[Runner] - Created Hub repo: {repo_url}") |
HF_HUB_DIR = "hf_hub" |
REPO_ROOT_DIR = os.path.join(self.args.expdir, HF_HUB_DIR, repo_name) |
REPO_TASK_DIR = os.path.join(REPO_ROOT_DIR, self.args.downstream, self.args.expname) |
print(f"[Runner] - Cloning Hub repo to {REPO_ROOT_DIR}") |
model_repo = Repository( |
local_dir=REPO_ROOT_DIR, clone_from=repo_url, use_auth_token=huggingface_token |
) |
model_repo.git_pull() |
shutil.copytree(self.args.expdir, REPO_TASK_DIR, dirs_exist_ok=True, ignore=shutil.ignore_patterns(HF_HUB_DIR)) |
checkpoints = list(Path(REPO_TASK_DIR).glob("*best*.ckpt")) |
if len(checkpoints) == 0: |
print("[Runner] - Did not find a best checkpoint! Using the final checkpoint instead ...") |
os.path.join(REPO_TASK_DIR, f"states-{self.config['runner']['total_steps']}.ckpt") |
) |
elif len(checkpoints) > 1: |
print(f"[Runner] - More than one best checkpoint found! Using {checkpoints[0]} as default ...") |
CKPT_PATH = checkpoints[0] |
else: |
print(f"[Runner] - Found best checkpoint {checkpoints[0]}!") |
CKPT_PATH = checkpoints[0] |
shutil.move(CKPT_PATH, os.path.join(REPO_TASK_DIR, "model.ckpt")) |
model_repo.lfs_track("*.ckpt") |
self._create_model_card(REPO_ROOT_DIR) |
print("[Runner] - Pushing model files to the Hub ...") |
model_repo.push_to_hub() |
print("[Runner] - Training run complete!") |