nbonetto commited on
Commit
ee142e9
·
1 Parent(s): b2acd9d

fix: trained model to 0.49 cer and fixed streamlit app issues

Browse files
.gitignore CHANGED
@@ -1,3 +1,4 @@
1
  __pycache__/
2
  core/__pycache__/
3
  trocr-ocr/
 
 
1
  __pycache__/
2
  core/__pycache__/
3
  trocr-ocr/
4
+ light-orc/
app.py CHANGED
@@ -3,17 +3,16 @@ from PIL import Image
3
  import torch
4
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
5
 
6
- MODEL_NAME = 'model/'
7
- processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
8
- model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME)
9
 
10
  streamlit.title('Light OCR')
11
 
12
  uploaded_file = streamlit.file_uploader('Choose an image...', type=['png', 'jpg', 'jpeg'])
13
  if uploaded_file:
14
- image = Image.open(uploaded_file).convert("RGB")
15
- image = image.resize((384, 384))
16
- streamlit.image(image, caption='Uploaded Image', use_column_width=True)
17
 
18
  pixel_values = processor(images=image, return_tensors='pt').pixel_values
19
  output_ids = model.generate(pixel_values)
 
3
  import torch
4
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
5
 
6
+ MODEL_PATH = 'model/'
7
+ processor = TrOCRProcessor.from_pretrained(MODEL_PATH)
8
+ model = VisionEncoderDecoderModel.from_pretrained(MODEL_PATH)
9
 
10
  streamlit.title('Light OCR')
11
 
12
  uploaded_file = streamlit.file_uploader('Choose an image...', type=['png', 'jpg', 'jpeg'])
13
  if uploaded_file:
14
+ image = Image.open(uploaded_file).convert('RGB')
15
+ streamlit.image(image, caption='Uploaded Image', use_container_width=True)
 
16
 
17
  pixel_values = processor(images=image, return_tensors='pt').pixel_values
18
  output_ids = model.generate(pixel_values)
config.py CHANGED
@@ -2,7 +2,7 @@ DATASET_NAME = 'Teklia/IAM-line'
2
  MODEL_NAME = 'microsoft/trocr-small-printed'
3
  TRAIN_SPLIT = 'train[:200]'
4
  TEST_SPLIT_RATIO = 0.2
5
- BATCH_SIZE = 2
6
- EPOCHS = 1
7
  OUTPUT_DIR = './trocr-ocr'
8
 
 
2
  MODEL_NAME = 'microsoft/trocr-small-printed'
3
  TRAIN_SPLIT = 'train[:200]'
4
  TEST_SPLIT_RATIO = 0.2
5
+ BATCH_SIZE = 8
6
+ EPOCHS = 8
7
  OUTPUT_DIR = './trocr-ocr'
8
 
core/data.py CHANGED
@@ -2,13 +2,15 @@ import torch
2
  from datasets import load_dataset
3
  from config import DATASET_NAME, TRAIN_SPLIT, TEST_SPLIT_RATIO
4
  from core.model import processor
 
5
 
6
  def preprocess_batch(batch):
7
- images = [img.convert("RGB") for img in batch["image"]]
8
- pixel_values = processor(images=images, return_tensors="pt").pixel_values
9
- batch["pixel_values"] = pixel_values
 
10
 
11
- labels = processor.tokenizer(batch["text"], padding=True, truncation=True).input_ids
12
  batch["labels"] = labels
13
 
14
  return batch
@@ -19,8 +21,8 @@ def load():
19
  train_ds = train_test['train']
20
  eval_ds = train_test['test']
21
 
22
- train_ds = train_ds.map(preprocess_batch, batched = True, remove_columns = train_ds.column_names)
23
- eval_ds = eval_ds.map(preprocess_batch, batched = True, remove_columns = eval_ds.column_names)
24
 
25
  return train_ds, eval_ds
26
 
 
2
  from datasets import load_dataset
3
  from config import DATASET_NAME, TRAIN_SPLIT, TEST_SPLIT_RATIO
4
  from core.model import processor
5
+ from PIL import Image
6
 
7
  def preprocess_batch(batch):
8
+ images = [img.convert('RGB') for img in batch["image"]]
9
+
10
+ labels = processor.tokenizer(batch['text'], padding=True, max_length=128, truncation=True).input_ids
11
+ pixel_values = processor.image_processor(images, return_tensors="pt").pixel_values
12
 
13
+ batch["pixel_values"] = pixel_values
14
  batch["labels"] = labels
15
 
16
  return batch
 
21
  train_ds = train_test['train']
22
  eval_ds = train_test['test']
23
 
24
+ train_ds = train_ds.map(preprocess_batch, batched=True, remove_columns=train_ds.column_names)
25
+ eval_ds = eval_ds.map(preprocess_batch, batched=True, remove_columns=eval_ds.column_names)
26
 
27
  return train_ds, eval_ds
28
 
core/model.py CHANGED
@@ -3,7 +3,7 @@ from transformers import VisionEncoderDecoderModel, TrOCRProcessor
3
  from config import MODEL_NAME
4
 
5
  device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
6
- processor = TrOCRProcessor.from_pretrained(MODEL_NAME, use_fast = False)
7
  model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME).to(device)
8
 
9
  model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
 
3
  from config import MODEL_NAME
4
 
5
  device = torch.device('mps') if torch.backends.mps.is_available() else torch.device('cpu')
6
+ processor = TrOCRProcessor.from_pretrained(MODEL_NAME)
7
  model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME).to(device)
8
 
9
  model.config.decoder_start_token_id = processor.tokenizer.bos_token_id
core/train.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
2
  from core.model import model, processor
3
  from core.data import load
4
  from core.utils import compute_metrics
@@ -12,9 +12,11 @@ training_args = Seq2SeqTrainingArguments(
12
  per_device_eval_batch_size = BATCH_SIZE,
13
  predict_with_generate = True,
14
  eval_strategy = 'epoch',
15
- logging_steps = 10,
16
  num_train_epochs = EPOCHS,
17
  save_total_limit = 1,
 
 
18
  fp16 = False
19
  )
20
 
@@ -24,9 +26,14 @@ trainer = Seq2SeqTrainer(
24
  train_dataset = train_ds,
25
  eval_dataset = eval_ds,
26
  processing_class = processor.image_processor,
27
- data_collator = default_data_collator,
28
  compute_metrics = compute_metrics
29
  )
30
 
31
- if __name__ == '__main__':
32
  trainer.train()
 
 
 
 
 
 
 
1
+ from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
2
  from core.model import model, processor
3
  from core.data import load
4
  from core.utils import compute_metrics
 
12
  per_device_eval_batch_size = BATCH_SIZE,
13
  predict_with_generate = True,
14
  eval_strategy = 'epoch',
15
+ logging_steps = 50,
16
  num_train_epochs = EPOCHS,
17
  save_total_limit = 1,
18
+ remove_unused_columns = False,
19
+ learning_rate = 5e-5,
20
  fp16 = False
21
  )
22
 
 
26
  train_dataset = train_ds,
27
  eval_dataset = eval_ds,
28
  processing_class = processor.image_processor,
 
29
  compute_metrics = compute_metrics
30
  )
31
 
32
+ def train_save():
33
  trainer.train()
34
+ trainer.save_model('./model')
35
+ processor.save_pretrained('./model')
36
+
37
+
38
+ if __name__ == '__main__':
39
+ train_save()
model/generation_config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:66a55c9bff6d80e77c8deb6dba8dd79d867da689c7e0f1e1eddb265f8a92fb1b
3
- size 185
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:91fde0da8b70ba657bd5e495956d6661ebf5ed65daeb70a4bcb488e9c62c046a
3
+ size 155
model/model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6d77161e6a5564a2d70e53b5dabfad12b67fb2e9bd7c3cc7555b1fe056bc8826
3
  size 246430696
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d7f1b418d01098f6a8c33b290b12952b49235e97d666feea5a95b5e38c250a3
3
  size 246430696
model/preprocessor_config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5172006ffcaf0f407db91ac4ada30ad6ca86183fa37ce6059f966bfaffb880cb
3
- size 411
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36a945a7cc645688b9ef64dabae16979cf5f7c1c448569cc306694edc0598b9b
3
+ size 450
model/sentencepiece.bpe.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f5e2fefcf793761a76a6bfb8ad35489f9c203b25557673284b6d032f41043f4
3
+ size 1356293
model/special_tokens_map.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5469a60db23249c7f8945013d78df30b44b6bf686c6bb4740f4223f77b1b535
3
+ size 279
model/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:969a92d9be8996720f3523976fe57f101a56d920b388707d48641055596c114f
3
+ size 4494958
model/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0836517b3d82dcc162f06172e0b50bc1df3024cce7cf2d71ed009acd4d8c75ea
3
+ size 1268
model/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26b9bcd85dc494ff7ae55bc22488e87174ffd31d5b8205496c292075416de299
3
+ size 5432
requirements.txt CHANGED
@@ -5,5 +5,6 @@ datasets
5
  evaluate
6
  jiwer
7
  Pillow
8
- accelerator
9
  streamlit
 
 
5
  evaluate
6
  jiwer
7
  Pillow
8
+ accelerate
9
  streamlit
10
+ sentencepiece