|
|
|
|
|
|
|
|
|
import sys |
|
sys.path.append("..") |
|
|
|
from jinja2 import Template |
|
from utils import PERTURBATIONS, CHECKPOINT_WRITE_PATH, \ |
|
PAREN_MODELS, PAREN_MODEL_PATH |
|
import argparse |
|
import os |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser( |
|
prog='Generate yaml for training', |
|
description='Generate train and dataset yaml configs for mistral training') |
|
parser.add_argument('perturbation_type', |
|
default='all', |
|
const='all', |
|
nargs='?', |
|
choices=PERTURBATIONS.keys(), |
|
help='Perturbation function used to transform BabyLM dataset') |
|
parser.add_argument('train_set', |
|
default='all', |
|
const='all', |
|
nargs='?', |
|
choices=["100M", "10M"], |
|
help='BabyLM train set') |
|
parser.add_argument('random_seed', type=int, help="Random seed") |
|
parser.add_argument('paren_model', |
|
default='all', |
|
const='all', |
|
nargs='?', |
|
choices=list(PAREN_MODELS.keys()) + ["randinit"], |
|
help='Parenthesis model') |
|
parser.add_argument('-np', '--no_pos_encodings', action='store_true', |
|
help="Train GPT-2 with no positional encodings") |
|
|
|
|
|
args = parser.parse_args() |
|
if args.paren_model != "randinit": |
|
paren_model_path = PAREN_MODEL_PATH + PAREN_MODELS[args.paren_model] + "/checkpoint-5000" |
|
else: |
|
paren_model_path = "null" |
|
paren_model_name = args.paren_model |
|
no_pos_encodings_str = "-no-positional-encodings" if args.no_pos_encodings else "" |
|
no_pos_encodings_underscore = "_no_positional_encodings" if args.no_pos_encodings else "" |
|
|
|
|
|
yaml_directory = f"conf/babylm_{args.perturbation_type}_{args.train_set}_{paren_model_name}{no_pos_encodings_underscore}/seed{args.random_seed}" |
|
if not os.path.exists(yaml_directory): |
|
os.makedirs(yaml_directory) |
|
|
|
print("Generating GPT-2 model yaml file...") |
|
|
|
|
|
model_temp_file = open("conf/template/gpt2-small-template.yaml") |
|
lines = model_temp_file.readlines() |
|
model_temp_file.close() |
|
|
|
|
|
tokenizer = PERTURBATIONS[args.perturbation_type]["gpt2_tokenizer"] |
|
vocab_size = len(tokenizer) |
|
model_template = Template("".join(lines)) |
|
model_conf = model_template.render( |
|
perturbation=args.perturbation_type, |
|
vocab_size=vocab_size, |
|
paren_model=paren_model_name, |
|
paren_model_path=paren_model_path, |
|
no_pos_encodings=no_pos_encodings_str, |
|
) |
|
|
|
|
|
model_file = open( |
|
f"conf/babylm_{args.perturbation_type}_{args.train_set}_{paren_model_name}{no_pos_encodings_underscore}/gpt2{no_pos_encodings_str}-small-{args.perturbation_type}-{paren_model_name}.yaml", "w") |
|
model_file.write(model_conf) |
|
model_file.close() |
|
|
|
print("Generating train yaml file...") |
|
|
|
|
|
train_temp_file = open("conf/template/babylm_train_template.yaml") |
|
lines = train_temp_file.readlines() |
|
train_temp_file.close() |
|
|
|
|
|
train_template = Template("".join(lines)) |
|
train_conf = train_template.render( |
|
perturbation=args.perturbation_type, |
|
seed=args.random_seed, |
|
ckpt_path=CHECKPOINT_WRITE_PATH, |
|
train_set=args.train_set, |
|
paren_model=paren_model_name, |
|
no_pos_encodings=no_pos_encodings_str, |
|
no_pos_encodings_underscore=no_pos_encodings_underscore, |
|
) |
|
|
|
|
|
train_file = open(yaml_directory + \ |
|
f"/train_{args.perturbation_type}_{args.train_set}_{paren_model_name}{no_pos_encodings_underscore}_seed{args.random_seed}.yaml", "w") |
|
train_file.write(train_conf) |
|
train_file.close() |
|
|
|
print("Generating dataset yaml file...") |
|
|
|
|
|
dataset_temp_file = open("conf/template/babylm_dataset_template.yaml") |
|
lines = dataset_temp_file.readlines() |
|
dataset_temp_file.close() |
|
|
|
|
|
dataset_template = Template("".join(lines)) |
|
dataset_conf = dataset_template.render( |
|
perturbation=args.perturbation_type, |
|
train_set=args.train_set, |
|
seed=args.random_seed, |
|
) |
|
|
|
|
|
dataset_file = open(yaml_directory + \ |
|
f"/dataset_{args.perturbation_type}_{args.train_set}_seed{args.random_seed}.yaml", "w") |
|
dataset_file.write(dataset_conf) |
|
dataset_file.close() |
|
|
|
|
|
ckpt_directory = CHECKPOINT_WRITE_PATH + f"/babylm_{args.perturbation_type}_{args.train_set}_{paren_model_name}{no_pos_encodings_underscore}" |
|
if not os.path.exists(ckpt_directory): |
|
os.makedirs(ckpt_directory) |