nreimers
commited on
Commit
•
67d6b31
1
Parent(s):
d0a17bf
upload
Browse files- CERerankingEvaluator_results.csv +127 -0
- README.md +34 -0
- config.json +27 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +1 -0
- tokenizer_config.json +1 -0
- train_script.py +193 -0
- vocab.txt +0 -0
CERerankingEvaluator_results.csv
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
epoch,steps,MRR@10
|
2 |
+
0,5000,0.5650095238095237
|
3 |
+
0,10000,0.5849968253968254
|
4 |
+
0,15000,0.6097650793650794
|
5 |
+
0,20000,0.6246285714285715
|
6 |
+
0,25000,0.6100253968253967
|
7 |
+
0,30000,0.6270730158730159
|
8 |
+
0,35000,0.6138888888888888
|
9 |
+
0,40000,0.6240317460317462
|
10 |
+
0,45000,0.6327619047619049
|
11 |
+
0,50000,0.619631746031746
|
12 |
+
0,55000,0.5871142857142856
|
13 |
+
0,60000,0.6175809523809525
|
14 |
+
0,65000,0.6081968253968254
|
15 |
+
0,70000,0.6151301587301587
|
16 |
+
0,75000,0.6093269841269842
|
17 |
+
0,80000,0.6032571428571428
|
18 |
+
0,85000,0.6138063492063491
|
19 |
+
0,90000,0.6156952380952381
|
20 |
+
0,95000,0.6303523809523809
|
21 |
+
0,100000,0.6061523809523809
|
22 |
+
0,105000,0.6133174603174603
|
23 |
+
0,110000,0.6226063492063493
|
24 |
+
0,115000,0.6176349206349206
|
25 |
+
0,120000,0.6104761904761905
|
26 |
+
0,125000,0.6332253968253967
|
27 |
+
0,130000,0.6289523809523808
|
28 |
+
0,135000,0.6181809523809524
|
29 |
+
0,140000,0.6399841269841271
|
30 |
+
0,145000,0.623073015873016
|
31 |
+
0,150000,0.5963587301587302
|
32 |
+
0,155000,0.6157301587301588
|
33 |
+
0,160000,0.613120634920635
|
34 |
+
0,165000,0.6089936507936508
|
35 |
+
0,170000,0.6203301587301587
|
36 |
+
0,175000,0.6171269841269841
|
37 |
+
0,180000,0.5939269841269841
|
38 |
+
0,185000,0.6417873015873015
|
39 |
+
0,190000,0.6164476190476191
|
40 |
+
0,195000,0.6215841269841269
|
41 |
+
0,200000,0.6298984126984126
|
42 |
+
0,205000,0.6030507936507936
|
43 |
+
0,210000,0.6084730158730158
|
44 |
+
0,215000,0.6092730158730159
|
45 |
+
0,220000,0.5939650793650793
|
46 |
+
0,225000,0.6124190476190475
|
47 |
+
0,230000,0.6039269841269841
|
48 |
+
0,235000,0.6253301587301587
|
49 |
+
0,240000,0.634904761904762
|
50 |
+
0,245000,0.6317015873015873
|
51 |
+
0,250000,0.6196603174603175
|
52 |
+
0,255000,0.6287396825396825
|
53 |
+
0,260000,0.6095746031746031
|
54 |
+
0,265000,0.6263492063492063
|
55 |
+
0,270000,0.6171079365079365
|
56 |
+
0,275000,0.6289523809523809
|
57 |
+
0,280000,0.6202634920634921
|
58 |
+
0,285000,0.6255301587301587
|
59 |
+
0,290000,0.5993841269841268
|
60 |
+
0,295000,0.6191841269841271
|
61 |
+
0,300000,0.6203396825396825
|
62 |
+
0,305000,0.6128412698412699
|
63 |
+
0,310000,0.6090825396825398
|
64 |
+
0,315000,0.5950539682539682
|
65 |
+
0,320000,0.5990444444444444
|
66 |
+
0,325000,0.6042412698412698
|
67 |
+
0,330000,0.5960190476190476
|
68 |
+
0,335000,0.6106222222222223
|
69 |
+
0,340000,0.6055968253968255
|
70 |
+
0,345000,0.5984095238095238
|
71 |
+
0,350000,0.6142984126984128
|
72 |
+
0,355000,0.6137746031746032
|
73 |
+
0,360000,0.6018412698412698
|
74 |
+
0,365000,0.6123079365079365
|
75 |
+
0,370000,0.6130285714285715
|
76 |
+
0,375000,0.6008412698412698
|
77 |
+
0,380000,0.6020698412698412
|
78 |
+
0,385000,0.6100222222222222
|
79 |
+
0,390000,0.5971650793650793
|
80 |
+
0,395000,0.5941968253968255
|
81 |
+
0,400000,0.5871428571428571
|
82 |
+
0,405000,0.6100190476190476
|
83 |
+
0,410000,0.5903174603174602
|
84 |
+
0,415000,0.5988317460317459
|
85 |
+
0,420000,0.6132380952380952
|
86 |
+
0,425000,0.6144412698412698
|
87 |
+
0,430000,0.5980888888888888
|
88 |
+
0,435000,0.5973746031746032
|
89 |
+
0,440000,0.595384126984127
|
90 |
+
0,445000,0.5871714285714286
|
91 |
+
0,450000,0.6012412698412699
|
92 |
+
0,455000,0.5873047619047618
|
93 |
+
0,460000,0.595584126984127
|
94 |
+
0,465000,0.5804285714285713
|
95 |
+
0,470000,0.5887619047619047
|
96 |
+
0,475000,0.5872761904761904
|
97 |
+
0,480000,0.5871396825396825
|
98 |
+
0,485000,0.5907174603174602
|
99 |
+
0,490000,0.5880412698412699
|
100 |
+
0,495000,0.5807968253968254
|
101 |
+
0,500000,0.5909746031746032
|
102 |
+
0,505000,0.5912984126984128
|
103 |
+
0,510000,0.5942761904761905
|
104 |
+
0,515000,0.5840222222222223
|
105 |
+
0,520000,0.5852380952380952
|
106 |
+
0,525000,0.582784126984127
|
107 |
+
0,530000,0.5916190476190476
|
108 |
+
0,535000,0.5777269841269841
|
109 |
+
0,540000,0.582120634920635
|
110 |
+
0,545000,0.5746634920634921
|
111 |
+
0,550000,0.5746444444444445
|
112 |
+
0,555000,0.5632444444444444
|
113 |
+
0,560000,0.5799650793650795
|
114 |
+
0,565000,0.5932507936507936
|
115 |
+
0,570000,0.5816190476190476
|
116 |
+
0,575000,0.5838857142857143
|
117 |
+
0,580000,0.5859650793650794
|
118 |
+
0,585000,0.5843968253968255
|
119 |
+
0,590000,0.5840634920634921
|
120 |
+
0,595000,0.5958285714285714
|
121 |
+
0,600000,0.5842857142857142
|
122 |
+
0,605000,0.5892507936507937
|
123 |
+
0,610000,0.5914507936507937
|
124 |
+
0,615000,0.5953968253968254
|
125 |
+
0,620000,0.5925174603174603
|
126 |
+
0,625000,0.5890857142857143
|
127 |
+
0,-1,0.5890857142857143
|
README.md
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Cross-Encoder for MS Marco
|
2 |
+
|
3 |
+
This model uses [TinyBERT](https://github.com/huawei-noah/Pretrained-Language-Model/tree/master/TinyBERT), a tiny BERT model with only 6 layers. The base model is: General_TinyBERT_v2(6layer-768dim)
|
4 |
+
|
5 |
+
It was trained on [MS Marco Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) task.
|
6 |
+
|
7 |
+
The model can be used for Information Retrieval: Given a query, encode the query will all possible passages (e.g. retrieved with ElasticSearch). Then sort the passages in a decreasing order. See [SBERT.net Information Retrieval](https://github.com/UKPLab/sentence-transformers/tree/master/examples/applications/information-retrieval) for more details. The training code is available here: [SBERT.net Training MS Marco](https://github.com/UKPLab/sentence-transformers/tree/master/examples/training/ms_marco)
|
8 |
+
|
9 |
+
## Usage and Performance
|
10 |
+
|
11 |
+
Pre-trained models can be used like this:
|
12 |
+
```
|
13 |
+
from sentence_transformers import CrossEncoder
|
14 |
+
model = CrossEncoder('model_name', max_length=512)
|
15 |
+
scores = model.predict([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')])
|
16 |
+
```
|
17 |
+
|
18 |
+
In the following table, we provide various pre-trained Cross-Encoders together with their performance on the [TREC Deep Learning 2019](https://microsoft.github.io/TREC-2019-Deep-Learning/) and the [MS Marco Passage Reranking](https://github.com/microsoft/MSMARCO-Passage-Ranking/) dataset.
|
19 |
+
|
20 |
+
|
21 |
+
| Model-Name | NDCG@10 (TREC DL 19) | MRR@10 (MS Marco Dev) | Docs / Sec (BertTokenizerFast) | Docs / Sec |
|
22 |
+
| ------------- |:-------------| -----| --- | --- |
|
23 |
+
| cross-encoder/ms-marco-TinyBERT-L-2 | 67.43 | 30.15 | 9000 | 780
|
24 |
+
| cross-encoder/ms-marco-TinyBERT-L-4 | 68.09 | 34.50 | 2900 | 760
|
25 |
+
| cross-encoder/ms-marco-TinyBERT-L-6 | 69.57 | 36.13 | 680 | 660
|
26 |
+
| cross-encoder/ms-marco-electra-base | 71.99 | 36.41 | 340 | 340
|
27 |
+
| *Other models* | | | |
|
28 |
+
| nboost/pt-tinybert-msmarco | 63.63 | 28.80 | 2900 | 760
|
29 |
+
| nboost/pt-bert-base-uncased-msmarco | 70.94 | 34.75 | 340 | 340|
|
30 |
+
| nboost/pt-bert-large-msmarco | 73.36 | 36.48 | 100 | 100 |
|
31 |
+
| Capreolus/electra-base-msmarco | 71.23 | | 340 | 340 |
|
32 |
+
| amberoad/bert-multilingual-passage-reranking-msmarco | 68.40 | | 330 | 330
|
33 |
+
|
34 |
+
Note: Runtime was computed on a V100 GPU. A bottleneck for smaller models is the standard Python tokenizer from Huggingface v3. Replacing it with the fast tokenizer based on Rust, the throughput is significantly improved:
|
config.json
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "nreimers/TinyBERT_L-6_H-768_v2",
|
3 |
+
"architectures": [
|
4 |
+
"BertForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"gradient_checkpointing": false,
|
8 |
+
"hidden_act": "gelu",
|
9 |
+
"hidden_dropout_prob": 0.1,
|
10 |
+
"hidden_size": 768,
|
11 |
+
"id2label": {
|
12 |
+
"0": "LABEL_0"
|
13 |
+
},
|
14 |
+
"initializer_range": 0.02,
|
15 |
+
"intermediate_size": 3072,
|
16 |
+
"label2id": {
|
17 |
+
"LABEL_0": 0
|
18 |
+
},
|
19 |
+
"layer_norm_eps": 1e-12,
|
20 |
+
"max_position_embeddings": 512,
|
21 |
+
"model_type": "bert",
|
22 |
+
"num_attention_heads": 12,
|
23 |
+
"num_hidden_layers": 6,
|
24 |
+
"pad_token_id": 0,
|
25 |
+
"type_vocab_size": 2,
|
26 |
+
"vocab_size": 30522
|
27 |
+
}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0cf70691b32e33dc1a57845ceeeb81bc70cfec62a0d584d063f13576403f2759
|
3 |
+
size 267871721
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"do_lower_case": true, "do_basic_tokenize": true, "never_split": null, "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": "/home/ukp-reimers/.cache/torch/transformers/68bd39c5d90b3f2d8da930bf3efe6ad55d21ae39cfbf97d37817cce8149bf2f3.dd8bd9bfd3664b530ea4e645105f557769387b3da9f79bdb55ed556bdd80611d", "tokenizer_file": null, "name_or_path": "nreimers/TinyBERT_L-6_H-768_v2"}
|
train_script.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
from sentence_transformers import LoggingHandler
|
3 |
+
from sentence_transformers.cross_encoder import CrossEncoder
|
4 |
+
from sentence_transformers.cross_encoder.evaluation import CEBinaryClassificationEvaluator
|
5 |
+
from sentence_transformers import InputExample
|
6 |
+
import logging
|
7 |
+
from datetime import datetime
|
8 |
+
import gzip
|
9 |
+
import sys
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
from shutil import copyfile
|
13 |
+
import csv
|
14 |
+
import tqdm
|
15 |
+
|
16 |
+
#### Just some code to print debug information to stdout
|
17 |
+
logging.basicConfig(format='%(asctime)s - %(message)s',
|
18 |
+
datefmt='%Y-%m-%d %H:%M:%S',
|
19 |
+
level=logging.INFO,
|
20 |
+
handlers=[LoggingHandler()])
|
21 |
+
#### /print debug information to stdout
|
22 |
+
|
23 |
+
|
24 |
+
#Define our Cross-Encoder
|
25 |
+
model_name = sys.argv[1] #'google/electra-small-discriminator'
|
26 |
+
train_batch_size = 32
|
27 |
+
num_epochs = 1
|
28 |
+
model_save_path = 'output/training_ms-marco_cross-encoder-'+model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
29 |
+
|
30 |
+
#We set num_labels=1, which predicts a continous score between 0 and 1
|
31 |
+
model = CrossEncoder(model_name, num_labels=1, max_length=512)
|
32 |
+
|
33 |
+
|
34 |
+
# Write self to path
|
35 |
+
os.makedirs(model_save_path, exist_ok=True)
|
36 |
+
|
37 |
+
train_script_path = os.path.join(model_save_path, 'train_script.py')
|
38 |
+
copyfile(__file__, train_script_path)
|
39 |
+
with open(train_script_path, 'a') as fOut:
|
40 |
+
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
41 |
+
|
42 |
+
|
43 |
+
corpus = {}
|
44 |
+
queries = {}
|
45 |
+
|
46 |
+
#### Read train file
|
47 |
+
with gzip.open('../data/collection.tsv.gz', 'rt') as fIn:
|
48 |
+
for line in fIn:
|
49 |
+
pid, passage = line.strip().split("\t")
|
50 |
+
corpus[pid] = passage
|
51 |
+
|
52 |
+
with open('../data/queries.train.tsv', 'r') as fIn:
|
53 |
+
for line in fIn:
|
54 |
+
qid, query = line.strip().split("\t")
|
55 |
+
queries[qid] = query
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
pos_neg_ration = (4+1)
|
60 |
+
cnt = 0
|
61 |
+
train_samples = []
|
62 |
+
dev_samples = {}
|
63 |
+
|
64 |
+
num_dev_queries = 125
|
65 |
+
num_max_dev_negatives = 200
|
66 |
+
|
67 |
+
with gzip.open('../data/qidpidtriples.rnd-shuf.train-eval.tsv.gz', 'rt') as fIn:
|
68 |
+
for line in fIn:
|
69 |
+
qid, pos_id, neg_id = line.strip().split()
|
70 |
+
|
71 |
+
if qid not in dev_samples and len(dev_samples) < num_dev_queries:
|
72 |
+
dev_samples[qid] = {'query': queries[qid], 'positive': set(), 'negative': set()}
|
73 |
+
|
74 |
+
if qid in dev_samples:
|
75 |
+
dev_samples[qid]['positive'].add(corpus[pos_id])
|
76 |
+
|
77 |
+
if len(dev_samples[qid]['negative']) < num_max_dev_negatives:
|
78 |
+
dev_samples[qid]['negative'].add(corpus[neg_id])
|
79 |
+
|
80 |
+
with gzip.open('../data/qidpidtriples.rnd-shuf.train.tsv.gz', 'rt') as fIn:
|
81 |
+
for line in tqdm.tqdm(fIn, unit_scale=True):
|
82 |
+
cnt += 1
|
83 |
+
qid, pos_id, neg_id = line.strip().split()
|
84 |
+
query = queries[qid]
|
85 |
+
if (cnt % pos_neg_ration) == 0:
|
86 |
+
passage = corpus[pos_id]
|
87 |
+
label = 1
|
88 |
+
else:
|
89 |
+
passage = corpus[neg_id]
|
90 |
+
label = 0
|
91 |
+
|
92 |
+
train_samples.append(InputExample(texts=[query, passage], label=label))
|
93 |
+
|
94 |
+
if len(train_samples) >= 2e7:
|
95 |
+
break
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
|
100 |
+
|
101 |
+
# We add an evaluator, which evaluates the performance during training
|
102 |
+
|
103 |
+
class CERerankingEvaluator:
|
104 |
+
def __init__(self, samples, mrr_at_k: int = 10, name: str = ''):
|
105 |
+
self.samples = samples
|
106 |
+
self.name = name
|
107 |
+
self.mrr_at_k = mrr_at_k
|
108 |
+
|
109 |
+
if isinstance(self.samples, dict):
|
110 |
+
self.samples = list(self.samples.values())
|
111 |
+
|
112 |
+
self.csv_file = "CERerankingEvaluator" + ("_" + name if name else '') + "_results.csv"
|
113 |
+
self.csv_headers = ["epoch", "steps", "MRR@{}".format(mrr_at_k)]
|
114 |
+
|
115 |
+
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
|
116 |
+
if epoch != -1:
|
117 |
+
if steps == -1:
|
118 |
+
out_txt = " after epoch {}:".format(epoch)
|
119 |
+
else:
|
120 |
+
out_txt = " in epoch {} after {} steps:".format(epoch, steps)
|
121 |
+
else:
|
122 |
+
out_txt = ":"
|
123 |
+
|
124 |
+
logging.info("CERerankingEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)
|
125 |
+
|
126 |
+
all_mrr_scores = []
|
127 |
+
num_queries = 0
|
128 |
+
num_positives = []
|
129 |
+
num_negatives = []
|
130 |
+
for instance in self.samples:
|
131 |
+
query = instance['query']
|
132 |
+
positive = list(instance['positive'])
|
133 |
+
negative = list(instance['negative'])
|
134 |
+
docs = positive + negative
|
135 |
+
is_relevant = [True]*len(positive) + [False]*len(negative)
|
136 |
+
|
137 |
+
if len(positive) == 0 or len(negative) == 0:
|
138 |
+
continue
|
139 |
+
|
140 |
+
num_queries += 1
|
141 |
+
num_positives.append(len(positive))
|
142 |
+
num_negatives.append(len(negative))
|
143 |
+
|
144 |
+
model_input = [[query, doc] for doc in docs]
|
145 |
+
pred_scores = model.predict(model_input, convert_to_numpy=True, show_progress_bar=False)
|
146 |
+
pred_scores_argsort = np.argsort(-pred_scores) #Sort in decreasing order
|
147 |
+
|
148 |
+
mrr_score = 0
|
149 |
+
for rank, index in enumerate(pred_scores_argsort[0:self.mrr_at_k]):
|
150 |
+
if is_relevant[index]:
|
151 |
+
mrr_score = 1 / (rank+1)
|
152 |
+
|
153 |
+
all_mrr_scores.append(mrr_score)
|
154 |
+
|
155 |
+
mean_mrr = np.mean(all_mrr_scores)
|
156 |
+
logging.info("Queries: {} \t Positives: Min {:.1f}, Mean {:.1f}, Max {:.1f} \t Negatives: Min {:.1f}, Mean {:.1f}, Max {:.1f}".format(num_queries, np.min(num_positives), np.mean(num_positives), np.max(num_positives), np.min(num_negatives), np.mean(num_negatives), np.max(num_negatives)))
|
157 |
+
logging.info("MRR@{}: {:.2f}".format(self.mrr_at_k, mean_mrr*100))
|
158 |
+
|
159 |
+
if output_path is not None:
|
160 |
+
csv_path = os.path.join(output_path, self.csv_file)
|
161 |
+
output_file_exists = os.path.isfile(csv_path)
|
162 |
+
with open(csv_path, mode="a" if output_file_exists else 'w', encoding="utf-8") as f:
|
163 |
+
writer = csv.writer(f)
|
164 |
+
if not output_file_exists:
|
165 |
+
writer.writerow(self.csv_headers)
|
166 |
+
|
167 |
+
writer.writerow([epoch, steps, mean_mrr])
|
168 |
+
|
169 |
+
return mean_mrr
|
170 |
+
|
171 |
+
|
172 |
+
evaluator = CERerankingEvaluator(dev_samples)
|
173 |
+
|
174 |
+
# Configure the training
|
175 |
+
warmup_steps = 5000
|
176 |
+
logging.info("Warmup-steps: {}".format(warmup_steps))
|
177 |
+
|
178 |
+
|
179 |
+
# Train the model
|
180 |
+
model.fit(train_dataloader=train_dataloader,
|
181 |
+
evaluator=evaluator,
|
182 |
+
epochs=num_epochs,
|
183 |
+
evaluation_steps=5000,
|
184 |
+
warmup_steps=warmup_steps,
|
185 |
+
output_path=model_save_path,
|
186 |
+
use_amp=True)
|
187 |
+
|
188 |
+
#Save latest model
|
189 |
+
model.save(model_save_path+'-latest')
|
190 |
+
|
191 |
+
|
192 |
+
# Script was called via:
|
193 |
+
#python train_cross-encoder.py nreimers/TinyBERT_L-6_H-768_v2
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|