|
|
|
"""*********************************************************************************************""" |
|
|
|
|
|
|
|
|
|
"""*********************************************************************************************""" |
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import sys |
|
import math |
|
import torch |
|
import shutil |
|
import builtins |
|
import numpy as np |
|
from time import time |
|
from typing import List |
|
from pathlib import Path |
|
from datetime import datetime |
|
from collections import defaultdict |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
from torch.distributed import is_initialized, get_rank, get_world_size |
|
|
|
def is_leader_process(): |
|
return not is_initialized() or get_rank() == 0 |
|
|
|
def count_parameters(model): |
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
def count_used_parameters(model): |
|
|
|
return sum(p.numel() for p in model.parameters() if p.grad is not None) |
|
|
|
def get_time_tag(): |
|
return datetime.fromtimestamp(time()).strftime('%Y-%m-%d-%H-%M-%S') |
|
|
|
def backup(src_path, tgt_dir): |
|
stem = Path(src_path).stem |
|
suffix = Path(src_path).suffix |
|
shutil.copyfile(src_path, os.path.join(tgt_dir, f'{stem}_{get_time_tag()}{suffix}')) |
|
|
|
def get_model_state(model): |
|
if isinstance(model, DDP): |
|
return model.module.state_dict() |
|
return model.state_dict() |
|
|
|
def show(*args, **kwargs): |
|
if is_leader_process(): |
|
print(*args, **kwargs) |
|
|
|
def hack_isinstance(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_isinstance = builtins.isinstance |
|
def isinstance(obj, cls): |
|
if _isinstance(obj, defaultdict): |
|
return _isinstance(obj, cls) and issubclass(cls, defaultdict) |
|
return _isinstance(obj, cls) |
|
builtins.isinstance = isinstance |
|
|
|
def override(string, args, config): |
|
""" |
|
Example usgae: |
|
-o "config.optimizer.lr=1.0e-3,,config.optimizer.name='AdamW',,config.runner.eval_dataloaders=['dev', 'test']" |
|
""" |
|
options = string.split(',,') |
|
for option in options: |
|
option = option.strip() |
|
key, value_str = option.split('=') |
|
key, value_str = key.strip(), value_str.strip() |
|
first_field, *remaining = key.split('.') |
|
|
|
try: |
|
value = eval(value_str) |
|
except: |
|
value = value_str |
|
|
|
print(f'[Override] - {key} = {value}', file=sys.stderr) |
|
|
|
if first_field == 'args': |
|
assert len(remaining) == 1 |
|
setattr(args, remaining[0], value) |
|
elif first_field == 'config': |
|
target_config = config |
|
for i, field_name in enumerate(remaining): |
|
if i == len(remaining) - 1: |
|
target_config[field_name] = value |
|
else: |
|
target_config.setdefault(field_name, {}) |
|
target_config = target_config[field_name] |
|
|
|
def zero_mean_unit_var_norm(input_values: List[np.ndarray]) -> List[np.ndarray]: |
|
""" |
|
Every array in the list is normalized to have zero mean and unit variance |
|
Taken from huggingface to ensure the same behavior across s3prl and huggingface |
|
Reference: https://github.com/huggingface/transformers/blob/a26f4d620874b32d898a5b712006a4c856d07de1/src/transformers/models/wav2vec2/feature_extraction_wav2vec2.py#L81-L86 |
|
""" |
|
return [(x - np.mean(x)) / np.sqrt(np.var(x) + 1e-5) for x in input_values] |
|
|
|
|
|
|
|
|
|
def parse_prune_heads(config): |
|
if 'prune_headids' in config['transformer'] and config['transformer']['prune_headids'] != 'None': |
|
heads_int = [] |
|
spans = config['transformer']['prune_headids'].split(',') |
|
for span in spans: |
|
endpoints = span.split('-') |
|
if len(endpoints) == 1: |
|
heads_int.append(int(endpoints[0])) |
|
elif len(endpoints) == 2: |
|
heads_int += torch.arange(int(endpoints[0]), int(endpoints[1])).tolist() |
|
else: |
|
raise ValueError |
|
print(f'[PRUNING] - heads {heads_int} will be pruned') |
|
config['transformer']['prune_headids'] = heads_int |
|
else: |
|
config['transformer']['prune_headids'] = None |
|
|
|
|
|
|
|
|
|
|
|
def get_transformer_tester(from_path='result/result_transformer/libri_sd1337_fmllrBase960-F-N-K-RA/model-1000000.ckpt', display_settings=False): |
|
''' Wrapper that loads the transformer model from checkpoint path ''' |
|
|
|
|
|
all_states = torch.load(from_path, map_location='cpu') |
|
config = all_states['Settings']['Config'] |
|
paras = all_states['Settings']['Paras'] |
|
|
|
|
|
if not hasattr(paras, 'multi_gpu'): |
|
setattr(paras, 'multi_gpu', False) |
|
if 'prune_headids' not in config['transformer']: |
|
config['transformer']['prune_headids'] = None |
|
|
|
|
|
if display_settings: |
|
for cluster in config: |
|
print(cluster + ':') |
|
for item in config[cluster]: |
|
print('\t' + str(item) + ': ', config[cluster][item]) |
|
print('paras:') |
|
v_paras = vars(paras) |
|
for item in v_paras: |
|
print('\t' + str(item) + ': ', v_paras[item]) |
|
|
|
|
|
from transformer.solver import Tester |
|
tester = Tester(config, paras) |
|
tester.set_model(inference=True, with_head=False, from_path=from_path) |
|
return tester |
|
|