upload
Browse files- README.md +82 -0
- config.json +30 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +1 -0
- spiece.model +3 -0
- tokenizer.json +3 -0
- tokenizer_config.json +1 -0
- train_script.py +164 -0
- training_args.bin +3 -0
README.md
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: nl
|
3 |
+
datasets:
|
4 |
+
- unicamp-dl/mmarco
|
5 |
+
widget:
|
6 |
+
- text: "Python is een programmeertaal die begin jaren 90 ontworpen en ontwikkeld werd door Guido van Rossum, destijds verbonden aan het Centrum voor Wiskunde en Informatica (daarvoor Mathematisch Centrum) in Amsterdam. De taal is mede gebaseerd op inzichten van professor Lambert Meertens, die een taal genaamd ABC had ontworpen, bedoeld als alternatief voor BASIC, maar dan met geavanceerde datastructuren. Inmiddels wordt de taal doorontwikkeld door een enthousiaste groep, tot juli 2018 geleid door Van Rossum. Deze groep wordt ondersteund door vrijwilligers op het internet. De ontwikkeling van Python wordt geleid door de Python Software Foundation. Python is vrije software."
|
7 |
+
|
8 |
+
license: apache-2.0
|
9 |
+
---
|
10 |
+
|
11 |
+
# doc2query/msmarco-dutch-mt5-base-v1
|
12 |
+
|
13 |
+
This is a [doc2query](https://arxiv.org/abs/1904.08375) model based on mT5 (also known as [docT5query](https://cs.uwaterloo.ca/~jimmylin/publications/Nogueira_Lin_2019_docTTTTTquery-v2.pdf)).
|
14 |
+
|
15 |
+
It can be used for:
|
16 |
+
- **Document expansion**: You generate for your paragraphs 20-40 queries and index the paragraphs and the generates queries in a standard BM25 index like Elasticsearch, OpenSearch, or Lucene. The generated queries help to close the lexical gap of lexical search, as the generate queries contain synonyms. Further, it re-weights words giving important words a higher weight even if they appear seldomn in a paragraph. In our [BEIR](https://arxiv.org/abs/2104.08663) paper we showed that BM25+docT5query is a powerful search engine. In the [BEIR repository](https://github.com/beir-cellar/beir) we have an example how to use docT5query with Pyserini.
|
17 |
+
- **Domain Specific Training Data Generation**: It can be used to generate training data to learn an embedding model. In our [GPL-Paper](https://arxiv.org/abs/2112.07577) / [GPL Example on SBERT.net](https://www.sbert.net/examples/domain_adaptation/README.html#gpl-generative-pseudo-labeling) we have an example how to use the model to generate (query, text) pairs for a given collection of unlabeled texts. These pairs can then be used to train powerful dense embedding models.
|
18 |
+
|
19 |
+
## Usage
|
20 |
+
```python
|
21 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
22 |
+
import torch
|
23 |
+
|
24 |
+
model_name = 'doc2query/msmarco-dutch-mt5-base-v1'
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
26 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
27 |
+
|
28 |
+
text = "Python ist eine universelle, üblicherweise interpretierte, höhere Programmiersprache. Sie hat den Anspruch, einen gut lesbaren, knappen Programmierstil zu fördern. So werden beispielsweise Blöcke nicht durch geschweifte Klammern, sondern durch Einrückungen strukturiert."
|
29 |
+
|
30 |
+
|
31 |
+
def create_queries(para):
|
32 |
+
input_ids = tokenizer.encode(para, return_tensors='pt')
|
33 |
+
with torch.no_grad():
|
34 |
+
# Here we use top_k / top_k random sampling. It generates more diverse queries, but of lower quality
|
35 |
+
sampling_outputs = model.generate(
|
36 |
+
input_ids=input_ids,
|
37 |
+
max_length=64,
|
38 |
+
do_sample=True,
|
39 |
+
top_p=0.95,
|
40 |
+
top_k=10,
|
41 |
+
num_return_sequences=5
|
42 |
+
)
|
43 |
+
|
44 |
+
# Here we use Beam-search. It generates better quality queries, but with less diversity
|
45 |
+
beam_outputs = model.generate(
|
46 |
+
input_ids=input_ids,
|
47 |
+
max_length=64,
|
48 |
+
num_beams=5,
|
49 |
+
no_repeat_ngram_size=2,
|
50 |
+
num_return_sequences=5,
|
51 |
+
early_stopping=True
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
print("Paragraph:")
|
56 |
+
print(para)
|
57 |
+
|
58 |
+
print("\nBeam Outputs:")
|
59 |
+
for i in range(len(beam_outputs)):
|
60 |
+
query = tokenizer.decode(beam_outputs[i], skip_special_tokens=True)
|
61 |
+
print(f'{i + 1}: {query}')
|
62 |
+
|
63 |
+
print("\nSampling Outputs:")
|
64 |
+
for i in range(len(sampling_outputs)):
|
65 |
+
query = tokenizer.decode(sampling_outputs[i], skip_special_tokens=True)
|
66 |
+
print(f'{i + 1}: {query}')
|
67 |
+
|
68 |
+
create_queries(text)
|
69 |
+
|
70 |
+
```
|
71 |
+
|
72 |
+
**Note:** `model.generate()` is non-deterministic for top_k/top_n sampling. It produces different queries each time you run it.
|
73 |
+
|
74 |
+
## Training
|
75 |
+
This model fine-tuned [google/mt5-base](https://huggingface.co/google/mt5-base) for 66k training steps (4 epochs on the 500k training pairs from MS MARCO). For the training script, see the `train_script.py` in this repository.
|
76 |
+
|
77 |
+
The input-text was truncated to 320 word pieces. Output text was generated up to 64 word pieces.
|
78 |
+
|
79 |
+
This model was trained on a (query, passage) from the [mMARCO dataset](https://github.com/unicamp-dl/mMARCO).
|
80 |
+
|
81 |
+
|
82 |
+
|
config.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "google/mt5-base",
|
3 |
+
"architectures": [
|
4 |
+
"MT5ForConditionalGeneration"
|
5 |
+
],
|
6 |
+
"d_ff": 2048,
|
7 |
+
"d_kv": 64,
|
8 |
+
"d_model": 768,
|
9 |
+
"decoder_start_token_id": 0,
|
10 |
+
"dropout_rate": 0.1,
|
11 |
+
"eos_token_id": 1,
|
12 |
+
"feed_forward_proj": "gated-gelu",
|
13 |
+
"initializer_factor": 1.0,
|
14 |
+
"is_encoder_decoder": true,
|
15 |
+
"layer_norm_epsilon": 1e-06,
|
16 |
+
"model_type": "mt5",
|
17 |
+
"num_decoder_layers": 12,
|
18 |
+
"num_heads": 12,
|
19 |
+
"num_layers": 12,
|
20 |
+
"output_past": true,
|
21 |
+
"pad_token_id": 0,
|
22 |
+
"relative_attention_max_distance": 128,
|
23 |
+
"relative_attention_num_buckets": 32,
|
24 |
+
"tie_word_embeddings": false,
|
25 |
+
"tokenizer_class": "T5Tokenizer",
|
26 |
+
"torch_dtype": "float32",
|
27 |
+
"transformers_version": "4.18.0",
|
28 |
+
"use_cache": true,
|
29 |
+
"vocab_size": 250112
|
30 |
+
}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3408a079bd467dc9c9deaaf85630340f3dfa8de07bbaa407cd1bbdd4eb7d3e3d
|
3 |
+
size 2329700301
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"}
|
spiece.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef78f86560d809067d12bac6c09f19a462cb3af3f54d2b8acbba26e1433125d6
|
3 |
+
size 4309802
|
tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4d3fca0dbb3a53bc1eddfc2e47ef441d7a94a70879e6750baddab04441a78305
|
3 |
+
size 16330621
|
tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 0, "additional_special_tokens": null, "special_tokens_map_file": "/home/patrick/.cache/torch/transformers/685ac0ca8568ec593a48b61b0a3c272beee9bc194a3c7241d15dcadb5f875e53.f76030f3ec1b96a8199b2593390c610e76ca8028ef3d24680000619ffb646276", "name_or_path": "google/mt5-base", "sp_model_kwargs": {}, "tokenizer_class": "T5Tokenizer"}
|
train_script.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
from torch.utils.data import Dataset, IterableDataset
|
4 |
+
import gzip
|
5 |
+
import json
|
6 |
+
from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments
|
7 |
+
import sys
|
8 |
+
from datetime import datetime
|
9 |
+
import torch
|
10 |
+
import random
|
11 |
+
from shutil import copyfile
|
12 |
+
import os
|
13 |
+
import wandb
|
14 |
+
import random
|
15 |
+
import re
|
16 |
+
from datasets import load_dataset
|
17 |
+
import tqdm
|
18 |
+
|
19 |
+
|
20 |
+
logging.basicConfig(
|
21 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
22 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
23 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
24 |
+
)
|
25 |
+
|
26 |
+
parser = argparse.ArgumentParser()
|
27 |
+
parser.add_argument("--lang", required=True)
|
28 |
+
parser.add_argument("--model_name", default="google/mt5-base")
|
29 |
+
parser.add_argument("--epochs", default=4, type=int)
|
30 |
+
parser.add_argument("--batch_size", default=32, type=int)
|
31 |
+
parser.add_argument("--max_source_length", default=320, type=int)
|
32 |
+
parser.add_argument("--max_target_length", default=64, type=int)
|
33 |
+
parser.add_argument("--eval_size", default=1000, type=int)
|
34 |
+
#parser.add_argument("--fp16", default=False, action='store_true')
|
35 |
+
args = parser.parse_args()
|
36 |
+
|
37 |
+
wandb.init(project="doc2query", name=f"{args.lang}-{args.model_name}")
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def main():
|
44 |
+
############ Load dataset
|
45 |
+
queries = {}
|
46 |
+
for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'queries-{args.lang}')['train']):
|
47 |
+
queries[row['id']] = row['text']
|
48 |
+
|
49 |
+
"""
|
50 |
+
collection = {}
|
51 |
+
for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']):
|
52 |
+
collection[row['id']] = row['text']
|
53 |
+
"""
|
54 |
+
collection = load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']
|
55 |
+
|
56 |
+
train_pairs = []
|
57 |
+
eval_pairs = []
|
58 |
+
|
59 |
+
|
60 |
+
with open('qrels.train.tsv') as fIn:
|
61 |
+
for line in fIn:
|
62 |
+
qid, _, did, _ = line.strip().split("\t")
|
63 |
+
|
64 |
+
qid = int(qid)
|
65 |
+
did = int(did)
|
66 |
+
|
67 |
+
assert did == collection[did]['id']
|
68 |
+
text = collection[did]['text']
|
69 |
+
|
70 |
+
pair = (queries[qid], text)
|
71 |
+
if len(eval_pairs) < args.eval_size:
|
72 |
+
eval_pairs.append(pair)
|
73 |
+
else:
|
74 |
+
train_pairs.append(pair)
|
75 |
+
|
76 |
+
|
77 |
+
print(f"Train pairs: {len(train_pairs)}")
|
78 |
+
|
79 |
+
|
80 |
+
############ Model
|
81 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
|
82 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
83 |
+
|
84 |
+
save_steps = 1000
|
85 |
+
|
86 |
+
output_dir = 'output/'+args.lang+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
87 |
+
print("Output dir:", output_dir)
|
88 |
+
|
89 |
+
# Write self to path
|
90 |
+
os.makedirs(output_dir, exist_ok=True)
|
91 |
+
|
92 |
+
train_script_path = os.path.join(output_dir, 'train_script.py')
|
93 |
+
copyfile(__file__, train_script_path)
|
94 |
+
with open(train_script_path, 'a') as fOut:
|
95 |
+
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
96 |
+
|
97 |
+
####
|
98 |
+
|
99 |
+
training_args = Seq2SeqTrainingArguments(
|
100 |
+
output_dir=output_dir,
|
101 |
+
bf16=True,
|
102 |
+
per_device_train_batch_size=args.batch_size,
|
103 |
+
evaluation_strategy="steps",
|
104 |
+
save_steps=save_steps,
|
105 |
+
logging_steps=100,
|
106 |
+
eval_steps=save_steps, #logging_steps,
|
107 |
+
warmup_steps=1000,
|
108 |
+
save_total_limit=1,
|
109 |
+
num_train_epochs=args.epochs,
|
110 |
+
report_to="wandb",
|
111 |
+
)
|
112 |
+
|
113 |
+
############ Arguments
|
114 |
+
|
115 |
+
############ Load datasets
|
116 |
+
|
117 |
+
|
118 |
+
print("Input:", train_pairs[0][1])
|
119 |
+
print("Target:", train_pairs[0][0])
|
120 |
+
|
121 |
+
print("Input:", eval_pairs[0][1])
|
122 |
+
print("Target:", eval_pairs[0][0])
|
123 |
+
|
124 |
+
|
125 |
+
def data_collator(examples):
|
126 |
+
targets = [row[0] for row in examples]
|
127 |
+
inputs = [row[1] for row in examples]
|
128 |
+
label_pad_token_id = -100
|
129 |
+
|
130 |
+
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None)
|
131 |
+
|
132 |
+
# Setup the tokenizer for targets
|
133 |
+
with tokenizer.as_target_tokenizer():
|
134 |
+
labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None)
|
135 |
+
|
136 |
+
# replace all tokenizer.pad_token_id in the labels by -100 to ignore padding in the loss.
|
137 |
+
labels["input_ids"] = [
|
138 |
+
[(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"]
|
139 |
+
]
|
140 |
+
|
141 |
+
|
142 |
+
model_inputs["labels"] = torch.tensor(labels["input_ids"])
|
143 |
+
return model_inputs
|
144 |
+
|
145 |
+
## Define the trainer
|
146 |
+
trainer = Seq2SeqTrainer(
|
147 |
+
model=model,
|
148 |
+
args=training_args,
|
149 |
+
train_dataset=train_pairs,
|
150 |
+
eval_dataset=eval_pairs,
|
151 |
+
tokenizer=tokenizer,
|
152 |
+
data_collator=data_collator
|
153 |
+
)
|
154 |
+
|
155 |
+
### Save the model
|
156 |
+
train_result = trainer.train()
|
157 |
+
trainer.save_model()
|
158 |
+
|
159 |
+
|
160 |
+
if __name__ == "__main__":
|
161 |
+
main()
|
162 |
+
|
163 |
+
# Script was called via:
|
164 |
+
#python train_hf_trainer_multilingual.py --lang dutch
|
training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9734bbc3a626edcaa03fc8ed83caeb957e2b9b684508345e53c2c1375c213e97
|
3 |
+
size 3247
|