nbonetto commited on
Commit
812c941
·
1 Parent(s): 9fe0a0f

feat: create streamlit app

Browse files
Files changed (9) hide show
  1. .gitignore +3 -0
  2. app.py +24 -0
  3. config.py +8 -0
  4. core/__init__.py +0 -0
  5. core/data.py +27 -0
  6. core/model.py +13 -0
  7. core/train.py +32 -0
  8. core/utils.py +15 -0
  9. requirements.txt +9 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/
2
+ core/__pycache__/
3
+ trocr-ocr/
app.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit
2
+ from PIL import Image
3
+ import torch
4
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
5
+
6
+ MODEL_NAME = 'model/'
7
+ processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
8
+ model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
9
+
10
+ streamlit.title('Light OCR')
11
+
12
+ uploaded_file = streamlit.file_uploader('Choose an image...', type=['png', 'jpg', 'jpeg'])
13
+ if uploaded_file:
14
+ image = Image.open(uploaded_file).convert("RGB")
15
+ image = image.resize((384, 384))
16
+ streamlit.image(image, caption='Uploaded Image', use_column_width=True)
17
+
18
+ pixel_values = processor(images=image, return_tensors='pt').pixel_values
19
+ output_ids = model.generate(pixel_values)
20
+ text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
21
+
22
+ streamlit.subheader('Recognized Text')
23
+ streamlit.write(text)
24
+
config.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ DATASET_NAME = 'Teklia/IAM-line'
2
+ MODEL_NAME = 'microsoft/trocr-small-printed'
3
+ TRAIN_SPLIT = 'train[:200]'
4
+ TEST_SPLIT_RATIO = 0.2
5
+ BATCH_SIZE = 2
6
+ EPOCHS = 1
7
+ OUTPUT_DIR = './trocr-ocr'
8
+
core/__init__.py ADDED
File without changes
core/data.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from datasets import load_dataset
3
+ from config import DATASET_NAME, TRAIN_SPLIT, TEST_SPLIT_RATIO
4
+ from core.model import processor
5
+
6
+ def preprocess_batch(batch):
7
+ images = [img.convert("RGB") for img in batch["image"]]
8
+ pixel_values = processor(images=images, return_tensors="pt").pixel_values
9
+ batch["pixel_values"] = pixel_values
10
+
11
+ labels = processor.tokenizer(batch["text"], padding=True, truncation=True).input_ids
12
+ batch["labels"] = labels
13
+
14
+ return batch
15
+
16
+ def load():
17
+ dataset = load_dataset(DATASET_NAME, split = TRAIN_SPLIT)
18
+ train_test = dataset.train_test_split(test_size = TEST_SPLIT_RATIO)
19
+ train_ds = train_test['train']
20
+ eval_ds = train_test['test']
21
+
22
+ train_ds = train_ds.map(preprocess_batch, batched = True, remove_columns = train_ds.column_names)
23
+ eval_ds = eval_ds.map(preprocess_batch, batched = True, remove_columns = eval_ds.column_names)
24
+
25
+ return train_ds, eval_ds
26
+
27
+
core/model.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import VisionEncoderDecoderModel, TrOCRProcessor
3
+ from config import MODEL_NAME
4
+
5
+ device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
6
+ processor = TrOCRProcessor.from_pretrained(MODEL_NAME, use_fast = False)
7
+ model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME).to(device)
8
+
9
+ model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
10
+
11
+ model.config.pad_token_id = processor.tokenizer.pad_token_id
12
+ if model.config.pad_token_id is None:
13
+ model.config.pad_token_id = processor.tokenizer.eos_token_id
core/train.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
2
+ from core.model import model, processor
3
+ from core.data import load
4
+ from core.utils import compute_metrics
5
+ from config import OUTPUT_DIR, BATCH_SIZE, EPOCHS
6
+
7
+ train_ds, eval_ds = load()
8
+
9
+ training_args = Seq2SeqTrainingArguments(
10
+ output_dir = OUTPUT_DIR,
11
+ per_device_train_batch_size = BATCH_SIZE,
12
+ per_device_eval_batch_size = BATCH_SIZE,
13
+ predict_with_generate = True,
14
+ eval_strategy = 'epoch',
15
+ logging_steps = 10,
16
+ num_train_epochs = EPOCHS,
17
+ save_total_limit = 1,
18
+ fp16 = False
19
+ )
20
+
21
+ trainer = Seq2SeqTrainer(
22
+ model = model,
23
+ args = training_args,
24
+ train_dataset = train_ds,
25
+ eval_dataset = eval_ds,
26
+ processing_class = processor.image_processor,
27
+ data_collator = default_data_collator,
28
+ compute_metrics = compute_metrics
29
+ )
30
+
31
+ if __name__ == '__main__':
32
+ trainer.train()
core/utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ from core.model import processor
3
+
4
+ cer_metric = evaluate.load('cer')
5
+
6
+ def compute_metrics(pred):
7
+ pred_ids = pred.predictions
8
+ label_ids = pred.label_ids
9
+ pred_str = processor.batch_decode(pred_ids, skip_special_tokens = True)
10
+ label_ids[ label_ids == -100 ] = processor.tokenizer.pad_token_id
11
+ label_str = processor.batch_decode(label_ids, skip_special_tokens = True)
12
+ cer = cer_metric.compute(predictions = pred_str, references = label_str)
13
+
14
+ return { 'cer': cer }
15
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transformers
4
+ datasets
5
+ evaluate
6
+ jiwer
7
+ Pillow
8
+ accelerator
9
+ streamlit