nreimers
commited on
Commit
·
2054d27
1
Parent(s):
c542599
upload
Browse files- README.md +65 -0
- config.json +32 -0
- pytorch_model.bin +3 -0
- sentencepiece.bpe.model +3 -0
- special_tokens_map.json +1 -0
- tokenizer.json +0 -0
- tokenizer_config.json +1 -0
- train_script.py +253 -0
README.md
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cross-Encoder for MS MARCO - EN-DE
|
2 |
+
|
3 |
+
This is a cross-lingual Cross-Encoder model for EN-DE that can be used for passage re-ranking. It was trained on the [MS Marco Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) task.
|
4 |
+
|
5 |
+
The model can be used for Information Retrieval: See [SBERT.net Retrieve & Re-rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html).
|
6 |
+
|
7 |
+
The training code is available in this repository, see `train_script.py`.
|
8 |
+
|
9 |
+
|
10 |
+
## Usage with SentenceTransformers
|
11 |
+
|
12 |
+
When you have [SentenceTransformers](https://www.sbert.net/) installed, you can use the model like this:
|
13 |
+
```python
|
14 |
+
from sentence_transformers import CrossEncoder
|
15 |
+
model = CrossEncoder('model_name', max_length=512)
|
16 |
+
query = 'How many people live in Berlin?'
|
17 |
+
docs = ['Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.']
|
18 |
+
pairs = [(query, doc) for doc in docs]
|
19 |
+
scores = model.predict(pairs)
|
20 |
+
```
|
21 |
+
|
22 |
+
|
23 |
+
## Usage with Transformers
|
24 |
+
With the transformers library, you can use the model like this:
|
25 |
+
|
26 |
+
```python
|
27 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
28 |
+
import torch
|
29 |
+
|
30 |
+
model = AutoModelForSequenceClassification.from_pretrained('model_name')
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained('model_name')
|
32 |
+
|
33 |
+
features = tokenizer(['How many people live in Berlin?', 'How many people live in Berlin?'], ['Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.'], padding=True, truncation=True, return_tensors="pt")
|
34 |
+
|
35 |
+
model.eval()
|
36 |
+
with torch.no_grad():
|
37 |
+
scores = model(**features).logits
|
38 |
+
print(scores)
|
39 |
+
```
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
## Performance
|
45 |
+
The performance was evaluated on three datasets:
|
46 |
+
- **TREC-DL19 EN-EN**: The original [TREC 2019 Deep Learning Track](https://microsoft.github.io/msmarco/TREC-Deep-Learning-2019.html): Given an English query and 1000 documents (retrieved by BM25 lexical search), rank documents with according to their relevance. We compute NDCG@10. BM25 achieves a score of 45.46, a perfect re-ranker can achieve a score 95.47.
|
47 |
+
- **TREC-DL19 DE-EN**: The English queries of TREC-DL19 have been translated by a German native speaker to German. We rank the German queries versus the English passages from the original TREC-DL19 setup. We compute NDCG@10.
|
48 |
+
- **GermanDPR DE-DE**: The [GermanDPR](https://www.deepset.ai/germanquad) dataset provides German queries and German passages from Wikipedia. We indexed the 2.8 Million paragraphs from German Wikipedia and retrieved for each query the top 100 most relevant passages using BM25 lexical search with Elasticsearch. We compute MRR@10. BM25 achieves a score of 35.85, a perfect re-ranker can achieve a score 76.27.
|
49 |
+
|
50 |
+
We also check the performance of bi-encoders using the same evaluation: The retrieved documents from BM25 lexical search are re-ranked using query & passage embeddings with cosine-similarity. Bi-Encoders can also be used for end-to-end semantic search.
|
51 |
+
|
52 |
+
|
53 |
+
| Model-Name | TREC-DL19 EN-EN | TREC-DL19 DE-EN | GermanDPR DE-DE | Docs / Sec |
|
54 |
+
| ------------- |:-------------:| :-----: | :---: | :----: |
|
55 |
+
| BM25 | 45.46 | - | 35.85 | -|
|
56 |
+
| **Cross-Encoder Re-Rankers** | | | |
|
57 |
+
| [cross-encoder/msmarco-MiniLM-L6-en-de-v1](https://huggingface.co/cross-encoder/msmarco-MiniLM-L6-en-de-v1) | 72.43 | 65.53 | 46.77 | 1600 |
|
58 |
+
| [cross-encoder/msmarco-MiniLM-L12-en-de-v1](https://huggingface.co/cross-encoder/msmarco-MiniLM-L12-en-de-v1) | 72.94 | 66.07 | 49.91 | 900 |
|
59 |
+
| [svalabs/cross-electra-ms-marco-german-uncased](https://huggingface.co/svalabs/cross-electra-ms-marco-german-uncased) (DE only) | - | - | 53.67 | 260 |
|
60 |
+
| **Bi-Encoders (re-ranking)** | | | |
|
61 |
+
| [sentence-transformers/msmarco-distilbert-multilingual-en-de-v2-tmp-lng-aligned](https://huggingface.co/sentence-transformers/msmarco-distilbert-multilingual-en-de-v2-tmp-lng-aligned) | 63.38 | 58.28 | 37.88 | 940 |
|
62 |
+
| [sentence-transformers/msmarco-distilbert-multilingual-en-de-v2-tmp-trained-scratch](https://huggingface.co/sentence-transformers/msmarco-distilbert-multilingual-en-de-v2-tmp-trained-scratch) | 65.51 | 58.69 | 38.32 | 940 |
|
63 |
+
| [svalabs/bi-electra-ms-marco-german-uncased](https://huggingface.co/svalabs/bi-electra-ms-marco-german-uncased) (DE only) | - | - | 34.31 | 450 |
|
64 |
+
|
65 |
+
Note: Docs / Sec gives the number of (query, document) pairs we can re-rank within a second on a V100 GPU.
|
config.json
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "microsoft/Multilingual-MiniLM-L12-H384",
|
3 |
+
"architectures": [
|
4 |
+
"BertForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"gradient_checkpointing": false,
|
8 |
+
"hidden_act": "gelu",
|
9 |
+
"hidden_dropout_prob": 0.1,
|
10 |
+
"hidden_size": 384,
|
11 |
+
"id2label": {
|
12 |
+
"0": "LABEL_0"
|
13 |
+
},
|
14 |
+
"initializer_range": 0.02,
|
15 |
+
"intermediate_size": 1536,
|
16 |
+
"label2id": {
|
17 |
+
"LABEL_0": 0
|
18 |
+
},
|
19 |
+
"layer_norm_eps": 1e-12,
|
20 |
+
"max_position_embeddings": 512,
|
21 |
+
"model_type": "bert",
|
22 |
+
"num_attention_heads": 12,
|
23 |
+
"num_hidden_layers": 12,
|
24 |
+
"pad_token_id": 0,
|
25 |
+
"position_embedding_type": "absolute",
|
26 |
+
"tokenizer_class": "XLMRobertaTokenizer",
|
27 |
+
"transformers_version": "4.6.1",
|
28 |
+
"type_vocab_size": 2,
|
29 |
+
"use_cache": true,
|
30 |
+
"vocab_size": 250037,
|
31 |
+
"sbert_ce_default_activation_function": "torch.nn.modules.linear.Identity"
|
32 |
+
}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fd9a04351955aa9f56f6b22ca61b14d1a91588ca72ef2769a63e847928b8309f
|
3 |
+
size 470705929
|
sentencepiece.bpe.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cfc8146abe2a0488e9e2a0c56de7952f7c11ab059eca145a0a727afce0db2865
|
3 |
+
size 5069051
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": "<mask>"}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "special_tokens_map_file": "/root/.cache/huggingface/transformers/8ed73a1ab9ef4e90a9451497bf96cfc38d34354352838a143f2dda1c81aed5ca.0dc5b1041f62041ebbd23b1297f2f573769d5c97d8b7c28180ec86b8f6185aa8", "name_or_path": "microsoft/Multilingual-MiniLM-L12-H384", "sp_model_kwargs": {}}
|
train_script.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import random
|
3 |
+
|
4 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, AdamW
|
5 |
+
import sys
|
6 |
+
import torch
|
7 |
+
import transformers
|
8 |
+
from torch.utils.data import Dataset, DataLoader
|
9 |
+
from torch.cuda.amp import autocast
|
10 |
+
import tqdm
|
11 |
+
from datetime import datetime
|
12 |
+
from shutil import copyfile
|
13 |
+
import os
|
14 |
+
####################################
|
15 |
+
|
16 |
+
import gzip
|
17 |
+
from collections import defaultdict
|
18 |
+
import logging
|
19 |
+
import tqdm
|
20 |
+
import numpy as np
|
21 |
+
import sys
|
22 |
+
import pytrec_eval
|
23 |
+
from sentence_transformers import SentenceTransformer, util, CrossEncoder
|
24 |
+
import torch
|
25 |
+
|
26 |
+
|
27 |
+
model_name = sys.argv[1]
|
28 |
+
max_length = 350
|
29 |
+
|
30 |
+
######### Evaluation
|
31 |
+
queries_filepath = 'msmarco-data/trec2019/msmarco-test2019-queries.tsv.gz'
|
32 |
+
queries_eval = {}
|
33 |
+
with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn:
|
34 |
+
for line in fIn:
|
35 |
+
qid, query = line.strip().split("\t")[0:2]
|
36 |
+
queries_eval[qid] = query
|
37 |
+
|
38 |
+
rel = defaultdict(lambda: defaultdict(int))
|
39 |
+
|
40 |
+
with open('msmarco-data/trec2019/2019qrels-pass.txt') as fIn:
|
41 |
+
for line in fIn:
|
42 |
+
qid, _, pid, score = line.strip().split()
|
43 |
+
score = int(score)
|
44 |
+
if score > 0:
|
45 |
+
rel[qid][pid] = score
|
46 |
+
|
47 |
+
relevant_qid = []
|
48 |
+
for qid in queries_eval:
|
49 |
+
if len(rel[qid]) > 0:
|
50 |
+
relevant_qid.append(qid)
|
51 |
+
|
52 |
+
# Read top 1k
|
53 |
+
passage_cand = {}
|
54 |
+
|
55 |
+
with gzip.open('msmarco-data/trec2019/msmarco-passagetest2019-top1000.tsv.gz', 'rt', encoding='utf8') as fIn:
|
56 |
+
for line in fIn:
|
57 |
+
qid, pid, query, passage = line.strip().split("\t")
|
58 |
+
if qid not in passage_cand:
|
59 |
+
passage_cand[qid] = []
|
60 |
+
|
61 |
+
passage_cand[qid].append([pid, passage])
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
def eval_modal(model_path):
|
66 |
+
run = {}
|
67 |
+
model = CrossEncoder(model_path, max_length=512)
|
68 |
+
|
69 |
+
for qid in relevant_qid:
|
70 |
+
query = queries_eval[qid]
|
71 |
+
|
72 |
+
cand = passage_cand[qid]
|
73 |
+
pids = [c[0] for c in cand]
|
74 |
+
corpus_sentences = [c[1] for c in cand]
|
75 |
+
|
76 |
+
## CrossEncoder
|
77 |
+
cross_inp = [[query, sent] for sent in corpus_sentences]
|
78 |
+
if model.config.num_labels > 1:
|
79 |
+
cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
|
80 |
+
else:
|
81 |
+
cross_scores = model.predict(cross_inp, activation_fct=torch.nn.Identity()).tolist()
|
82 |
+
|
83 |
+
cross_scores_sparse = {}
|
84 |
+
for idx, pid in enumerate(pids):
|
85 |
+
cross_scores_sparse[pid] = cross_scores[idx]
|
86 |
+
|
87 |
+
sparse_scores = cross_scores_sparse
|
88 |
+
run[qid] = {}
|
89 |
+
for pid in sparse_scores:
|
90 |
+
run[qid][pid] = float(sparse_scores[pid])
|
91 |
+
|
92 |
+
evaluator = pytrec_eval.RelevanceEvaluator(rel, {'ndcg_cut.10'})
|
93 |
+
scores = evaluator.evaluate(run)
|
94 |
+
scores_mean = np.mean([ele["ndcg_cut_10"] for ele in scores.values()])
|
95 |
+
|
96 |
+
print("NDCG@10: {:.2f}".format(scores_mean * 100))
|
97 |
+
return scores_mean
|
98 |
+
|
99 |
+
################################
|
100 |
+
|
101 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
102 |
+
config = AutoConfig.from_pretrained(model_name)
|
103 |
+
config.num_labels = 1
|
104 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
|
105 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
#######################
|
111 |
+
|
112 |
+
queries = {}
|
113 |
+
corpus = {}
|
114 |
+
|
115 |
+
output_save_path = 'output/train_cross-encoder_mse-{}-{}'.format(model_name.replace("/", "_"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
|
116 |
+
output_save_path_latest = output_save_path+"-latest"
|
117 |
+
tokenizer.save_pretrained(output_save_path)
|
118 |
+
tokenizer.save_pretrained(output_save_path_latest)
|
119 |
+
|
120 |
+
|
121 |
+
# Write self to path
|
122 |
+
train_script_path = os.path.join(output_save_path, 'train_script.py')
|
123 |
+
copyfile(__file__, train_script_path)
|
124 |
+
with open(train_script_path, 'a') as fOut:
|
125 |
+
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
126 |
+
|
127 |
+
|
128 |
+
####
|
129 |
+
train_script_path = os.path.join(output_save_path_latest, 'train_script.py')
|
130 |
+
copyfile(__file__, train_script_path)
|
131 |
+
with open(train_script_path, 'a') as fOut:
|
132 |
+
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
#### Read train files
|
137 |
+
class MultilingualDataset(Dataset):
|
138 |
+
def __init__(self):
|
139 |
+
self.examples = defaultdict(lambda: defaultdict(list)) #[id][lang] => [samples...]
|
140 |
+
|
141 |
+
def add(self, lang, filepath):
|
142 |
+
open_method = gzip.open if filepath.endswith('.gz') else open
|
143 |
+
with open_method(filepath, 'rt') as fIn:
|
144 |
+
for line in fIn:
|
145 |
+
pid, passage = line.strip().split("\t")
|
146 |
+
self.examples[pid][lang].append(passage)
|
147 |
+
|
148 |
+
|
149 |
+
def __len__(self):
|
150 |
+
return len(self.examples)
|
151 |
+
|
152 |
+
def __getitem__(self, item):
|
153 |
+
all_examples = self.examples[item] #All examples in all languages
|
154 |
+
lang_examples = random.choice(list(all_examples.values())) #Examples in on specific language
|
155 |
+
return random.choice(lang_examples) #One random example
|
156 |
+
|
157 |
+
|
158 |
+
train_corpus = MultilingualDataset()
|
159 |
+
train_corpus.add('en', 'msmarco-data/collection.tsv')
|
160 |
+
train_corpus.add('de', 'msmarco-data/de/collection.de.opus-mt.tsv.gz')
|
161 |
+
train_corpus.add('de', 'msmarco-data/de/collection.de.wmt19.tsv.gz')
|
162 |
+
|
163 |
+
|
164 |
+
train_queries = MultilingualDataset()
|
165 |
+
train_queries.add('en', 'msmarco-data/queries.train.tsv')
|
166 |
+
train_queries.add('de', 'msmarco-data/de/queries.train.de.opus-mt.tsv.gz')
|
167 |
+
train_queries.add('de', 'msmarco-data/de/queries.train.de.wmt19.tsv.gz')
|
168 |
+
|
169 |
+
############## MSE Dataset
|
170 |
+
class MSEDataset(Dataset):
|
171 |
+
def __init__(self, filepath):
|
172 |
+
super().__init__()
|
173 |
+
|
174 |
+
self.examples = []
|
175 |
+
with open(filepath) as fIn:
|
176 |
+
for line in fIn:
|
177 |
+
pos_score, neg_score, qid, pid1, pid2 = line.strip().split("\t")
|
178 |
+
self.examples.append([qid, pid1, pid2, float(pos_score)-float(neg_score)])
|
179 |
+
|
180 |
+
def __len__(self):
|
181 |
+
return len(self.examples)
|
182 |
+
|
183 |
+
def __getitem__(self, item):
|
184 |
+
return self.examples[item]
|
185 |
+
|
186 |
+
train_batch_size = 16
|
187 |
+
train_dataset = MSEDataset('msmarco-data/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv')
|
188 |
+
train_dataloader = DataLoader(train_dataset, drop_last=True, shuffle=True, batch_size=train_batch_size)
|
189 |
+
|
190 |
+
|
191 |
+
############## Optimizer
|
192 |
+
|
193 |
+
weight_decay = 0.01
|
194 |
+
max_grad_norm = 1
|
195 |
+
param_optimizer = list(model.named_parameters())
|
196 |
+
|
197 |
+
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
|
198 |
+
optimizer_grouped_parameters = [
|
199 |
+
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
|
200 |
+
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
201 |
+
]
|
202 |
+
|
203 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)
|
204 |
+
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=len(train_dataloader))
|
205 |
+
scaler = torch.cuda.amp.GradScaler()
|
206 |
+
|
207 |
+
loss_fct = torch.nn.MSELoss()
|
208 |
+
### Start training
|
209 |
+
model.to(device)
|
210 |
+
|
211 |
+
auto_save = 10000
|
212 |
+
best_ndcg_score = 0
|
213 |
+
for step_idx, batch in tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
|
214 |
+
batch_queries = [train_queries[qid] for qid in batch[0]]
|
215 |
+
batch_pos = [train_corpus[cid] for cid in batch[1]]
|
216 |
+
batch_neg = [train_corpus[cid] for cid in batch[2]]
|
217 |
+
scores = batch[3].float().to(device) #torch.tensor(batch[3], dtype=torch.float, device=device)
|
218 |
+
|
219 |
+
with autocast():
|
220 |
+
inp_pos = tokenizer(batch_queries, batch_pos, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device)
|
221 |
+
pred_pos = model(**inp_pos).logits.squeeze()
|
222 |
+
|
223 |
+
inp_neg = tokenizer(batch_queries, batch_neg, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device)
|
224 |
+
pred_neg = model(**inp_neg).logits.squeeze()
|
225 |
+
|
226 |
+
pred_diff = pred_pos - pred_neg
|
227 |
+
loss_value = loss_fct(pred_diff, scores)
|
228 |
+
|
229 |
+
|
230 |
+
scaler.scale(loss_value).backward()
|
231 |
+
scaler.unscale_(optimizer)
|
232 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
233 |
+
scaler.step(optimizer)
|
234 |
+
scaler.update()
|
235 |
+
|
236 |
+
optimizer.zero_grad()
|
237 |
+
scheduler.step()
|
238 |
+
|
239 |
+
if (step_idx+1) % auto_save == 0:
|
240 |
+
print("Step:", step_idx+1)
|
241 |
+
model.save_pretrained(output_save_path_latest)
|
242 |
+
ndcg_score = eval_modal(output_save_path_latest)
|
243 |
+
|
244 |
+
if ndcg_score >= best_ndcg_score:
|
245 |
+
best_ndcg_score = ndcg_score
|
246 |
+
print("Save to:", output_save_path)
|
247 |
+
model.save_pretrained(output_save_path)
|
248 |
+
|
249 |
+
model.save_pretrained(output_save_path)
|
250 |
+
|
251 |
+
|
252 |
+
# Script was called via:
|
253 |
+
#python train_cross-encoder_mse_multilingual.py microsoft/Multilingual-MiniLM-L12-H384
|