michal-stefanik
commited on
Commit
·
289fb31
1
Parent(s):
af3183d
README & Training resources
Browse files- .gitattributes +1 -0
- README.md +81 -0
- czech_squad_4-sents.json +3 -0
- priming_objective.py +200 -0
- requirements.txt +4 -0
- train_mt5_qa_en_SQuAD+cs_random.py +118 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
36 |
+
czech_squad_4-sents.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
datasets:
|
4 |
+
- squad
|
5 |
+
- cs_sqad-3.0
|
6 |
+
language:
|
7 |
+
- cs
|
8 |
+
- en
|
9 |
+
metrics:
|
10 |
+
- rouge
|
11 |
+
pipeline_tag: text2text-generation
|
12 |
+
---
|
13 |
+
|
14 |
+
# Model Card for mTk-SQuAD_en-SQAD_cs-1B
|
15 |
+
|
16 |
+
This model is a generative in-context few-shot learner specialized in Czech. It was trained on a combination of English SQuAD and Czech SQAD dataset.
|
17 |
+
|
18 |
+
You can find detailed information on [Project Github](https://github.com/fewshot-goes-multilingual/slavic-incontext-learning) & the referenced paper.
|
19 |
+
|
20 |
+
|
21 |
+
## Model Details
|
22 |
+
|
23 |
+
### Model Description
|
24 |
+
|
25 |
+
|
26 |
+
- **Developed by:** Michal Stefanik & Marek Kadlcik, Masaryk University
|
27 |
+
- **Model type:** mt5
|
28 |
+
- **Language(s) (NLP):** cs,en
|
29 |
+
- **License:** MIT
|
30 |
+
- **Finetuned from model:** google/mt5-large
|
31 |
+
|
32 |
+
### Model Sources
|
33 |
+
|
34 |
+
- **Repository:** https://github.com/fewshot-goes-multilingual/slavic-incontext-learning
|
35 |
+
- **Paper:** [To be filled]
|
36 |
+
|
37 |
+
## Uses
|
38 |
+
|
39 |
+
This model is intended to be used in a few-shot in-context learning format in the target language (Czech), or in the source language (English, see below).
|
40 |
+
It was evaluated for unseen task learning (with k=3 demonstrations) in Czech: see the referenced paper for details.
|
41 |
+
|
42 |
+
### How to Get Started with the Model
|
43 |
+
|
44 |
+
Use the code below to get started with the model.
|
45 |
+
|
46 |
+
```python
|
47 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
48 |
+
|
49 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("{this model path}")
|
50 |
+
tokenizer = AutoTokenizer.from_pretrained("{this model path}")
|
51 |
+
|
52 |
+
# Instead, use keywords "Otázka", "Kontext" and "Odpověď" for Czech few-shot prompts
|
53 |
+
input_text = """
|
54 |
+
Question: What is the customer's name?
|
55 |
+
Context: Origin: Barrack Obama, Customer id: Bill Moe.
|
56 |
+
Answer: Bill Moe,
|
57 |
+
Question: What is the customer's name?
|
58 |
+
Context: Customer id: Barrack Obama, if not deliverable, return to Bill Clinton.
|
59 |
+
Answer:
|
60 |
+
"""
|
61 |
+
|
62 |
+
inputs = tokenizer(input_text, return_tensors="pt")
|
63 |
+
|
64 |
+
outputs = model.generate(**inputs)
|
65 |
+
|
66 |
+
print("Answer:")
|
67 |
+
print(tokenizer.decode(outputs))
|
68 |
+
```
|
69 |
+
|
70 |
+
## Training Details
|
71 |
+
|
72 |
+
Training this model can be reproduced by running `pip install -r requirements.txt && python train.py`.
|
73 |
+
See the referenced script for hyperparameters and other training configurations.
|
74 |
+
|
75 |
+
## Citation
|
76 |
+
|
77 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
78 |
+
|
79 |
+
**BibTeX:**
|
80 |
+
|
81 |
+
[Will be filled soon]
|
czech_squad_4-sents.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:95083cfcb45f6eb1d2640d35af360e6100ce5ef01182838e86b7ac8d4b4e0d9b
|
3 |
+
size 13358790
|
priming_objective.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import random
|
3 |
+
from typing import Iterable, Union, Dict, List, Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from adaptor.objectives.seq2seq import Sequence2Sequence
|
7 |
+
from transformers import BatchEncoding
|
8 |
+
|
9 |
+
logger = logging.getLogger()
|
10 |
+
|
11 |
+
priming_formats = {
|
12 |
+
"QA": {"cs": "Otázka: %s Kontext: %s Odpověď:",
|
13 |
+
"en": "Question: %s Context: %s Answer:",
|
14 |
+
"ru": "Вопрос: %s Контекст: %s Отвечать:"}}
|
15 |
+
|
16 |
+
|
17 |
+
class Priming(Sequence2Sequence):
|
18 |
+
|
19 |
+
def __init__(self, *args,
|
20 |
+
train_question_categories: Iterable[str],
|
21 |
+
max_eval_samples: int,
|
22 |
+
val_question_categories: Optional[Iterable[str]] = None,
|
23 |
+
min_num_demonstrations: int = 2,
|
24 |
+
max_num_demonstrations: int = 5,
|
25 |
+
demos_infer_batch_size: int = 32,
|
26 |
+
demos_selection_strategy: str = "hard",
|
27 |
+
difficulty_sample: int = 64,
|
28 |
+
max_input_length: int = 8000,
|
29 |
+
**kwargs):
|
30 |
+
super().__init__(*args, **kwargs)
|
31 |
+
|
32 |
+
self.train_question_categories = list(train_question_categories)
|
33 |
+
self.val_question_categories = list(val_question_categories) if val_question_categories is not None else None
|
34 |
+
|
35 |
+
self.min_num_demonstrations = min_num_demonstrations
|
36 |
+
self.max_num_demonstrations = max_num_demonstrations
|
37 |
+
self.demos_infer_batch_size = demos_infer_batch_size
|
38 |
+
self.demos_selection_strategy = demos_selection_strategy
|
39 |
+
self.difficulty_sample = difficulty_sample
|
40 |
+
self.max_input_length = max_input_length
|
41 |
+
self.max_eval_samples = max_eval_samples
|
42 |
+
|
43 |
+
def _construct_qa_prompt(self, question: str, context: str) -> str:
|
44 |
+
return priming_formats["QA"][self.source_lang_id] % (question, context)
|
45 |
+
|
46 |
+
def _construct_demonstration(self, prompt: str, answer: str) -> str:
|
47 |
+
return "%s %s " % (prompt, answer)
|
48 |
+
|
49 |
+
def _construct_primed_prompt(self, primed_demonstrations: List[str], prompt: str) -> str:
|
50 |
+
return " ".join(primed_demonstrations) + " " + prompt
|
51 |
+
|
52 |
+
def forced_generation_score(self, input_texts: List[str], forced_output: str) -> torch.FloatTensor:
|
53 |
+
inputs = self.tokenizer(input_texts, return_tensors="pt", padding="longest", truncation=True)
|
54 |
+
inputs = inputs.to(self.compatible_head_model.device)
|
55 |
+
|
56 |
+
with self.tokenizer.as_target_tokenizer():
|
57 |
+
output_ids = self.tokenizer(forced_output, return_tensors="pt", padding="longest",
|
58 |
+
truncation=True).input_ids.to(self.compatible_head_model.device)
|
59 |
+
forced_outputs = self.compatible_head_model.prepare_decoder_input_ids_from_labels(output_ids)
|
60 |
+
forced_outputs = forced_outputs.to(self.compatible_head_model.device)
|
61 |
+
|
62 |
+
outputs = self.compatible_head_model(**inputs,
|
63 |
+
decoder_input_ids=forced_outputs.expand(inputs.input_ids.shape[0], -1))
|
64 |
+
output_log_probs = outputs.logits.log_softmax(-1)
|
65 |
+
forced_output_logits = torch.gather(output_log_probs, -1,
|
66 |
+
output_ids.expand(inputs.input_ids.shape[0], -1).unsqueeze(-1))
|
67 |
+
forced_output_log_score = forced_output_logits.sum((-1, -2))
|
68 |
+
# we do not need to normalize, as all the targets are the same <=> same length
|
69 |
+
return forced_output_log_score.double().exp()
|
70 |
+
|
71 |
+
def _pick_most_difficult_demo(self,
|
72 |
+
selected_demos: List[str],
|
73 |
+
next_demo_cands: List[str],
|
74 |
+
predict_prompt: str,
|
75 |
+
predicted_answer: str) -> int:
|
76 |
+
with torch.no_grad():
|
77 |
+
difficulties = torch.empty(0, device=self.compatible_head_model.device, dtype=torch.float)
|
78 |
+
|
79 |
+
for batch_offset in range(0, len(next_demo_cands), self.demos_infer_batch_size):
|
80 |
+
next_demo_cands_batch = next_demo_cands[batch_offset: batch_offset + self.demos_infer_batch_size]
|
81 |
+
|
82 |
+
primed_prompts = [self._construct_primed_prompt(selected_demos + [demo], predict_prompt)
|
83 |
+
for demo in next_demo_cands_batch]
|
84 |
+
cands_difficulty = self.forced_generation_score(primed_prompts, predicted_answer)
|
85 |
+
|
86 |
+
difficulties = torch.hstack((difficulties, cands_difficulty))
|
87 |
+
|
88 |
+
assert difficulties.argmin() < len(next_demo_cands)
|
89 |
+
|
90 |
+
return difficulties.argmin()
|
91 |
+
|
92 |
+
def _get_inputs_iterator(self, split: str) -> Iterable[Union[BatchEncoding, Dict[str, torch.Tensor]]]:
|
93 |
+
"""
|
94 |
+
Creates a default iterator over encodings with aligned input and output texts.
|
95 |
+
:param split: Data split. `train` or `eval`.
|
96 |
+
:return: Iterator of model input encodings.
|
97 |
+
"""
|
98 |
+
# we materialize all samples in memory, so that we can heuristically pick the combinations
|
99 |
+
questions, contexts, answers = (list(it) for it in self._per_split_iterators(split))
|
100 |
+
question_categories = self.train_question_categories if split == "train" else self.val_question_categories
|
101 |
+
|
102 |
+
assert len(questions) == len(contexts) == len(answers) == len(question_categories), \
|
103 |
+
"Given numbers of questions, contexts and answers do not match."
|
104 |
+
|
105 |
+
prompts = [self._construct_qa_prompt(q, c) for q, c in zip(questions, contexts)]
|
106 |
+
|
107 |
+
features_batch = []
|
108 |
+
cat_index = {cat: [i for i, sample_cat in enumerate(question_categories) if cat == sample_cat]
|
109 |
+
for cat in set(question_categories)}
|
110 |
+
|
111 |
+
retrieved_samples = 0
|
112 |
+
|
113 |
+
for idx, sample_category in enumerate(question_categories):
|
114 |
+
if not cat_index[sample_category]:
|
115 |
+
logger.warning("No samples within the category %s", sample_category)
|
116 |
+
continue
|
117 |
+
|
118 |
+
pred_prompt, pred_answer = prompts[idx], answers[idx]
|
119 |
+
|
120 |
+
picked_demonstrations = []
|
121 |
+
|
122 |
+
# a number of demonstrations is in the specified range
|
123 |
+
expected_num_demonstrations = random.randint(self.min_num_demonstrations, self.max_num_demonstrations)
|
124 |
+
|
125 |
+
while len(picked_demonstrations) < expected_num_demonstrations:
|
126 |
+
if sum(map(len, picked_demonstrations)) > self.max_input_length:
|
127 |
+
logger.warning("Skipping too long prompt.")
|
128 |
+
break
|
129 |
+
if self.demos_selection_strategy == "hard":
|
130 |
+
# pick the most difficult examples out of a sample
|
131 |
+
# we do not need to worry for picking up the predicted sample among demonstrations in hard strategy
|
132 |
+
if len(cat_index[sample_category]) <= 1:
|
133 |
+
# we can not construct informative demonstrations for categories of a single item
|
134 |
+
break
|
135 |
+
|
136 |
+
samples_idx = random.choices(cat_index[sample_category], k=self.difficulty_sample)
|
137 |
+
cand_demonstrations = [self._construct_demonstration(prompts[i], answers[i]) for i in samples_idx]
|
138 |
+
selected_index = self._pick_most_difficult_demo(picked_demonstrations, cand_demonstrations,
|
139 |
+
pred_prompt, pred_answer)
|
140 |
+
picked_demonstrations.append(cand_demonstrations[selected_index])
|
141 |
+
elif self.demos_selection_strategy == "informative":
|
142 |
+
if len(cat_index[sample_category]) <= 1:
|
143 |
+
# we can not construct informative demonstrations for categories of a single item
|
144 |
+
break
|
145 |
+
selected_cat_index = random.randint(1, len(cat_index[sample_category])-1)
|
146 |
+
selected_index = cat_index[sample_category][selected_cat_index]
|
147 |
+
if selected_index == idx:
|
148 |
+
# we do not want to expose the predicted sample in demonstrations
|
149 |
+
selected_index = cat_index[sample_category][selected_cat_index-1]
|
150 |
+
picked_demonstration = self._construct_demonstration(prompts[selected_index],
|
151 |
+
answers[selected_index])
|
152 |
+
picked_demonstrations.append(picked_demonstration)
|
153 |
+
elif self.demos_selection_strategy == "random":
|
154 |
+
# evaluation: do not infer samples' difficulty, pick randomly
|
155 |
+
selected_index = random.randint(1, len(prompts)-1)
|
156 |
+
if selected_index == idx:
|
157 |
+
# we do not want to expose the predicted sample in demonstrations
|
158 |
+
selected_index -= 1
|
159 |
+
picked_demonstration = self._construct_demonstration(prompts[selected_index],
|
160 |
+
answers[selected_index])
|
161 |
+
picked_demonstrations.append(picked_demonstration)
|
162 |
+
else:
|
163 |
+
raise ValueError("Unknown demon selection strategy: '%s'" % self.demos_selection_strategy)
|
164 |
+
if len(picked_demonstrations) != expected_num_demonstrations:
|
165 |
+
# we omit examples with none or only one demonstration in the category
|
166 |
+
continue
|
167 |
+
|
168 |
+
# encode a yielded batch
|
169 |
+
primed_prompt = self._construct_primed_prompt(picked_demonstrations, pred_prompt)
|
170 |
+
|
171 |
+
primed_prompt_encoding = self.tokenizer(primed_prompt, truncation=True)
|
172 |
+
label_encoding = self.tokenizer(pred_answer, truncation=True)
|
173 |
+
|
174 |
+
features_batch.append({"input_ids": primed_prompt_encoding.input_ids,
|
175 |
+
"attention_mask": primed_prompt_encoding.attention_mask,
|
176 |
+
"labels": label_encoding.input_ids})
|
177 |
+
if len(features_batch) == self.batch_size:
|
178 |
+
yield self.collator(features_batch)
|
179 |
+
features_batch = []
|
180 |
+
|
181 |
+
retrieved_samples += 1
|
182 |
+
if split == "eval" and retrieved_samples >= self.max_eval_samples:
|
183 |
+
# custom evaluation break - we need all samples in set to match categories,
|
184 |
+
# but do not want to iterate them all
|
185 |
+
break
|
186 |
+
|
187 |
+
if features_batch:
|
188 |
+
# yield last nonempty residual batch
|
189 |
+
yield self.collator(features_batch)
|
190 |
+
|
191 |
+
def _compute_loss(self,
|
192 |
+
lm_logit_outputs: torch.FloatTensor,
|
193 |
+
labels: torch.LongTensor,
|
194 |
+
inputs: Optional[Union[BatchEncoding, Dict[str, torch.Tensor]]] = None) -> torch.FloatTensor:
|
195 |
+
# customization for mt5 model, with incorrectly-set tokenizer.vocab_size
|
196 |
+
# This should be fixed in upcoming release of adaptor (>=0.1.5)
|
197 |
+
loss_fct = torch.nn.CrossEntropyLoss()
|
198 |
+
lm_loss = loss_fct(lm_logit_outputs.flatten(end_dim=1), labels.flatten())
|
199 |
+
|
200 |
+
return lm_loss
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
adaptor[generative]==0.2.0
|
2 |
+
torch==1.11.0
|
3 |
+
pandas
|
4 |
+
nltk
|
train_mt5_qa_en_SQuAD+cs_random.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from adaptor.adapter import Adapter
|
5 |
+
from adaptor.evaluators.generative import BLEU
|
6 |
+
from adaptor.lang_module import LangModule
|
7 |
+
from adaptor.schedules import ParallelSchedule
|
8 |
+
from adaptor.utils import AdaptationArguments, StoppingStrategy
|
9 |
+
from datasets import load_dataset
|
10 |
+
|
11 |
+
from priming_objective import Priming
|
12 |
+
|
13 |
+
training_arguments = AdaptationArguments(output_dir="train_dir_SQuAD_random_large",
|
14 |
+
learning_rate=2e-5, # we set LR=2e-4 for pre-training experiments
|
15 |
+
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED,
|
16 |
+
# stopping_strategy=StoppingStrategy.NUM_STEPS_TOTAL,
|
17 |
+
do_train=True,
|
18 |
+
do_eval=True,
|
19 |
+
warmup_steps=1000,
|
20 |
+
max_steps=10000,
|
21 |
+
gradient_accumulation_steps=30,
|
22 |
+
eval_steps=500,
|
23 |
+
logging_steps=10,
|
24 |
+
save_steps=500,
|
25 |
+
num_train_epochs=5,
|
26 |
+
evaluation_strategy="steps",
|
27 |
+
save_total_limit=10,
|
28 |
+
stopping_patience=10)
|
29 |
+
eval_examples = 200
|
30 |
+
|
31 |
+
# priming
|
32 |
+
num_demonstrations = 3
|
33 |
+
|
34 |
+
|
35 |
+
def _construct_priming_prompt(previous_examples: List[str], current_example: str) -> str:
|
36 |
+
return " ".join(previous_examples + [current_example])
|
37 |
+
|
38 |
+
|
39 |
+
lang_module = LangModule("google/mt5-large")
|
40 |
+
|
41 |
+
# priming
|
42 |
+
per_type_examples = {}
|
43 |
+
|
44 |
+
qa_en = load_dataset("squad")
|
45 |
+
qa_train = qa_en["train"].filter(lambda entry: len(entry["context"]) < 2000)
|
46 |
+
|
47 |
+
val_metrics = [BLEU(**{"additional_sep_char": "▁"})]
|
48 |
+
|
49 |
+
# SQuAD QA dataset & objective:
|
50 |
+
|
51 |
+
|
52 |
+
def _get_en_qa_categories(data) -> List[str]:
|
53 |
+
return [question.split()[0] if not question.startswith("To")
|
54 |
+
else " ".join(question.split()[:2])
|
55 |
+
for question in data["question"]]
|
56 |
+
|
57 |
+
|
58 |
+
q_answering_en = Priming(lang_module,
|
59 |
+
max_eval_samples=eval_examples,
|
60 |
+
demos_selection_strategy="random",
|
61 |
+
texts_or_path=qa_train["question"],
|
62 |
+
text_pair_or_path=qa_train["context"],
|
63 |
+
val_texts_or_path=qa_en["validation"]["question"][-eval_examples:],
|
64 |
+
val_text_pair_or_path=qa_en["validation"]["context"][-eval_examples:],
|
65 |
+
labels_or_path=[a["text"][0] for a in qa_train["answers"]],
|
66 |
+
val_labels_or_path=[a["text"][0] for a in qa_en["validation"]["answers"]][-eval_examples:],
|
67 |
+
train_question_categories=_get_en_qa_categories(qa_train),
|
68 |
+
val_question_categories=_get_en_qa_categories(qa_en["validation"])[-eval_examples:],
|
69 |
+
batch_size=1,
|
70 |
+
val_evaluators=val_metrics,
|
71 |
+
# val_evaluators=val_metrics,
|
72 |
+
source_lang_id="en",
|
73 |
+
objective_id="AQA-en")
|
74 |
+
|
75 |
+
# Czech data & objective
|
76 |
+
|
77 |
+
squad_cs_dataset = json.load(open("czech_squad_4-sents.json"))
|
78 |
+
|
79 |
+
skipped = 0
|
80 |
+
|
81 |
+
questions_cs = []
|
82 |
+
contexts_cs = []
|
83 |
+
answers_cs = []
|
84 |
+
categories_cs = []
|
85 |
+
|
86 |
+
for i, entry in squad_cs_dataset.items():
|
87 |
+
if len(entry["context"]) > 800:
|
88 |
+
skipped += 1
|
89 |
+
continue
|
90 |
+
|
91 |
+
questions_cs.append(entry["question"])
|
92 |
+
contexts_cs.append(entry["context"])
|
93 |
+
answers_cs.append(entry["answers"]["text"][0])
|
94 |
+
categories_cs.append(entry["answer_type"])
|
95 |
+
|
96 |
+
print("Skipped cs examples: %s" % skipped)
|
97 |
+
|
98 |
+
q_answering_cs = Priming(lang_module,
|
99 |
+
max_eval_samples=eval_examples,
|
100 |
+
demos_selection_strategy="random",
|
101 |
+
texts_or_path=questions_cs[:-eval_examples],
|
102 |
+
text_pair_or_path=contexts_cs[:-eval_examples],
|
103 |
+
val_texts_or_path=questions_cs[-eval_examples:],
|
104 |
+
val_text_pair_or_path=contexts_cs[-eval_examples:],
|
105 |
+
labels_or_path=answers_cs[:-eval_examples],
|
106 |
+
val_labels_or_path=answers_cs[-eval_examples:],
|
107 |
+
train_question_categories=categories_cs[:-eval_examples],
|
108 |
+
val_question_categories=categories_cs[-eval_examples:],
|
109 |
+
batch_size=1,
|
110 |
+
val_evaluators=val_metrics,
|
111 |
+
source_lang_id="cs",
|
112 |
+
objective_id="SQUAD-cs")
|
113 |
+
|
114 |
+
schedule = ParallelSchedule(objectives=[q_answering_en, q_answering_cs],
|
115 |
+
args=training_arguments)
|
116 |
+
|
117 |
+
adapter = Adapter(lang_module, schedule, args=training_arguments)
|
118 |
+
adapter.train()
|