nreimers commited on
Commit
c7d6d48
·
1 Parent(s): eaac9b1
README.md ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ datasets:
4
+ - datasets/sentence-transformers/reddit-title-body
5
+ widget:
6
+ - text: "Python is an interpreted, high-level and general-purpose programming language. Python's design philosophy emphasizes code readability with its notable use of significant whitespace. Its language constructs and object-oriented approach aim to help programmers write clear, logical code for small and large-scale projects."
7
+
8
+ license: apache-2.0
9
+ ---
10
+
11
+ # doc2query/reddit-t5-small-v1
12
+
13
+ This is a [doc2query](https://arxiv.org/abs/1904.08375) model based on T5 (also known as [docT5query](https://cs.uwaterloo.ca/~jimmylin/publications/Nogueira_Lin_2019_docTTTTTquery-v2.pdf)).
14
+
15
+ It can be used for:
16
+ - **Document expansion**: You generate for your paragraphs 20-40 queries and index the paragraphs and the generates queries in a standard BM25 index like Elasticsearch, OpenSearch, or Lucene. The generated queries help to close the lexical gap of lexical search, as the generate queries contain synonyms. Further, it re-weights words giving important words a higher weight even if they appear seldomn in a paragraph. In our [BEIR](https://arxiv.org/abs/2104.08663) paper we showed that BM25+docT5query is a powerful search engine. In the [BEIR repository](https://github.com/UKPLab/beir) we have an example how to use docT5query with Pyserini.
17
+ - **Domain Specific Training Data Generation**: It can be used to generate training data to learn an embedding model. On [SBERT.net](https://www.sbert.net/examples/unsupervised_learning/query_generation/README.html) we have an example how to use the model to generate (query, text) pairs for a given collection of unlabeled texts. These pairs can then be used to train powerful dense embedding models.
18
+
19
+ ## Usage
20
+ ```python
21
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
22
+
23
+ model_name = 'doc2query/reddit-t5-small-v1'
24
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
25
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
26
+
27
+ text = "Python is an interpreted, high-level and general-purpose programming language. Python's design philosophy emphasizes code readability with its notable use of significant whitespace. Its language constructs and object-oriented approach aim to help programmers write clear, logical code for small and large-scale projects."
28
+
29
+
30
+ input_ids = tokenizer.encode(text, max_length=384, truncation=True, return_tensors='pt')
31
+ outputs = model.generate(
32
+ input_ids=input_ids,
33
+ max_length=64,
34
+ do_sample=True,
35
+ top_p=0.95,
36
+ num_return_sequences=5)
37
+
38
+ print("Text:")
39
+ print(text)
40
+
41
+ print("\nGenerated Queries:")
42
+ for i in range(len(outputs)):
43
+ query = tokenizer.decode(outputs[i], skip_special_tokens=True)
44
+ print(f'{i + 1}: {query}')
45
+ ```
46
+
47
+ **Note:** `model.generate()` is non-deterministic. It produces different queries each time you run it.
48
+
49
+ ## Training
50
+ This model fine-tuned [google/t5-v1_1-small](https://huggingface.co/google/t5-v1_1-small) for 547k training steps. For the training script, see the `train_script.py` in this repository.
51
+
52
+ The input-text was truncated to 384 word pieces. Output text was generated up to 64 word pieces.
53
+
54
+ This model was trained on a (title, body) from Reddit.
55
+
56
+
57
+
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/t5-v1_1-small",
3
+ "architectures": [
4
+ "T5ForConditionalGeneration"
5
+ ],
6
+ "d_ff": 1024,
7
+ "d_kv": 64,
8
+ "d_model": 512,
9
+ "decoder_start_token_id": 0,
10
+ "dropout_rate": 0.1,
11
+ "eos_token_id": 1,
12
+ "feed_forward_proj": "gated-gelu",
13
+ "gradient_checkpointing": false,
14
+ "initializer_factor": 1.0,
15
+ "is_encoder_decoder": true,
16
+ "layer_norm_epsilon": 1e-06,
17
+ "model_type": "t5",
18
+ "num_decoder_layers": 8,
19
+ "num_heads": 6,
20
+ "num_layers": 8,
21
+ "output_past": true,
22
+ "pad_token_id": 0,
23
+ "relative_attention_num_buckets": 32,
24
+ "tie_word_embeddings": false,
25
+ "torch_dtype": "float32",
26
+ "transformers_version": "4.10.2",
27
+ "use_cache": true,
28
+ "vocab_size": 32128
29
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2924f2f11e5a444710a05bc413777da1527f264629a7e7234f4c84d7da558192
3
+ size 307934749
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"]}
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 100, "additional_special_tokens": ["<extra_id_0>", "<extra_id_1>", "<extra_id_2>", "<extra_id_3>", "<extra_id_4>", "<extra_id_5>", "<extra_id_6>", "<extra_id_7>", "<extra_id_8>", "<extra_id_9>", "<extra_id_10>", "<extra_id_11>", "<extra_id_12>", "<extra_id_13>", "<extra_id_14>", "<extra_id_15>", "<extra_id_16>", "<extra_id_17>", "<extra_id_18>", "<extra_id_19>", "<extra_id_20>", "<extra_id_21>", "<extra_id_22>", "<extra_id_23>", "<extra_id_24>", "<extra_id_25>", "<extra_id_26>", "<extra_id_27>", "<extra_id_28>", "<extra_id_29>", "<extra_id_30>", "<extra_id_31>", "<extra_id_32>", "<extra_id_33>", "<extra_id_34>", "<extra_id_35>", "<extra_id_36>", "<extra_id_37>", "<extra_id_38>", "<extra_id_39>", "<extra_id_40>", "<extra_id_41>", "<extra_id_42>", "<extra_id_43>", "<extra_id_44>", "<extra_id_45>", "<extra_id_46>", "<extra_id_47>", "<extra_id_48>", "<extra_id_49>", "<extra_id_50>", "<extra_id_51>", "<extra_id_52>", "<extra_id_53>", "<extra_id_54>", "<extra_id_55>", "<extra_id_56>", "<extra_id_57>", "<extra_id_58>", "<extra_id_59>", "<extra_id_60>", "<extra_id_61>", "<extra_id_62>", "<extra_id_63>", "<extra_id_64>", "<extra_id_65>", "<extra_id_66>", "<extra_id_67>", "<extra_id_68>", "<extra_id_69>", "<extra_id_70>", "<extra_id_71>", "<extra_id_72>", "<extra_id_73>", "<extra_id_74>", "<extra_id_75>", "<extra_id_76>", "<extra_id_77>", "<extra_id_78>", "<extra_id_79>", "<extra_id_80>", "<extra_id_81>", "<extra_id_82>", "<extra_id_83>", "<extra_id_84>", "<extra_id_85>", "<extra_id_86>", "<extra_id_87>", "<extra_id_88>", "<extra_id_89>", "<extra_id_90>", "<extra_id_91>", "<extra_id_92>", "<extra_id_93>", "<extra_id_94>", "<extra_id_95>", "<extra_id_96>", "<extra_id_97>", "<extra_id_98>", "<extra_id_99>"], "model_max_length": 512, "name_or_path": "google/t5-v1_1-small", "special_tokens_map_file": "/root/.cache/huggingface/transformers/3ad6f8335c1b1ef8966245899d47dcf735abd134d21fd7d26f621fe45ac01184.c94798918c92ded6aeef2d2f0e666d2cc4145eca1aa6e1336fde07f2e13e2f46", "sp_model_kwargs": {}, "tokenizer_class": "T5Tokenizer"}
train_script.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ from torch.utils.data import Dataset, IterableDataset
4
+ import gzip
5
+ import json
6
+ from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments
7
+ import sys
8
+ from datetime import datetime
9
+ import torch
10
+ import random
11
+ from shutil import copyfile
12
+ import os
13
+ import wandb
14
+ import re
15
+
16
+
17
+ logging.basicConfig(
18
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
19
+ datefmt="%Y-%m-%d %H:%M:%S",
20
+ handlers=[logging.StreamHandler(sys.stdout)],
21
+ )
22
+
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--model_name", default="google/t5-v1_1-base")
25
+ parser.add_argument("--train_files", required=True, nargs='+', default=[])
26
+ parser.add_argument("--epochs", default=1, type=int)
27
+ parser.add_argument("--batch_size", default=32, type=int)
28
+ parser.add_argument("--max_source_length", default=320, type=int)
29
+ parser.add_argument("--max_target_length", default=64, type=int)
30
+ parser.add_argument("--name", required=True)
31
+ parser.add_argument("--train_size", default=10*1000*1000, type=int)
32
+ parser.add_argument("--eval_size", default=10000, type=int)
33
+ parser.add_argument("--fp16", default=False, action='store_true')
34
+ args = parser.parse_args()
35
+
36
+ wandb.init(project="doc2query", name=f"{args.name}-{args.model_name}")
37
+
38
+
39
+
40
+
41
+ class PairDataset:
42
+ def __init__(self, filepath):
43
+ self.filepath = filepath
44
+ self.examples = []
45
+
46
+ def __iter__(self):
47
+ print("open", self.filepath)
48
+ with gzip.open(self.filepath, 'rt') as fIn:
49
+ for line in fIn:
50
+ example = self.get_example(json.loads(line))
51
+ if example is not None:
52
+ self.examples.append(example)
53
+ yield example
54
+
55
+ while True:
56
+ random.shuffle(self.examples)
57
+ for ex in self.examples:
58
+ yield ex
59
+
60
+
61
+ def get_example(self, raw_example):
62
+ return [raw_example[0], raw_example[1]]
63
+
64
+
65
+ class RedditTitleDataset(PairDataset):
66
+ def get_example(self, raw_example):
67
+ return [self.clean_title(raw_example['title']), raw_example['body']]
68
+
69
+
70
+ def clean_title(self, text):
71
+ text = text.replace("&amp;", "&").strip()
72
+ if text.startswith("["):
73
+ text = re.sub("^\[[a-zA-Z0-9]+\]", "", text).strip()
74
+
75
+ if text.endswith("]"):
76
+ text = re.sub("\[[a-zA-Z0-9\.]+\]$", "", text).strip()
77
+
78
+ if text.startswith("/r"):
79
+ text = re.sub("^/[a-zA-Z0-9/]+[;,: \-]+", "", text).strip()
80
+
81
+ return text
82
+
83
+
84
+
85
+ class MultiDataset(IterableDataset):
86
+ def __init__(self, filepaths, num_samples):
87
+ self.num_samples = num_samples
88
+ self.datasets = []
89
+ self.data_iterators = []
90
+
91
+ for filepath in filepaths:
92
+ if 'reddit_title_text' in filepath:
93
+ dataset = RedditTitleDataset(filepath)
94
+ else:
95
+ dataset = PairDataset(filepath)
96
+ self.datasets.append(dataset)
97
+ self.data_iterators.append(iter(dataset))
98
+
99
+ def __len__(self):
100
+ return self.num_samples
101
+
102
+ def __iter__(self):
103
+ while True:
104
+ for dataset in self.data_iterators:
105
+ yield next(dataset)
106
+
107
+ random.shuffle(self.data_iterators)
108
+
109
+ def delete_examples_cache(self):
110
+ for dataset in self.datasets:
111
+ dataset.examples = []
112
+
113
+
114
+
115
+ def main():
116
+ ############ Model
117
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
118
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
119
+
120
+ save_steps = 1000
121
+
122
+ output_dir = 'output/'+args.name+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
123
+ print("Output dir:", output_dir)
124
+
125
+ # Write self to path
126
+ os.makedirs(output_dir, exist_ok=True)
127
+
128
+ train_script_path = os.path.join(output_dir, 'train_script.py')
129
+ copyfile(__file__, train_script_path)
130
+ with open(train_script_path, 'a') as fOut:
131
+ fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
132
+
133
+ ####
134
+
135
+ training_args = Seq2SeqTrainingArguments(
136
+ output_dir=output_dir,
137
+ fp16=args.fp16,
138
+ fp16_backend="amp",
139
+ per_device_train_batch_size=args.batch_size,
140
+ evaluation_strategy="steps",
141
+ save_steps=save_steps,
142
+ logging_steps=100,
143
+ eval_steps=save_steps, #logging_steps,
144
+ warmup_steps=1000,
145
+ save_total_limit=1,
146
+ num_train_epochs=args.epochs,
147
+ report_to="wandb",
148
+ )
149
+
150
+ ############ Arguments
151
+
152
+ ############ Load datasets
153
+
154
+
155
+ train_dataset = MultiDataset(args.train_files, args.train_size)
156
+ train_dataset_iter = iter(train_dataset)
157
+ eval_dataset = [next(train_dataset_iter) for _ in range(args.eval_size)]
158
+ train_dataset.delete_examples_cache() #Make sure dev data is no re-used for training
159
+ print("Target:", eval_dataset[0][0])
160
+ print("Input:", eval_dataset[0][1])
161
+
162
+ print("Train dataset len:", len(train_dataset))
163
+
164
+
165
+ def data_collator(examples):
166
+ targets = [row[0] for row in examples]
167
+ inputs = [row[1] for row in examples]
168
+ label_pad_token_id = -100
169
+
170
+ model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None)
171
+
172
+ # Setup the tokenizer for targets
173
+ with tokenizer.as_target_tokenizer():
174
+ labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None)
175
+
176
+ # replace all tokenizer.pad_token_id in the labels by -100 to ignore padding in the loss.
177
+ labels["input_ids"] = [
178
+ [(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"]
179
+ ]
180
+
181
+
182
+ model_inputs["labels"] = torch.tensor(labels["input_ids"])
183
+ return model_inputs
184
+
185
+ ## Define the trainer
186
+ trainer = Seq2SeqTrainer(
187
+ model=model,
188
+ args=training_args,
189
+ train_dataset=train_dataset,
190
+ eval_dataset=eval_dataset,
191
+ tokenizer=tokenizer,
192
+ data_collator=data_collator
193
+ )
194
+
195
+ ### Save the model
196
+ train_result = trainer.train()
197
+ trainer.save_model()
198
+
199
+
200
+ if __name__ == "__main__":
201
+ main()
202
+
203
+ # Script was called via:
204
+ #python train_hf_trainer.py --model_name google/t5-v1_1-small --train_files /home/reddit/submissions_parsed/reddit_title_text_2010.jsonl.gz /home/reddit/submissions_parsed/reddit_title_text_2011.jsonl.gz /home/reddit/submissions_parsed/reddit_title_text_2012.jsonl.gz /home/reddit/submissions_parsed/reddit_title_text_2013.jsonl.gz /home/reddit/submissions_parsed/reddit_title_text_2014.jsonl.gz /home/reddit/submissions_parsed/reddit_title_text_2015.jsonl.gz /home/reddit/submissions_parsed/reddit_title_text_2016.jsonl.gz /home/reddit/submissions_parsed/reddit_title_text_2017.jsonl.gz /home/reddit/submissions_parsed/reddit_title_text_2018.jsonl.gz /home/reddit/submissions_parsed/reddit_title_text_2019.jsonl.gz /home/reddit/submissions_parsed/reddit_title_text_2020.jsonl.gz /home/reddit/submissions_parsed/reddit_title_text_2021.jsonl.gz --name reddit_title_text_all --train_size 100000000 --max_source_length 384
trainer_state.json ADDED
The diff for this file is too large to render. See raw diff
 
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7058fbc6f0b77640072b5461ff4594e25374997b286d794af21fb51b7a460cfa
3
+ size 2927