|
--- |
|
license: apache-2.0 |
|
language: |
|
- en |
|
- la |
|
base_model: |
|
- google/mt5-small |
|
--- |
|
Demonstration of fine-tuning of mt5-small for C17th English (and Latin) legal depositions. |
|
Uses mt5-small, which is trained on the mC4 common crawal dataset containing 101 languages, including some Latin. |
|
mt5-small is the smallest of five variants of mt5 (small; base; large; XL; XXL). |
|
Fine-tuned with text to text pairs of raw-HTR and hand-corrected Ground Truth from C17th English High Court of Admiralty depositions. |
|
|
|
A series of fine-tuned mt5-small models will be created with ascending version numbers. |
|
|
|
Training dataset = 80%; validation dataset = 20%. |
|
mt5Tokenizer. |
|
PyTorch datasets. |
|
T5ForConditionalGeneration model. |
|
CER/WER evaluation; Qualitative evaluation (e.g. capitalisation; HTR error correction). |
|
Train using Nvidia T4 small 15 GB $0.40/hour. |
|
|
|
MT5TOKENIZER |
|
Python |
|
|
|
from transformers import T5Tokenizer |
|
|
|
tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") |
|
|
|
TOKENIZE DATA |
|
Python |
|
|
|
train_encodings = tokenizer(list(train_inputs), text_target=list(train_targets), truncation=True, padding=True) |
|
val_encodings = tokenizer(list(val_inputs), text_target=list(val_targets), truncation=True, padding=True) |
|
|
|
CREATE PYTORCH DATASETS |
|
Python |
|
|
|
import torch |
|
|
|
class HTRDataset(torch.utils.data.Dataset): |
|
def __init__(self, encodings): |
|
self.encodings = encodings |
|
|
|
def __getitem__(self, idx): |
|
return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} |
|
|
|
def __len__(self): |
|
return len(self.encodings.input_ids) |
|
|
|
|
|
|
|
train_dataset = HTRDataset(train_encodings) |
|
val_dataset = HTRDataset(val_encodings) |
|
|
|
FINE-TUNING WITH TRANSFORMERS |
|
Python |
|
|
|
from transformers import T5ForConditionalGeneration |
|
|
|
model = T5ForConditionalGeneration.from_pretrained("google/mt5-small") |
|
|
|
TRAINING ARGUMENTS: |
|
python |
|
|
|
training_args = TrainingArguments( |
|
output_dir="./results", |
|
per_device_train_batch_size=8, # Or 16 if your GPU has enough memory |
|
per_device_eval_batch_size=8, # Same as train batch size |
|
learning_rate=1e-4, |
|
num_train_epochs=3, # Or 5 |
|
evaluation_strategy="epoch", |
|
save_strategy="epoch", |
|
fp16=True, # If your GPU supports it, for faster training |
|
# ... other arguments ... |
|
) |
|
|
|
EARLY STOPPING: |
|
python |
|
|
|
training_args = TrainingArguments( |
|
# ... other arguments ... |
|
evaluation_strategy="epoch", |
|
load_best_model_at_end=True, |
|
metric_for_best_model="eval_loss", |
|
early_stopping_patience=3 # Optional |
|
) |
|
|
|
CREATE TRAINER AND FINE-TUNE |
|
Python |
|
|
|
from transformers import Trainer |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=val_dataset |
|
|
|
|
|
) |
|
|
|
trainer.train() |
|
|
|
--- |
|
Fine-tuning data experiments will include: |
|
|
|
* Using 1000 lines of raw-HTR paired with 1000 lines of hand corrected Ground Truth |
|
* Using 2000 lines of raw-HTR paired with 1000 lines of hand corrected Ground Truth |
|
* Using 1000 and 2000 lines of synthetic raw-HTR paired with 1000 lines of handcorrected Ground Truth |
|
--- |
|
Hyper-parameter experients will include: |
|
|
|
* Adjusting batch size from 8 paired-lines to 16 paired-lines |
|
* Adjusting epochs from 3 to 5 epochs |
|
* Adjusting learning rate |
|
** Start with a learning rate of 1e-4 (0.0001). This is a common starting point for fine-tuning transformer models. |
|
** Experiment with slightly higher or lower values (e.g., 5e-4 or 5e-5) in later experiments |
|
* Adjusting earlystopping settings |
|
|
|
|