DeepDiveDev commited on
Commit
398e23b
·
verified ·
1 Parent(s): d010bf6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -29
app.py CHANGED
@@ -4,51 +4,43 @@ from PIL import Image
4
  import numpy as np
5
  import torch
6
 
7
- # Load the primary OCR model (DeepDiveDev/transformodocs-ocr)
8
- processor1 = TrOCRProcessor.from_pretrained("DeepDiveDev/transformodocs-ocr")
9
- model1 = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-ocr")
10
-
11
- # Load the fallback model (microsoft/trocr-base-handwritten) for handwritten text
12
- processor2 = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
13
- model2 = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
14
 
15
  # Function to extract text from handwritten images
16
  def extract_text(image):
17
  try:
18
- # Ensure input is a PIL Image
19
- if isinstance(image, np.ndarray):
20
- if len(image.shape) == 2: # Grayscale (H, W) -> Convert to RGB
21
  image = np.stack([image] * 3, axis=-1)
22
  image = Image.fromarray(image)
23
- elif isinstance(image, str): # If file path is given, open the image
24
  image = Image.open(image).convert("RGB")
25
 
26
- # Maintain aspect ratio while resizing (better for OCR)
27
- image.thumbnail((800, 800))
28
-
29
- # Process image with the first model
30
- pixel_values = processor1(images=image, return_tensors="pt").pixel_values.to(torch.float32)
31
- generated_ids = model1.generate(pixel_values)
32
- extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]
33
-
34
- # If output is short or incorrect, use the fallback model
35
- if len(extracted_text.strip()) < 2:
36
- inputs = processor2(images=image, return_tensors="pt").pixel_values.to(torch.float32)
37
- generated_ids = model2.generate(inputs)
38
- extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]
39
 
40
- return extracted_text if extracted_text else "No text detected."
 
 
 
41
 
 
 
42
  except Exception as e:
43
  return f"Error: {str(e)}"
44
 
45
- # Gradio UI for OCR Extraction
46
  iface = gr.Interface(
47
  fn=extract_text,
48
- inputs=gr.Image(type="pil"), # Ensures input is a PIL image
49
  outputs="text",
50
- title="Handwritten OCR Extraction",
51
- description="Upload a handwritten image to extract text using AI OCR.",
52
  )
53
 
 
54
  iface.launch()
 
4
  import numpy as np
5
  import torch
6
 
7
+ # Load TrOCR model and processor
8
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
9
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
 
 
 
 
10
 
11
  # Function to extract text from handwritten images
12
  def extract_text(image):
13
  try:
14
+ # Convert image to RGB if needed
15
+ if isinstance(image, np.ndarray):
16
+ if len(image.shape) == 2: # If grayscale (H, W), convert to RGB
17
  image = np.stack([image] * 3, axis=-1)
18
  image = Image.fromarray(image)
19
+ else:
20
  image = Image.open(image).convert("RGB")
21
 
22
+ # Preprocessing (convert to grayscale for better OCR)
23
+ image = image.convert("L")
24
+ image = image.resize((640, 640))
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Process image
27
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
28
+ generated_ids = model.generate(pixel_values)
29
+ extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
30
 
31
+ return extracted_text if extracted_text.strip() else "No text detected."
32
+
33
  except Exception as e:
34
  return f"Error: {str(e)}"
35
 
36
+ # Gradio Interface
37
  iface = gr.Interface(
38
  fn=extract_text,
39
+ inputs="image",
40
  outputs="text",
41
+ title="Handwritten OCR Extractor",
42
+ description="Upload a handwritten image to extract text.",
43
  )
44
 
45
+ # Launch the app
46
  iface.launch()