File size: 4,918 Bytes
aec7e95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import argparse
import logging
from torch.utils.data import Dataset, IterableDataset
import gzip
import json
from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments
import sys
from datetime import datetime
import torch
import random
from shutil import copyfile
import os
import wandb
import random
import re
from datasets import load_dataset
import tqdm
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
parser = argparse.ArgumentParser()
parser.add_argument("--lang", required=True)
parser.add_argument("--model_name", default="google/mt5-base")
parser.add_argument("--epochs", default=4, type=int)
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--max_source_length", default=320, type=int)
parser.add_argument("--max_target_length", default=64, type=int)
parser.add_argument("--eval_size", default=1000, type=int)
#parser.add_argument("--fp16", default=False, action='store_true')
args = parser.parse_args()
wandb.init(project="doc2query", name=f"{args.lang}-{args.model_name}")
def main():
############ Load dataset
queries = {}
for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'queries-{args.lang}')['train']):
queries[row['id']] = row['text']
"""
collection = {}
for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']):
collection[row['id']] = row['text']
"""
collection = load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']
train_pairs = []
eval_pairs = []
with open('qrels.train.tsv') as fIn:
for line in fIn:
qid, _, did, _ = line.strip().split("\t")
qid = int(qid)
did = int(did)
assert did == collection[did]['id']
text = collection[did]['text']
pair = (queries[qid], text)
if len(eval_pairs) < args.eval_size:
eval_pairs.append(pair)
else:
train_pairs.append(pair)
print(f"Train pairs: {len(train_pairs)}")
############ Model
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
save_steps = 1000
output_dir = 'output/'+args.lang+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
print("Output dir:", output_dir)
# Write self to path
os.makedirs(output_dir, exist_ok=True)
train_script_path = os.path.join(output_dir, 'train_script.py')
copyfile(__file__, train_script_path)
with open(train_script_path, 'a') as fOut:
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
####
training_args = Seq2SeqTrainingArguments(
output_dir=output_dir,
bf16=True,
per_device_train_batch_size=args.batch_size,
evaluation_strategy="steps",
save_steps=save_steps,
logging_steps=100,
eval_steps=save_steps, #logging_steps,
warmup_steps=1000,
save_total_limit=1,
num_train_epochs=args.epochs,
report_to="wandb",
)
############ Arguments
############ Load datasets
print("Input:", train_pairs[0][1])
print("Target:", train_pairs[0][0])
print("Input:", eval_pairs[0][1])
print("Target:", eval_pairs[0][0])
def data_collator(examples):
targets = [row[0] for row in examples]
inputs = [row[1] for row in examples]
label_pad_token_id = -100
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None)
# Setup the tokenizer for targets
with tokenizer.as_target_tokenizer():
labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None)
# replace all tokenizer.pad_token_id in the labels by -100 to ignore padding in the loss.
labels["input_ids"] = [
[(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"]
]
model_inputs["labels"] = torch.tensor(labels["input_ids"])
return model_inputs
## Define the trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_pairs,
eval_dataset=eval_pairs,
tokenizer=tokenizer,
data_collator=data_collator
)
### Save the model
train_result = trainer.train()
trainer.save_model()
if __name__ == "__main__":
main()
# Script was called via:
#python train_hf_trainer_multilingual.py --lang arabic |