|
import json |
|
|
|
from adaptor.adapter import Adapter |
|
from adaptor.evaluators.generative import BLEU |
|
from adaptor.lang_module import LangModule |
|
from adaptor.objectives.seq2seq import Sequence2Sequence |
|
from adaptor.schedules import ParallelSchedule |
|
from adaptor.utils import AdaptationArguments, StoppingStrategy |
|
from datasets import load_dataset |
|
|
|
training_arguments = AdaptationArguments(output_dir="train_dir", |
|
learning_rate=5e-5, |
|
|
|
stopping_strategy=StoppingStrategy.ALL_OBJECTIVES_CONVERGED, |
|
do_train=True, |
|
do_eval=True, |
|
warmup_steps=1000, |
|
max_steps=100000, |
|
gradient_accumulation_steps=4, |
|
eval_steps=100, |
|
logging_steps=10, |
|
save_steps=1000, |
|
num_train_epochs=50, |
|
evaluation_strategy="steps", |
|
remove_unused_columns=False) |
|
|
|
|
|
lang_module = LangModule("Helsinki-NLP/opus-mt-en-cs") |
|
|
|
metrics_args = {"additional_sep_char": "▁"} |
|
|
|
val_metrics = [BLEU(**metrics_args, decides_convergence=True)] |
|
|
|
squad_en = load_dataset("squad") |
|
squad_train = squad_en["train"].filter(lambda entry: len(entry["context"]) < 2000) |
|
|
|
train_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_train["question"], |
|
squad_train["context"])] |
|
val_contexts_questions_en = ["question: %s context: %s" % (q, c) for q, c in zip(squad_en["validation"]["question"], |
|
squad_en["validation"]["context"])] |
|
train_answers_en = [a["text"][0] for a in squad_train["answers"]] |
|
val_answers_en = [a["text"][0] for a in squad_en["validation"]["answers"]] |
|
|
|
generative_qa_en = Sequence2Sequence(lang_module, |
|
texts_or_path=train_contexts_questions_en, |
|
val_texts_or_path=val_contexts_questions_en[:200], |
|
labels_or_path=train_answers_en, |
|
val_labels_or_path=val_answers_en[:200], |
|
batch_size=8, |
|
val_evaluators=val_metrics, |
|
objective_id="SQUAD-en") |
|
|
|
squad_dataset = json.load(open("data/czech_squad.json")) |
|
|
|
contexts_questions = [] |
|
answers = [] |
|
|
|
for i, entry in squad_dataset.items(): |
|
contexts_questions.append("question: %s context: %s" % (entry["question"], entry["context"])) |
|
answers.append(entry["answers"]["text"][0]) |
|
|
|
train_contexts_questions = contexts_questions[:-200] |
|
val_contexts_questions = contexts_questions[-200:] |
|
train_answers = answers[:-200] |
|
val_answers = answers[-200:] |
|
|
|
generative_qa_cs = Sequence2Sequence(lang_module, |
|
texts_or_path=train_contexts_questions, |
|
val_texts_or_path=val_contexts_questions[:200], |
|
labels_or_path=train_answers, |
|
val_labels_or_path=val_answers[:200], |
|
batch_size=8, |
|
val_evaluators=val_metrics, |
|
objective_id="SQUAD-cs") |
|
|
|
schedule = ParallelSchedule(objectives=[generative_qa_en, generative_qa_cs], |
|
args=training_arguments) |
|
|
|
adapter = Adapter(lang_module, schedule, args=training_arguments) |
|
adapter.train() |
|
|