Spaces:
Sleeping
Sleeping
Merge pull request #1 from NicBonetto/initialize-model
Browse files- .gitignore +3 -0
- config.py +8 -0
- core/__init__.py +0 -0
- core/data.py +27 -0
- core/model.py +13 -0
- core/train.py +32 -0
- core/utils.py +15 -0
- requirements.txt +8 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
core/__pycache__/
|
3 |
+
trocr-ocr/
|
config.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATASET_NAME = 'Teklia/IAM-line'
|
2 |
+
MODEL_NAME = 'microsoft/trocr-small-printed'
|
3 |
+
TRAIN_SPLIT = 'train[:200]'
|
4 |
+
TEST_SPLIT_RATIO = 0.2
|
5 |
+
BATCH_SIZE = 2
|
6 |
+
EPOCHS = 1
|
7 |
+
OUTPUT_DIR = './trocr-ocr'
|
8 |
+
|
core/__init__.py
ADDED
File without changes
|
core/data.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from datasets import load_dataset
|
3 |
+
from config import DATASET_NAME, TRAIN_SPLIT, TEST_SPLIT_RATIO
|
4 |
+
from core.model import processor
|
5 |
+
|
6 |
+
def preprocess_batch(batch):
|
7 |
+
images = [img.convert("RGB") for img in batch["image"]]
|
8 |
+
pixel_values = processor(images=images, return_tensors="pt").pixel_values
|
9 |
+
batch["pixel_values"] = pixel_values
|
10 |
+
|
11 |
+
labels = processor.tokenizer(batch["text"], padding=True, truncation=True).input_ids
|
12 |
+
batch["labels"] = labels
|
13 |
+
|
14 |
+
return batch
|
15 |
+
|
16 |
+
def load():
|
17 |
+
dataset = load_dataset(DATASET_NAME, split = TRAIN_SPLIT)
|
18 |
+
train_test = dataset.train_test_split(test_size = TEST_SPLIT_RATIO)
|
19 |
+
train_ds = train_test['train']
|
20 |
+
eval_ds = train_test['test']
|
21 |
+
|
22 |
+
train_ds = train_ds.map(preprocess_batch, batched = True, remove_columns = train_ds.column_names)
|
23 |
+
eval_ds = eval_ds.map(preprocess_batch, batched = True, remove_columns = eval_ds.column_names)
|
24 |
+
|
25 |
+
return train_ds, eval_ds
|
26 |
+
|
27 |
+
|
core/model.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import VisionEncoderDecoderModel, TrOCRProcessor
|
3 |
+
from config import MODEL_NAME
|
4 |
+
|
5 |
+
device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
|
6 |
+
processor = TrOCRProcessor.from_pretrained(MODEL_NAME, use_fast = False)
|
7 |
+
model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME).to(device)
|
8 |
+
|
9 |
+
model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
|
10 |
+
|
11 |
+
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
12 |
+
if model.config.pad_token_id is None:
|
13 |
+
model.config.pad_token_id = processor.tokenizer.eos_token_id
|
core/train.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
|
2 |
+
from core.model import model, processor
|
3 |
+
from core.data import load
|
4 |
+
from core.utils import compute_metrics
|
5 |
+
from config import OUTPUT_DIR, BATCH_SIZE, EPOCHS
|
6 |
+
|
7 |
+
train_ds, eval_ds = load()
|
8 |
+
|
9 |
+
training_args = Seq2SeqTrainingArguments(
|
10 |
+
output_dir = OUTPUT_DIR,
|
11 |
+
per_device_train_batch_size = BATCH_SIZE,
|
12 |
+
per_device_eval_batch_size = BATCH_SIZE,
|
13 |
+
predict_with_generate = True,
|
14 |
+
eval_strategy = 'epoch',
|
15 |
+
logging_steps = 10,
|
16 |
+
num_train_epochs = EPOCHS,
|
17 |
+
save_total_limit = 1,
|
18 |
+
fp16 = False
|
19 |
+
)
|
20 |
+
|
21 |
+
trainer = Seq2SeqTrainer(
|
22 |
+
model = model,
|
23 |
+
args = training_args,
|
24 |
+
train_dataset = train_ds,
|
25 |
+
eval_dataset = eval_ds,
|
26 |
+
processing_class = processor.image_processor,
|
27 |
+
data_collator = default_data_collator,
|
28 |
+
compute_metrics = compute_metrics
|
29 |
+
)
|
30 |
+
|
31 |
+
if __name__ == '__main__':
|
32 |
+
trainer.train()
|
core/utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import evaluate
|
2 |
+
from core.model import processor
|
3 |
+
|
4 |
+
cer_metric = evaluate.load('cer')
|
5 |
+
|
6 |
+
def compute_metrics(pred):
|
7 |
+
pred_ids = pred.predictions
|
8 |
+
label_ids = pred.label_ids
|
9 |
+
pred_str = processor.batch_decode(pred_ids, skip_special_tokens = True)
|
10 |
+
label_ids[ label_ids == -100 ] = processor.tokenizer.pad_token_id
|
11 |
+
label_str = processor.batch_decode(label_ids, skip_special_tokens = True)
|
12 |
+
cer = cer_metric.compute(predictions = pred_str, references = label_str)
|
13 |
+
|
14 |
+
return { 'cer': cer }
|
15 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
transformers
|
4 |
+
datasets
|
5 |
+
evaluate
|
6 |
+
jiwer
|
7 |
+
Pillow
|
8 |
+
|