DeepDiveDev commited on
Commit
a3df3f5
·
verified ·
1 Parent(s): 2653a83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -4,51 +4,51 @@ from PIL import Image
4
  import numpy as np
5
  import torch
6
 
7
- # Load the primary model (DeepDiveDev/transformodocs-ocr)
8
  processor1 = TrOCRProcessor.from_pretrained("DeepDiveDev/transformodocs-ocr")
9
  model1 = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-ocr")
10
 
11
- # Load the fallback model (microsoft/trocr-base-handwritten)
12
  processor2 = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
13
  model2 = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
14
 
15
- # Function to extract text using both models
16
  def extract_text(image):
17
  try:
18
  # Convert NumPy array to PIL Image if needed
19
  if isinstance(image, np.ndarray):
20
- if len(image.shape) == 2: # Grayscale (H, W), convert to RGB
21
  image = np.stack([image] * 3, axis=-1)
22
  image = Image.fromarray(image)
23
  else:
24
- image = Image.open(image).convert("RGB") # Ensure RGB mode
25
 
26
- # Maintain aspect ratio while resizing
27
- image.thumbnail((640, 640))
28
 
29
- # Process with the primary model
30
  pixel_values = processor1(images=image, return_tensors="pt").pixel_values.to(torch.float32)
31
  generated_ids = model1.generate(pixel_values)
32
  extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
 
34
- # If output seems incorrect, use the fallback model
35
  if len(extracted_text.strip()) < 2:
36
  inputs = processor2(images=image, return_tensors="pt").pixel_values.to(torch.float32)
37
  generated_ids = model2.generate(inputs)
38
  extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]
39
 
40
- return extracted_text
41
 
42
  except Exception as e:
43
  return f"Error: {str(e)}"
44
 
45
- # Gradio Interface
46
  iface = gr.Interface(
47
  fn=extract_text,
48
- inputs="image",
49
  outputs="text",
50
- title="TransformoDocs - AI OCR",
51
- description="Upload a handwritten document and get the extracted text.",
52
  )
53
 
54
  iface.launch()
 
4
  import numpy as np
5
  import torch
6
 
7
+ # Load the primary OCR model (DeepDiveDev/transformodocs-ocr)
8
  processor1 = TrOCRProcessor.from_pretrained("DeepDiveDev/transformodocs-ocr")
9
  model1 = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-ocr")
10
 
11
+ # Load the fallback model (microsoft/trocr-base-handwritten) for handwritten text
12
  processor2 = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
13
  model2 = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
14
 
15
+ # Function to extract text from handwritten images
16
  def extract_text(image):
17
  try:
18
  # Convert NumPy array to PIL Image if needed
19
  if isinstance(image, np.ndarray):
20
+ if len(image.shape) == 2: # Grayscale (H, W) -> Convert to RGB
21
  image = np.stack([image] * 3, axis=-1)
22
  image = Image.fromarray(image)
23
  else:
24
+ image = Image.open(image).convert("RGB") # Ensure RGB format
25
 
26
+ # Maintain aspect ratio while resizing (better for OCR)
27
+ image.thumbnail((800, 800))
28
 
29
+ # Process image with the first model
30
  pixel_values = processor1(images=image, return_tensors="pt").pixel_values.to(torch.float32)
31
  generated_ids = model1.generate(pixel_values)
32
  extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
 
34
+ # If output is short or incorrect, use the fallback model
35
  if len(extracted_text.strip()) < 2:
36
  inputs = processor2(images=image, return_tensors="pt").pixel_values.to(torch.float32)
37
  generated_ids = model2.generate(inputs)
38
  extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]
39
 
40
+ return extracted_text if extracted_text else "No text detected."
41
 
42
  except Exception as e:
43
  return f"Error: {str(e)}"
44
 
45
+ # Gradio UI for OCR Extraction
46
  iface = gr.Interface(
47
  fn=extract_text,
48
+ inputs=gr.Image(type="pil"),
49
  outputs="text",
50
+ title="Handwritten OCR Extraction",
51
+ description="Upload a handwritten image to extract text using AI OCR.",
52
  )
53
 
54
  iface.launch()