DeepDiveDev commited on
Commit
fa36a00
·
verified ·
1 Parent(s): e6b9318

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
  import numpy as np
5
  import torch
@@ -9,8 +9,8 @@ processor1 = TrOCRProcessor.from_pretrained("DeepDiveDev/transformodocs-ocr")
9
  model1 = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-ocr")
10
 
11
  # Load the fallback model (allenai/olmOCR-7B-0225-preview)
12
- processor2 = TrOCRProcessor.from_pretrained("allenai/olmOCR-7B-0225-preview")
13
- model2 = VisionEncoderDecoderModel.from_pretrained("allenai/olmOCR-7B-0225-preview")
14
 
15
  # Function to extract text using both models
16
  def extract_text(image):
@@ -32,8 +32,8 @@ def extract_text(image):
32
 
33
  # If output seems incorrect, use the fallback model
34
  if len(extracted_text.strip()) < 2: # If output is too short, retry with second model
35
- pixel_values = processor2(images=image, return_tensors="pt").pixel_values
36
- generated_ids = model2.generate(pixel_values)
37
  extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]
38
 
39
  return extracted_text
@@ -50,4 +50,4 @@ iface = gr.Interface(
50
  description="Upload a handwritten document and get the extracted text.",
51
  )
52
 
53
- iface.launch()
 
1
  import gradio as gr
2
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoProcessor, AutoModelForVision2Seq
3
  from PIL import Image
4
  import numpy as np
5
  import torch
 
9
  model1 = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-ocr")
10
 
11
  # Load the fallback model (allenai/olmOCR-7B-0225-preview)
12
+ processor2 = AutoProcessor.from_pretrained("allenai/olmOCR-7B-0225-preview")
13
+ model2 = AutoModelForVision2Seq.from_pretrained("allenai/olmOCR-7B-0225-preview", torch_dtype=torch.float16)
14
 
15
  # Function to extract text using both models
16
  def extract_text(image):
 
32
 
33
  # If output seems incorrect, use the fallback model
34
  if len(extracted_text.strip()) < 2: # If output is too short, retry with second model
35
+ inputs = processor2(images=image, return_tensors="pt").pixel_values
36
+ generated_ids = model2.generate(inputs)
37
  extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]
38
 
39
  return extracted_text
 
50
  description="Upload a handwritten document and get the extracted text.",
51
  )
52
 
53
+ iface.launch()