DeepDiveDev commited on
Commit
429d160
·
verified ·
1 Parent(s): fca31c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -19
app.py CHANGED
@@ -1,35 +1,32 @@
1
  import gradio as gr
2
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
- from PIL import Image
4
- import numpy as np
5
  import torch
 
 
 
6
 
7
- # Load TrOCR model and processor
8
- processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
9
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
 
10
 
11
- # Function to extract text from handwritten images
12
  def extract_text(image):
13
  try:
14
- # Convert image to RGB if needed
15
  if isinstance(image, np.ndarray):
16
- if len(image.shape) == 2: # If grayscale (H, W), convert to RGB
17
  image = np.stack([image] * 3, axis=-1)
18
  image = Image.fromarray(image)
19
  else:
20
  image = Image.open(image).convert("RGB")
21
 
22
- # Preprocessing (convert to grayscale for better OCR)
23
- image = image.convert("L")
24
- image = image.resize((640, 640))
25
-
26
- # Process image
27
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
28
  generated_ids = model.generate(pixel_values)
29
  extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
30
 
31
- return extracted_text if extracted_text.strip() else "No text detected."
32
-
33
  except Exception as e:
34
  return f"Error: {str(e)}"
35
 
@@ -38,9 +35,9 @@ iface = gr.Interface(
38
  fn=extract_text,
39
  inputs="image",
40
  outputs="text",
41
- title="Handwritten OCR Extractor",
42
- description="Upload a handwritten image to extract text.",
43
  )
44
 
45
- # Launch the app
46
  iface.launch()
 
1
  import gradio as gr
 
 
 
2
  import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, AutoModelForVision2Seq
6
 
7
+ # Load the model & processor
8
+ model_name = "Murasajo/Llama-3.2-VL-Finetuned-on-HandwrittenText"
9
+ processor = AutoProcessor.from_pretrained(model_name)
10
+ model = AutoModelForVision2Seq.from_pretrained(model_name)
11
 
12
+ # Function to extract handwritten text
13
  def extract_text(image):
14
  try:
15
+ # Convert input to PIL Image
16
  if isinstance(image, np.ndarray):
17
+ if len(image.shape) == 2: # If grayscale (H, W), add channels
18
  image = np.stack([image] * 3, axis=-1)
19
  image = Image.fromarray(image)
20
  else:
21
  image = Image.open(image).convert("RGB")
22
 
23
+ # Process image through model
 
 
 
 
24
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
25
  generated_ids = model.generate(pixel_values)
26
  extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
27
 
28
+ return extracted_text
29
+
30
  except Exception as e:
31
  return f"Error: {str(e)}"
32
 
 
35
  fn=extract_text,
36
  inputs="image",
37
  outputs="text",
38
+ title="Handwritten Text OCR",
39
+ description="Upload a handwritten document and extract text using AI.",
40
  )
41
 
42
+ # Run the app
43
  iface.launch()