ms-marco-TinyBERT-L-6 / train_script.py
nreimers
upload
67d6b31
from torch.utils.data import DataLoader
from sentence_transformers import LoggingHandler
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
from sentence_transformers import InputExample
import logging
from datetime import datetime
import gzip
import sys
import numpy as np
import os
from shutil import copyfile
import csv
import tqdm
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
#### /print debug information to stdout
#Define our Cross-Encoder
model_name = sys.argv[1] #'google/electra-small-discriminator'
train_batch_size = 32
num_epochs = 1
model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
#We set num_labels=1, which predicts a continous score between 0 and 1
model = CrossEncoder(model_name, num_labels=1, max_length=512)
# Write self to path
os.makedirs(model_save_path, exist_ok=True)
train_script_path = os.path.join(model_save_path, 'train_script.py')
copyfile(__file__, train_script_path)
with open(train_script_path, 'a') as fOut:
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
corpus = {}
queries = {}
#### Read train file
with gzip.open('../data/collection.tsv.gz', 'rt') as fIn:
for line in fIn:
pid, passage = line.strip().split("\t")
corpus[pid] = passage
with open('../data/queries.train.tsv', 'r') as fIn:
for line in fIn:
qid, query = line.strip().split("\t")
queries[qid] = query
pos_neg_ration = (4+1)
cnt = 0
train_samples = []
dev_samples = {}
num_dev_queries = 125
num_max_dev_negatives = 200
with gzip.open('../data/qidpidtriples.rnd-shuf.train-eval.tsv.gz', 'rt') as fIn:
for line in fIn:
qid, pos_id, neg_id = line.strip().split()
if qid not in dev_samples and len(dev_samples) < num_dev_queries:
dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}
if qid in dev_samples:
dev_samples[qid]['positive'].add(corpus[pos_id])
if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
dev_samples[qid]['negative'].add(corpus[neg_id])
with gzip.open('../data/qidpidtriples.rnd-shuf.train.tsv.gz', 'rt') as fIn:
for line in tqdm.tqdm(fIn, unit_scale=True):
cnt += 1
qid, pos_id, neg_id = line.strip().split()
query = queries[qid]
if (cnt % pos_neg_ration) == 0:
passage = corpus[pos_id]
label = 1
else:
passage = corpus[neg_id]
label = 0
train_samples.append(InputExample(texts=[query, passage], label=label))
if len(train_samples) >= 2e7:
break
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
# We add an evaluator, which evaluates the performance during training
class CERerankingEvaluator:
def __init__(self, samples, mrr_at_k: int = 10, name: str = ''):
self.samples = samples
self.name = name
self.mrr_at_k = mrr_at_k
if isinstance(self.samples, dict):
self.samples = list(self.samples.values())
self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv"
self.csv_headers = ["epoch", "steps", "MRR@{}".format(mrr_at_k)]
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
if epoch != -1:
if steps == -1:
out_txt = " after epoch {}:".format(epoch)
else:
out_txt = " in epoch {} after {} steps:".format(epoch, steps)
else:
out_txt = ":"
logging.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)
all_mrr_scores = []
num_queries = 0
num_positives = []
num_negatives = []
for instance in self.samples:
query = instance['query']
positive = list(instance['positive'])
negative = list(instance['negative'])
docs = positive + negative
is_relevant = [True]*len(positive) + [False]*len(negative)
if len(positive) == 0 or len(negative) == 0:
continue
num_queries += 1
num_positives.append(len(positive))
num_negatives.append(len(negative))
model_input = [[query, doc] for doc in docs]
pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False)
pred_scores_argsort = np.argsort(-pred_scores) #Sort in decreasing order
mrr_score = 0
for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]):
if is_relevant[index]:
mrr_score = 1 / (rank+1)
all_mrr_scores.append(mrr_score)
mean_mrr = np.mean(all_mrr_scores)
logging.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives)))
logging.info("MRR@{}: {:.2f}".format(self.mrr_at_k, mean_mrr*100))
if output_path is not None:
csv_path = os.path.join(output_path, self.csv_file)
output_file_exists = os.path.isfile(csv_path)
with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
writer = csv.writer(f)
if not output_file_exists:
writer.writerow(self.csv_headers)
writer.writerow([epoch, steps, mean_mrr])
return mean_mrr
evaluator = CERerankingEvaluator(dev_samples)
# Configure the training
warmup_steps = 5000
logging.info("Warmup-steps: {}".format(warmup_steps))
# Train the model
model.fit(train_dataloader=train_dataloader,
evaluator=evaluator,
epochs=num_epochs,
evaluation_steps=5000,
warmup_steps=warmup_steps,
output_path=model_save_path,
use_amp=True)
#Save latest model
model.save(model_save_path+'-latest')
# Script was called via:
#python train_cross-encoder.py nreimers/TinyBERT_L-6_H-768_v2