nreimers commited on
Commit
2054d27
·
1 Parent(s): c542599
README.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cross-Encoder for MS MARCO - EN-DE
2
+
3
+ This is a cross-lingual Cross-Encoder model for EN-DE that can be used for passage re-ranking. It was trained on the [MS Marco Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) task.
4
+
5
+ The model can be used for Information Retrieval: See [SBERT.net Retrieve & Re-rank](https://www.sbert.net/examples/applications/retrieve_rerank/README.html).
6
+
7
+ The training code is available in this repository, see `train_script.py`.
8
+
9
+
10
+ ## Usage with SentenceTransformers
11
+
12
+ When you have [SentenceTransformers](https://www.sbert.net/) installed, you can use the model like this:
13
+ ```python
14
+ from sentence_transformers import CrossEncoder
15
+ model = CrossEncoder('model_name', max_length=512)
16
+ query = 'How many people live in Berlin?'
17
+ docs = ['Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.']
18
+ pairs = [(query, doc) for doc in docs]
19
+ scores = model.predict(pairs)
20
+ ```
21
+
22
+
23
+ ## Usage with Transformers
24
+ With the transformers library, you can use the model like this:
25
+
26
+ ```python
27
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
28
+ import torch
29
+
30
+ model = AutoModelForSequenceClassification.from_pretrained('model_name')
31
+ tokenizer = AutoTokenizer.from_pretrained('model_name')
32
+
33
+ features = tokenizer(['How many people live in Berlin?', 'How many people live in Berlin?'], ['Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.'], padding=True, truncation=True, return_tensors="pt")
34
+
35
+ model.eval()
36
+ with torch.no_grad():
37
+ scores = model(**features).logits
38
+ print(scores)
39
+ ```
40
+
41
+
42
+
43
+
44
+ ## Performance
45
+ The performance was evaluated on three datasets:
46
+ - **TREC-DL19 EN-EN**: The original [TREC 2019 Deep Learning Track](https://microsoft.github.io/msmarco/TREC-Deep-Learning-2019.html): Given an English query and 1000 documents (retrieved by BM25 lexical search), rank documents with according to their relevance. We compute NDCG@10. BM25 achieves a score of 45.46, a perfect re-ranker can achieve a score 95.47.
47
+ - **TREC-DL19 DE-EN**: The English queries of TREC-DL19 have been translated by a German native speaker to German. We rank the German queries versus the English passages from the original TREC-DL19 setup. We compute NDCG@10.
48
+ - **GermanDPR DE-DE**: The [GermanDPR](https://www.deepset.ai/germanquad) dataset provides German queries and German passages from Wikipedia. We indexed the 2.8 Million paragraphs from German Wikipedia and retrieved for each query the top 100 most relevant passages using BM25 lexical search with Elasticsearch. We compute MRR@10. BM25 achieves a score of 35.85, a perfect re-ranker can achieve a score 76.27.
49
+
50
+ We also check the performance of bi-encoders using the same evaluation: The retrieved documents from BM25 lexical search are re-ranked using query & passage embeddings with cosine-similarity. Bi-Encoders can also be used for end-to-end semantic search.
51
+
52
+
53
+ | Model-Name | TREC-DL19 EN-EN | TREC-DL19 DE-EN | GermanDPR DE-DE | Docs / Sec |
54
+ | ------------- |:-------------:| :-----: | :---: | :----: |
55
+ | BM25 | 45.46 | - | 35.85 | -|
56
+ | **Cross-Encoder Re-Rankers** | | | |
57
+ | [cross-encoder/msmarco-MiniLM-L6-en-de-v1](https://huggingface.co/cross-encoder/msmarco-MiniLM-L6-en-de-v1) | 72.43 | 65.53 | 46.77 | 1600 |
58
+ | [cross-encoder/msmarco-MiniLM-L12-en-de-v1](https://huggingface.co/cross-encoder/msmarco-MiniLM-L12-en-de-v1) | 72.94 | 66.07 | 49.91 | 900 |
59
+ | [svalabs/cross-electra-ms-marco-german-uncased](https://huggingface.co/svalabs/cross-electra-ms-marco-german-uncased) (DE only) | - | - | 53.67 | 260 |
60
+ | **Bi-Encoders (re-ranking)** | | | |
61
+ | [sentence-transformers/msmarco-distilbert-multilingual-en-de-v2-tmp-lng-aligned](https://huggingface.co/sentence-transformers/msmarco-distilbert-multilingual-en-de-v2-tmp-lng-aligned) | 63.38 | 58.28 | 37.88 | 940 |
62
+ | [sentence-transformers/msmarco-distilbert-multilingual-en-de-v2-tmp-trained-scratch](https://huggingface.co/sentence-transformers/msmarco-distilbert-multilingual-en-de-v2-tmp-trained-scratch) | 65.51 | 58.69 | 38.32 | 940 |
63
+ | [svalabs/bi-electra-ms-marco-german-uncased](https://huggingface.co/svalabs/bi-electra-ms-marco-german-uncased) (DE only) | - | - | 34.31 | 450 |
64
+
65
+ Note: Docs / Sec gives the number of (query, document) pairs we can re-rank within a second on a V100 GPU.
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "microsoft/Multilingual-MiniLM-L12-H384",
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "gradient_checkpointing": false,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.1,
10
+ "hidden_size": 384,
11
+ "id2label": {
12
+ "0": "LABEL_0"
13
+ },
14
+ "initializer_range": 0.02,
15
+ "intermediate_size": 1536,
16
+ "label2id": {
17
+ "LABEL_0": 0
18
+ },
19
+ "layer_norm_eps": 1e-12,
20
+ "max_position_embeddings": 512,
21
+ "model_type": "bert",
22
+ "num_attention_heads": 12,
23
+ "num_hidden_layers": 12,
24
+ "pad_token_id": 0,
25
+ "position_embedding_type": "absolute",
26
+ "tokenizer_class": "XLMRobertaTokenizer",
27
+ "transformers_version": "4.6.1",
28
+ "type_vocab_size": 2,
29
+ "use_cache": true,
30
+ "vocab_size": 250037,
31
+ "sbert_ce_default_activation_function": "torch.nn.modules.linear.Identity"
32
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd9a04351955aa9f56f6b22ca61b14d1a91588ca72ef2769a63e847928b8309f
3
+ size 470705929
sentencepiece.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfc8146abe2a0488e9e2a0c56de7952f7c11ab059eca145a0a727afce0db2865
3
+ size 5069051
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": "<mask>"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "<unk>", "pad_token": "<pad>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "special_tokens_map_file": "/root/.cache/huggingface/transformers/8ed73a1ab9ef4e90a9451497bf96cfc38d34354352838a143f2dda1c81aed5ca.0dc5b1041f62041ebbd23b1297f2f573769d5c97d8b7c28180ec86b8f6185aa8", "name_or_path": "microsoft/Multilingual-MiniLM-L12-H384", "sp_model_kwargs": {}}
train_script.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import random
3
+
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, AdamW
5
+ import sys
6
+ import torch
7
+ import transformers
8
+ from torch.utils.data import Dataset, DataLoader
9
+ from torch.cuda.amp import autocast
10
+ import tqdm
11
+ from datetime import datetime
12
+ from shutil import copyfile
13
+ import os
14
+ ####################################
15
+
16
+ import gzip
17
+ from collections import defaultdict
18
+ import logging
19
+ import tqdm
20
+ import numpy as np
21
+ import sys
22
+ import pytrec_eval
23
+ from sentence_transformers import SentenceTransformer, util, CrossEncoder
24
+ import torch
25
+
26
+
27
+ model_name = sys.argv[1]
28
+ max_length = 350
29
+
30
+ ######### Evaluation
31
+ queries_filepath = 'msmarco-data/trec2019/msmarco-test2019-queries.tsv.gz'
32
+ queries_eval = {}
33
+ with gzip.open(queries_filepath, 'rt', encoding='utf8') as fIn:
34
+ for line in fIn:
35
+ qid, query = line.strip().split("\t")[0:2]
36
+ queries_eval[qid] = query
37
+
38
+ rel = defaultdict(lambda: defaultdict(int))
39
+
40
+ with open('msmarco-data/trec2019/2019qrels-pass.txt') as fIn:
41
+ for line in fIn:
42
+ qid, _, pid, score = line.strip().split()
43
+ score = int(score)
44
+ if score > 0:
45
+ rel[qid][pid] = score
46
+
47
+ relevant_qid = []
48
+ for qid in queries_eval:
49
+ if len(rel[qid]) > 0:
50
+ relevant_qid.append(qid)
51
+
52
+ # Read top 1k
53
+ passage_cand = {}
54
+
55
+ with gzip.open('msmarco-data/trec2019/msmarco-passagetest2019-top1000.tsv.gz', 'rt', encoding='utf8') as fIn:
56
+ for line in fIn:
57
+ qid, pid, query, passage = line.strip().split("\t")
58
+ if qid not in passage_cand:
59
+ passage_cand[qid] = []
60
+
61
+ passage_cand[qid].append([pid, passage])
62
+
63
+
64
+
65
+ def eval_modal(model_path):
66
+ run = {}
67
+ model = CrossEncoder(model_path, max_length=512)
68
+
69
+ for qid in relevant_qid:
70
+ query = queries_eval[qid]
71
+
72
+ cand = passage_cand[qid]
73
+ pids = [c[0] for c in cand]
74
+ corpus_sentences = [c[1] for c in cand]
75
+
76
+ ## CrossEncoder
77
+ cross_inp = [[query, sent] for sent in corpus_sentences]
78
+ if model.config.num_labels > 1:
79
+ cross_scores = model.predict(cross_inp, apply_softmax=True)[:, 1].tolist()
80
+ else:
81
+ cross_scores = model.predict(cross_inp, activation_fct=torch.nn.Identity()).tolist()
82
+
83
+ cross_scores_sparse = {}
84
+ for idx, pid in enumerate(pids):
85
+ cross_scores_sparse[pid] = cross_scores[idx]
86
+
87
+ sparse_scores = cross_scores_sparse
88
+ run[qid] = {}
89
+ for pid in sparse_scores:
90
+ run[qid][pid] = float(sparse_scores[pid])
91
+
92
+ evaluator = pytrec_eval.RelevanceEvaluator(rel, {'ndcg_cut.10'})
93
+ scores = evaluator.evaluate(run)
94
+ scores_mean = np.mean([ele["ndcg_cut_10"] for ele in scores.values()])
95
+
96
+ print("NDCG@10: {:.2f}".format(scores_mean * 100))
97
+ return scores_mean
98
+
99
+ ################################
100
+
101
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
102
+ config = AutoConfig.from_pretrained(model_name)
103
+ config.num_labels = 1
104
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, config=config)
105
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
106
+
107
+
108
+
109
+
110
+ #######################
111
+
112
+ queries = {}
113
+ corpus = {}
114
+
115
+ output_save_path = 'output/train_cross-encoder_mse-{}-{}'.format(model_name.replace("/", "_"), datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
116
+ output_save_path_latest = output_save_path+"-latest"
117
+ tokenizer.save_pretrained(output_save_path)
118
+ tokenizer.save_pretrained(output_save_path_latest)
119
+
120
+
121
+ # Write self to path
122
+ train_script_path = os.path.join(output_save_path, 'train_script.py')
123
+ copyfile(__file__, train_script_path)
124
+ with open(train_script_path, 'a') as fOut:
125
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
126
+
127
+
128
+ ####
129
+ train_script_path = os.path.join(output_save_path_latest, 'train_script.py')
130
+ copyfile(__file__, train_script_path)
131
+ with open(train_script_path, 'a') as fOut:
132
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
133
+
134
+
135
+
136
+ #### Read train files
137
+ class MultilingualDataset(Dataset):
138
+ def __init__(self):
139
+ self.examples = defaultdict(lambda: defaultdict(list)) #[id][lang] => [samples...]
140
+
141
+ def add(self, lang, filepath):
142
+ open_method = gzip.open if filepath.endswith('.gz') else open
143
+ with open_method(filepath, 'rt') as fIn:
144
+ for line in fIn:
145
+ pid, passage = line.strip().split("\t")
146
+ self.examples[pid][lang].append(passage)
147
+
148
+
149
+ def __len__(self):
150
+ return len(self.examples)
151
+
152
+ def __getitem__(self, item):
153
+ all_examples = self.examples[item] #All examples in all languages
154
+ lang_examples = random.choice(list(all_examples.values())) #Examples in on specific language
155
+ return random.choice(lang_examples) #One random example
156
+
157
+
158
+ train_corpus = MultilingualDataset()
159
+ train_corpus.add('en', 'msmarco-data/collection.tsv')
160
+ train_corpus.add('de', 'msmarco-data/de/collection.de.opus-mt.tsv.gz')
161
+ train_corpus.add('de', 'msmarco-data/de/collection.de.wmt19.tsv.gz')
162
+
163
+
164
+ train_queries = MultilingualDataset()
165
+ train_queries.add('en', 'msmarco-data/queries.train.tsv')
166
+ train_queries.add('de', 'msmarco-data/de/queries.train.de.opus-mt.tsv.gz')
167
+ train_queries.add('de', 'msmarco-data/de/queries.train.de.wmt19.tsv.gz')
168
+
169
+ ############## MSE Dataset
170
+ class MSEDataset(Dataset):
171
+ def __init__(self, filepath):
172
+ super().__init__()
173
+
174
+ self.examples = []
175
+ with open(filepath) as fIn:
176
+ for line in fIn:
177
+ pos_score, neg_score, qid, pid1, pid2 = line.strip().split("\t")
178
+ self.examples.append([qid, pid1, pid2, float(pos_score)-float(neg_score)])
179
+
180
+ def __len__(self):
181
+ return len(self.examples)
182
+
183
+ def __getitem__(self, item):
184
+ return self.examples[item]
185
+
186
+ train_batch_size = 16
187
+ train_dataset = MSEDataset('msmarco-data/bert_cat_ensemble_msmarcopassage_train_scores_ids.tsv')
188
+ train_dataloader = DataLoader(train_dataset, drop_last=True, shuffle=True, batch_size=train_batch_size)
189
+
190
+
191
+ ############## Optimizer
192
+
193
+ weight_decay = 0.01
194
+ max_grad_norm = 1
195
+ param_optimizer = list(model.named_parameters())
196
+
197
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
198
+ optimizer_grouped_parameters = [
199
+ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
200
+ {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
201
+ ]
202
+
203
+ optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)
204
+ scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=len(train_dataloader))
205
+ scaler = torch.cuda.amp.GradScaler()
206
+
207
+ loss_fct = torch.nn.MSELoss()
208
+ ### Start training
209
+ model.to(device)
210
+
211
+ auto_save = 10000
212
+ best_ndcg_score = 0
213
+ for step_idx, batch in tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
214
+ batch_queries = [train_queries[qid] for qid in batch[0]]
215
+ batch_pos = [train_corpus[cid] for cid in batch[1]]
216
+ batch_neg = [train_corpus[cid] for cid in batch[2]]
217
+ scores = batch[3].float().to(device) #torch.tensor(batch[3], dtype=torch.float, device=device)
218
+
219
+ with autocast():
220
+ inp_pos = tokenizer(batch_queries, batch_pos, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device)
221
+ pred_pos = model(**inp_pos).logits.squeeze()
222
+
223
+ inp_neg = tokenizer(batch_queries, batch_neg, max_length=max_length, padding=True, truncation='longest_first', return_tensors='pt').to(device)
224
+ pred_neg = model(**inp_neg).logits.squeeze()
225
+
226
+ pred_diff = pred_pos - pred_neg
227
+ loss_value = loss_fct(pred_diff, scores)
228
+
229
+
230
+ scaler.scale(loss_value).backward()
231
+ scaler.unscale_(optimizer)
232
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
233
+ scaler.step(optimizer)
234
+ scaler.update()
235
+
236
+ optimizer.zero_grad()
237
+ scheduler.step()
238
+
239
+ if (step_idx+1) % auto_save == 0:
240
+ print("Step:", step_idx+1)
241
+ model.save_pretrained(output_save_path_latest)
242
+ ndcg_score = eval_modal(output_save_path_latest)
243
+
244
+ if ndcg_score >= best_ndcg_score:
245
+ best_ndcg_score = ndcg_score
246
+ print("Save to:", output_save_path)
247
+ model.save_pretrained(output_save_path)
248
+
249
+ model.save_pretrained(output_save_path)
250
+
251
+
252
+ # Script was called via:
253
+ #python train_cross-encoder_mse_multilingual.py microsoft/Multilingual-MiniLM-L12-H384