Text2Text Generation
Transformers
PyTorch
Safetensors
Czech
English
mt5
Inference Endpoints
michal-stefanik commited on
Commit
289fb31
·
1 Parent(s): af3183d

README & Training resources

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
  tokenizer.json filter=lfs diff=lfs merge=lfs -text
36
+ czech_squad_4-sents.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - squad
5
+ - cs_sqad-3.0
6
+ language:
7
+ - cs
8
+ - en
9
+ metrics:
10
+ - rouge
11
+ pipeline_tag: text2text-generation
12
+ ---
13
+
14
+ # Model Card for mTk-SQuAD_en-SQAD_cs-1B
15
+
16
+ This model is a generative in-context few-shot learner specialized in Czech. It was trained on a combination of English SQuAD and Czech SQAD dataset.
17
+
18
+ You can find detailed information on [Project Github](https://github.com/fewshot-goes-multilingual/slavic-incontext-learning) & the referenced paper.
19
+
20
+
21
+ ## Model Details
22
+
23
+ ### Model Description
24
+
25
+
26
+ - **Developed by:** Michal Stefanik & Marek Kadlcik, Masaryk University
27
+ - **Model type:** mt5
28
+ - **Language(s) (NLP):** cs,en
29
+ - **License:** MIT
30
+ - **Finetuned from model:** google/mt5-large
31
+
32
+ ### Model Sources
33
+
34
+ - **Repository:** https://github.com/fewshot-goes-multilingual/slavic-incontext-learning
35
+ - **Paper:** [To be filled]
36
+
37
+ ## Uses
38
+
39
+ This model is intended to be used in a few-shot in-context learning format in the target language (Czech), or in the source language (English, see below).
40
+ It was evaluated for unseen task learning (with k=3 demonstrations) in Czech: see the referenced paper for details.
41
+
42
+ ### How to Get Started with the Model
43
+
44
+ Use the code below to get started with the model.
45
+
46
+ ```python
47
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
48
+
49
+ model = AutoModelForSeq2SeqLM.from_pretrained("{this model path}")
50
+ tokenizer = AutoTokenizer.from_pretrained("{this model path}")
51
+
52
+ # Instead, use keywords "Otázka", "Kontext" and "Odpověď" for Czech few-shot prompts
53
+ input_text = """
54
+ Question: What is the customer's name?
55
+ Context: Origin: Barrack Obama, Customer id: Bill Moe.
56
+ Answer: Bill Moe,
57
+ Question: What is the customer's name?
58
+ Context: Customer id: Barrack Obama, if not deliverable, return to Bill Clinton.
59
+ Answer:
60
+ """
61
+
62
+ inputs = tokenizer(input_text, return_tensors="pt")
63
+
64
+ outputs = model.generate(**inputs)
65
+
66
+ print("Answer:")
67
+ print(tokenizer.decode(outputs))
68
+ ```
69
+
70
+ ## Training Details
71
+
72
+ Training this model can be reproduced by running `pip install -r requirements.txt && python train.py`.
73
+ See the referenced script for hyperparameters and other training configurations.
74
+
75
+ ## Citation
76
+
77
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
78
+
79
+ **BibTeX:**
80
+
81
+ [Will be filled soon]
czech_squad_4-sents.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95083cfcb45f6eb1d2640d35af360e6100ce5ef01182838e86b7ac8d4b4e0d9b
3
+ size 13358790
priming_objective.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ from typing import Iterable, Union, Dict, List, Optional
4
+
5
+ import torch
6
+ from adaptor.objectives.seq2seq import Sequence2Sequence
7
+ from transformers import BatchEncoding
8
+
9
+ logger = logging.getLogger()
10
+
11
+ priming_formats = {
12
+ "QA": {"cs": "Otázka: %s Kontext: %s Odpověď:",
13
+ "en": "Question: %s Context: %s Answer:",
14
+ "ru": "Вопрос: %s Контекст: %s Отвечать:"}}
15
+
16
+
17
+ class Priming(Sequence2Sequence):
18
+
19
+ def __init__(self, *args,
20
+ train_question_categories: Iterable[str],
21
+ max_eval_samples: int,
22
+ val_question_categories: Optional[Iterable[str]] = None,
23
+ min_num_demonstrations: int = 2,
24
+ max_num_demonstrations: int = 5,
25
+ demos_infer_batch_size: int = 32,
26
+ demos_selection_strategy: str = "hard",
27
+ difficulty_sample: int = 64,
28
+ max_input_length: int = 8000,
29
+ **kwargs):
30
+ super().__init__(*args, **kwargs)
31
+
32
+ self.train_question_categories = list(train_question_categories)
33
+ self.val_question_categories = list(val_question_categories) if val_question_categories is not None else None
34
+
35
+ self.min_num_demonstrations = min_num_demonstrations
36
+ self.max_num_demonstrations = max_num_demonstrations
37
+ self.demos_infer_batch_size = demos_infer_batch_size
38
+ self.demos_selection_strategy = demos_selection_strategy
39
+ self.difficulty_sample = difficulty_sample
40
+ self.max_input_length = max_input_length
41
+ self.max_eval_samples = max_eval_samples
42
+
43
+ def _construct_qa_prompt(self, question: str, context: str) -> str:
44
+ return priming_formats["QA"][self.source_lang_id] % (question, context)
45
+
46
+ def _construct_demonstration(self, prompt: str, answer: str) -> str:
47
+ return "%s %s " % (prompt, answer)
48
+
49
+ def _construct_primed_prompt(self, primed_demonstrations: List[str], prompt: str) -> str:
50
+ return " ".join(primed_demonstrations) + " " + prompt
51
+
52
+ def forced_generation_score(self, input_texts: List[str], forced_output: str) -> torch.FloatTensor:
53
+ inputs = self.tokenizer(input_texts, return_tensors="pt", padding="longest", truncation=True)
54
+ inputs = inputs.to(self.compatible_head_model.device)
55
+
56
+ with self.tokenizer.as_target_tokenizer():
57
+ output_ids = self.tokenizer(forced_output, return_tensors="pt", padding="longest",
58
+ truncation=True).input_ids.to(self.compatible_head_model.device)
59
+ forced_outputs = self.compatible_head_model.prepare_decoder_input_ids_from_labels(output_ids)
60
+ forced_outputs = forced_outputs.to(self.compatible_head_model.device)
61
+
62
+ outputs = self.compatible_head_model(**inputs,
63
+ decoder_input_ids=forced_outputs.expand(inputs.input_ids.shape[0], -1))
64
+ output_log_probs = outputs.logits.log_softmax(-1)
65
+ forced_output_logits = torch.gather(output_log_probs, -1,
66
+ output_ids.expand(inputs.input_ids.shape[0], -1).unsqueeze(-1))
67
+ forced_output_log_score = forced_output_logits.sum((-1, -2))
68
+ # we do not need to normalize, as all the targets are the same <=> same length
69
+ return forced_output_log_score.double().exp()
70
+
71
+ def _pick_most_difficult_demo(self,
72
+ selected_demos: List[str],
73
+ next_demo_cands: List[str],
74
+ predict_prompt: str,
75
+ predicted_answer: str) -> int:
76
+ with torch.no_grad():
77
+ difficulties = torch.empty(0, device=self.compatible_head_model.device, dtype=torch.float)
78
+
79
+ for batch_offset in range(0, len(next_demo_cands), self.demos_infer_batch_size):
80
+ next_demo_cands_batch = next_demo_cands[batch_offset: batch_offset + self.demos_infer_batch_size]
81
+
82
+ primed_prompts = [self._construct_primed_prompt(selected_demos + [demo], predict_prompt)
83
+ for demo in next_demo_cands_batch]
84
+ cands_difficulty = self.forced_generation_score(primed_prompts, predicted_answer)
85
+
86
+ difficulties = torch.hstack((difficulties, cands_difficulty))
87
+
88
+ assert difficulties.argmin() < len(next_demo_cands)
89
+
90
+ return difficulties.argmin()
91
+
92
+ def _get_inputs_iterator(self, split: str) -> Iterable[Union[BatchEncoding, Dict[str, torch.Tensor]]]:
93
+ """
94
+ Creates a default iterator over encodings with aligned input and output texts.
95
+ :param split: Data split. `train` or `eval`.
96
+ :return: Iterator of model input encodings.
97
+ """
98
+ # we materialize all samples in memory, so that we can heuristically pick the combinations
99
+ questions, contexts, answers = (list(it) for it in self._per_split_iterators(split))
100
+ question_categories = self.train_question_categories if split == "train" else self.val_question_categories
101
+
102
+ assert len(questions) == len(contexts) == len(answers) == len(question_categories), \
103
+ "Given numbers of questions, contexts and answers do not match."
104
+
105
+ prompts = [self._construct_qa_prompt(q, c) for q, c in zip(questions, contexts)]
106
+
107
+ features_batch = []
108
+ cat_index = {cat: [i for i, sample_cat in enumerate(question_categories) if cat == sample_cat]
109
+ for cat in set(question_categories)}
110
+
111
+ retrieved_samples = 0
112
+
113
+ for idx, sample_category in enumerate(question_categories):
114
+ if not cat_index[sample_category]:
115
+ logger.warning("No samples within the category %s", sample_category)
116
+ continue
117
+
118
+ pred_prompt, pred_answer = prompts[idx], answers[idx]
119
+
120
+ picked_demonstrations = []
121
+
122
+ # a number of demonstrations is in the specified range
123
+ expected_num_demonstrations = random.randint(self.min_num_demonstrations, self.max_num_demonstrations)
124
+
125
+ while len(picked_demonstrations) < expected_num_demonstrations:
126
+ if sum(map(len, picked_demonstrations)) > self.max_input_length:
127
+ logger.warning("Skipping too long prompt.")
128
+ break
129
+ if self.demos_selection_strategy == "hard":
130
+ # pick the most difficult examples out of a sample
131
+ # we do not need to worry for picking up the predicted sample among demonstrations in hard strategy
132
+ if len(cat_index[sample_category]) <= 1:
133
+ # we can not construct informative demonstrations for categories of a single item
134
+ break
135
+
136
+ samples_idx = random.choices(cat_index[sample_category], k=self.difficulty_sample)
137
+ cand_demonstrations = [self._construct_demonstration(prompts[i], answers[i]) for i in samples_idx]
138
+ selected_index = self._pick_most_difficult_demo(picked_demonstrations, cand_demonstrations,
139
+ pred_prompt, pred_answer)
140
+ picked_demonstrations.append(cand_demonstrations[selected_index])
141
+ elif self.demos_selection_strategy == "informative":
142
+ if len(cat_index[sample_category]) <= 1:
143
+ # we can not construct informative demonstrations for categories of a single item
144
+ break
145
+ selected_cat_index = random.randint(1, len(cat_index[sample_category])-1)
146
+ selected_index = cat_index[sample_category][selected_cat_index]
147
+ if selected_index == idx:
148
+ # we do not want to expose the predicted sample in demonstrations
149
+ selected_index = cat_index[sample_category][selected_cat_index-1]
150
+ picked_demonstration = self._construct_demonstration(prompts[selected_index],
151
+ answers[selected_index])
152
+ picked_demonstrations.append(picked_demonstration)
153
+ elif self.demos_selection_strategy == "random":
154
+ # evaluation: do not infer samples' difficulty, pick randomly
155
+ selected_index = random.randint(1, len(prompts)-1)
156
+ if selected_index == idx:
157
+ # we do not want to expose the predicted sample in demonstrations
158
+ selected_index -= 1
159
+ picked_demonstration = self._construct_demonstration(prompts[selected_index],
160
+ answers[selected_index])
161
+ picked_demonstrations.append(picked_demonstration)
162
+ else:
163
+ raise ValueError("Unknown demon selection strategy: '%s'" % self.demos_selection_strategy)
164
+ if len(picked_demonstrations) != expected_num_demonstrations:
165
+ # we omit examples with none or only one demonstration in the category
166
+ continue
167
+
168
+ # encode a yielded batch
169
+ primed_prompt = self._construct_primed_prompt(picked_demonstrations, pred_prompt)
170
+
171
+ primed_prompt_encoding = self.tokenizer(primed_prompt, truncation=True)
172
+ label_encoding = self.tokenizer(pred_answer, truncation=True)
173
+
174
+ features_batch.append({"input_ids": primed_prompt_encoding.input_ids,
175
+ "attention_mask": primed_prompt_encoding.attention_mask,
176
+ "labels": label_encoding.input_ids})
177
+ if len(features_batch) == self.batch_size:
178
+ yield self.collator(features_batch)
179
+ features_batch = []
180
+
181
+ retrieved_samples += 1
182
+ if split == "eval" and retrieved_samples >= self.max_eval_samples:
183
+ # custom evaluation break - we need all samples in set to match categories,
184
+ # but do not want to iterate them all
185
+ break
186
+
187
+ if features_batch:
188
+ # yield last nonempty residual batch
189
+ yield self.collator(features_batch)
190
+
191
+ def _compute_loss(self,
192
+ lm_logit_outputs: torch.FloatTensor,
193
+ labels: torch.LongTensor,
194
+ inputs: Optional[Union[BatchEncoding, Dict[str, torch.Tensor]]] = None) -> torch.FloatTensor:
195
+ # customization for mt5 model, with incorrectly-set tokenizer.vocab_size
196
+ # This should be fixed in upcoming release of adaptor (>=0.1.5)
197
+ loss_fct = torch.nn.CrossEntropyLoss()
198
+ lm_loss = loss_fct(lm_logit_outputs.flatten(end_dim=1), labels.flatten())
199
+
200
+ return lm_loss
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ adaptor[generative]==0.2.0
2
+ torch==1.11.0
3
+ pandas
4
+ nltk
train_mt5_qa_en_SQuAD+cs_random.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List
3
+
4
+ from adaptor.adapter import Adapter
5
+ from adaptor.evaluators.generative import BLEU
6
+ from adaptor.lang_module import LangModule
7
+ from adaptor.schedules import ParallelSchedule
8
+ from adaptor.utils import AdaptationArguments, StoppingStrategy
9
+ from datasets import load_dataset
10
+
11
+ from priming_objective import Priming
12
+
13
+ training_arguments = AdaptationArguments(output_dir="train_dir_SQuAD_random_large",
14
+ learning_rate=2e-5, # we set LR=2e-4 for pre-training experiments
15
+ stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
16
+ # stopping_strategy=StoppingStrategy.NUM_STEPS_TOTAL,
17
+ do_train=True,
18
+ do_eval=True,
19
+ warmup_steps=1000,
20
+ max_steps=10000,
21
+ gradient_accumulation_steps=30,
22
+ eval_steps=500,
23
+ logging_steps=10,
24
+ save_steps=500,
25
+ num_train_epochs=5,
26
+ evaluation_strategy="steps",
27
+ save_total_limit=10,
28
+ stopping_patience=10)
29
+ eval_examples = 200
30
+
31
+ # priming
32
+ num_demonstrations = 3
33
+
34
+
35
+ def _construct_priming_prompt(previous_examples: List[str], current_example: str) -> str:
36
+ return " ".join(previous_examples + [current_example])
37
+
38
+
39
+ lang_module = LangModule("google/mt5-large")
40
+
41
+ # priming
42
+ per_type_examples = {}
43
+
44
+ qa_en = load_dataset("squad")
45
+ qa_train = qa_en["train"].filter(lambda entry: len(entry["context"]) < 2000)
46
+
47
+ val_metrics = [BLEU(**{"additional_sep_char": "▁"})]
48
+
49
+ # SQuAD QA dataset & objective:
50
+
51
+
52
+ def _get_en_qa_categories(data) -> List[str]:
53
+ return [question.split()[0] if not question.startswith("To")
54
+ else " ".join(question.split()[:2])
55
+ for question in data["question"]]
56
+
57
+
58
+ q_answering_en = Priming(lang_module,
59
+ max_eval_samples=eval_examples,
60
+ demos_selection_strategy="random",
61
+ texts_or_path=qa_train["question"],
62
+ text_pair_or_path=qa_train["context"],
63
+ val_texts_or_path=qa_en["validation"]["question"][-eval_examples:],
64
+ val_text_pair_or_path=qa_en["validation"]["context"][-eval_examples:],
65
+ labels_or_path=[a["text"][0] for a in qa_train["answers"]],
66
+ val_labels_or_path=[a["text"][0] for a in qa_en["validation"]["answers"]][-eval_examples:],
67
+ train_question_categories=_get_en_qa_categories(qa_train),
68
+ val_question_categories=_get_en_qa_categories(qa_en["validation"])[-eval_examples:],
69
+ batch_size=1,
70
+ val_evaluators=val_metrics,
71
+ # val_evaluators=val_metrics,
72
+ source_lang_id="en",
73
+ objective_id="AQA-en")
74
+
75
+ # Czech data & objective
76
+
77
+ squad_cs_dataset = json.load(open("czech_squad_4-sents.json"))
78
+
79
+ skipped = 0
80
+
81
+ questions_cs = []
82
+ contexts_cs = []
83
+ answers_cs = []
84
+ categories_cs = []
85
+
86
+ for i, entry in squad_cs_dataset.items():
87
+ if len(entry["context"]) > 800:
88
+ skipped += 1
89
+ continue
90
+
91
+ questions_cs.append(entry["question"])
92
+ contexts_cs.append(entry["context"])
93
+ answers_cs.append(entry["answers"]["text"][0])
94
+ categories_cs.append(entry["answer_type"])
95
+
96
+ print("Skipped cs examples: %s" % skipped)
97
+
98
+ q_answering_cs = Priming(lang_module,
99
+ max_eval_samples=eval_examples,
100
+ demos_selection_strategy="random",
101
+ texts_or_path=questions_cs[:-eval_examples],
102
+ text_pair_or_path=contexts_cs[:-eval_examples],
103
+ val_texts_or_path=questions_cs[-eval_examples:],
104
+ val_text_pair_or_path=contexts_cs[-eval_examples:],
105
+ labels_or_path=answers_cs[:-eval_examples],
106
+ val_labels_or_path=answers_cs[-eval_examples:],
107
+ train_question_categories=categories_cs[:-eval_examples],
108
+ val_question_categories=categories_cs[-eval_examples:],
109
+ batch_size=1,
110
+ val_evaluators=val_metrics,
111
+ source_lang_id="cs",
112
+ objective_id="SQUAD-cs")
113
+
114
+ schedule = ParallelSchedule(objectives=[q_answering_en, q_answering_cs],
115
+ args=training_arguments)
116
+
117
+ adapter = Adapter(lang_module, schedule, args=training_arguments)
118
+ adapter.train()