|
import gzip |
|
import random |
|
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, AdamW |
|
import sys |
|
import torch |
|
import transformers |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.cuda.amp import autocast |
|
import tqdm |
|
from datetime import datetime |
|
from shutil import copyfile |
|
import os |
|
|
|
|
|
import gzip |
|
from collections import defaultdict |
|
import logging |
|
import tqdm |
|
import numpy as np |
|
import sys |
|
import pytrec_eval |
|
from sentence_transformers import SentenceTransformer, util, CrossEncoder |
|
import torch |
|
|
|
|
|
model_name = sys.argv[1] |
|
max_length = 350 |
|
|
|
|
|
queries_filepath = 'msmarco-data/trec2019/msmarco-test2019-queries.tsv.gz' |
|
queries_eval = {} |
|
with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn: |
|
for line in fIn: |
|
qid, query = line.strip().split("\t")[0:2] |
|
queries_eval[qid] = query |
|
|
|
rel = defaultdict(lambda: defaultdict(int)) |
|
|
|
with open('msmarco-data/trec2019/2019qrels-pass.txt') as fIn: |
|
for line in fIn: |
|
qid, _, pid, score = line.strip().split() |
|
score = int(score) |
|
if score > 0: |
|
rel[qid][pid] = score |
|
|
|
relevant_qid = [] |
|
for qid in queries_eval: |
|
if len(rel[qid]) > 0: |
|
relevant_qid.append(qid) |
|
|
|
|
|
passage_cand = {} |
|
|
|
with gzip.open('msmarco-data/trec2019/msmarco-passagetest2019-top1000.tsv.gz', 'rt', encoding='utf8') as fIn: |
|
for line in fIn: |
|
qid, pid, query, passage = line.strip().split("\t") |
|
if qid not in passage_cand: |
|
passage_cand[qid] = [] |
|
|
|
passage_cand[qid].append([pid, passage]) |
|
|
|
|
|
|
|
def eval_modal(model_path): |
|
run = {} |
|
model = CrossEncoder(model_path, max_length=512) |
|
|
|
for qid in relevant_qid: |
|
query = queries_eval[qid] |
|
|
|
cand = passage_cand[qid] |
|
pids = [c[0] for c in cand] |
|
corpus_sentences = [c[1] for c in cand] |
|
|
|
|
|
cross_inp = [[query, sent] for sent in corpus_sentences] |
|
if model.config.num_labels > 1: |
|
cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist() |
|
else: |
|
cross_scores = model.predict(cross_inp, activation_fct=torch.nn.Identity()).tolist() |
|
|
|
cross_scores_sparse = {} |
|
for idx, pid in enumerate(pids): |
|
cross_scores_sparse[pid] = cross_scores[idx] |
|
|
|
sparse_scores = cross_scores_sparse |
|
run[qid] = {} |
|
for pid in sparse_scores: |
|
run[qid][pid] = float(sparse_scores[pid]) |
|
|
|
evaluator = pytrec_eval.RelevanceEvaluator(rel, {'ndcg_cut.10'}) |
|
scores = evaluator.evaluate(run) |
|
scores_mean = np.mean([ele["ndcg_cut_10"] for ele in scores.values()]) |
|
|
|
print("NDCG@10: {:.2f}".format(scores_mean * 100)) |
|
return scores_mean |
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
config = AutoConfig.from_pretrained(model_name) |
|
config.num_labels = 1 |
|
model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config) |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
queries = {} |
|
corpus = {} |
|
|
|
output_save_path = 'output/train_cross-encoder_mse-{}-{}'.format(model_name.replace("/", "_"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) |
|
output_save_path_latest = output_save_path+"-latest" |
|
tokenizer.save_pretrained(output_save_path) |
|
tokenizer.save_pretrained(output_save_path_latest) |
|
|
|
|
|
|
|
train_script_path = os.path.join(output_save_path, 'train_script.py') |
|
copyfile(__file__, train_script_path) |
|
with open(train_script_path, 'a') as fOut: |
|
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) |
|
|
|
|
|
|
|
train_script_path = os.path.join(output_save_path_latest, 'train_script.py') |
|
copyfile(__file__, train_script_path) |
|
with open(train_script_path, 'a') as fOut: |
|
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv)) |
|
|
|
|
|
|
|
|
|
class MultilingualDataset(Dataset): |
|
def __init__(self): |
|
self.examples = defaultdict(lambda: defaultdict(list)) |
|
|
|
def add(self, lang, filepath): |
|
open_method = gzip.open if filepath.endswith('.gz') else open |
|
with open_method(filepath, 'rt') as fIn: |
|
for line in fIn: |
|
pid, passage = line.strip().split("\t") |
|
self.examples[pid][lang].append(passage) |
|
|
|
|
|
def __len__(self): |
|
return len(self.examples) |
|
|
|
def __getitem__(self, item): |
|
all_examples = self.examples[item] |
|
lang_examples = random.choice(list(all_examples.values())) |
|
return random.choice(lang_examples) |
|
|
|
|
|
train_corpus = MultilingualDataset() |
|
train_corpus.add('en', 'msmarco-data/collection.tsv') |
|
train_corpus.add('de', 'msmarco-data/de/collection.de.opus-mt.tsv.gz') |
|
train_corpus.add('de', 'msmarco-data/de/collection.de.wmt19.tsv.gz') |
|
|
|
|
|
train_queries = MultilingualDataset() |
|
train_queries.add('en', 'msmarco-data/queries.train.tsv') |
|
train_queries.add('de', 'msmarco-data/de/queries.train.de.opus-mt.tsv.gz') |
|
train_queries.add('de', 'msmarco-data/de/queries.train.de.wmt19.tsv.gz') |
|
|
|
|
|
class MSEDataset(Dataset): |
|
def __init__(self, filepath): |
|
super().__init__() |
|
|
|
self.examples = [] |
|
with open(filepath) as fIn: |
|
for line in fIn: |
|
pos_score, neg_score, qid, pid1, pid2 = line.strip().split("\t") |
|
self.examples.append([qid, pid1, pid2, float(pos_score)-float(neg_score)]) |
|
|
|
def __len__(self): |
|
return len(self.examples) |
|
|
|
def __getitem__(self, item): |
|
return self.examples[item] |
|
|
|
train_batch_size = 16 |
|
train_dataset = MSEDataset('msmarco-data/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv') |
|
train_dataloader = DataLoader(train_dataset, drop_last=True, shuffle=True, batch_size=train_batch_size) |
|
|
|
|
|
|
|
|
|
weight_decay = 0.01 |
|
max_grad_norm = 1 |
|
param_optimizer = list(model.named_parameters()) |
|
|
|
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] |
|
optimizer_grouped_parameters = [ |
|
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay}, |
|
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} |
|
] |
|
|
|
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5) |
|
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=len(train_dataloader)) |
|
scaler = torch.cuda.amp.GradScaler() |
|
|
|
loss_fct = torch.nn.MSELoss() |
|
|
|
model.to(device) |
|
|
|
auto_save = 10000 |
|
best_ndcg_score = 0 |
|
for step_idx, batch in tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader)): |
|
batch_queries = [train_queries[qid] for qid in batch[0]] |
|
batch_pos = [train_corpus[cid] for cid in batch[1]] |
|
batch_neg = [train_corpus[cid] for cid in batch[2]] |
|
scores = batch[3].float().to(device) |
|
|
|
with autocast(): |
|
inp_pos = tokenizer(batch_queries, batch_pos, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device) |
|
pred_pos = model(**inp_pos).logits.squeeze() |
|
|
|
inp_neg = tokenizer(batch_queries, batch_neg, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device) |
|
pred_neg = model(**inp_neg).logits.squeeze() |
|
|
|
pred_diff = pred_pos - pred_neg |
|
loss_value = loss_fct(pred_diff, scores) |
|
|
|
|
|
scaler.scale(loss_value).backward() |
|
scaler.unscale_(optimizer) |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) |
|
scaler.step(optimizer) |
|
scaler.update() |
|
|
|
optimizer.zero_grad() |
|
scheduler.step() |
|
|
|
if (step_idx+1) % auto_save == 0: |
|
print("Step:", step_idx+1) |
|
model.save_pretrained(output_save_path_latest) |
|
ndcg_score = eval_modal(output_save_path_latest) |
|
|
|
if ndcg_score >= best_ndcg_score: |
|
best_ndcg_score = ndcg_score |
|
print("Save to:", output_save_path) |
|
model.save_pretrained(output_save_path) |
|
|
|
model.save_pretrained(output_save_path) |
|
|
|
|
|
|
|
|