Jangai commited on
Commit
301a707
·
verified ·
1 Parent(s): b0ac62c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -18
app.py CHANGED
@@ -1,29 +1,34 @@
1
  import gradio as gr
2
- from PIL import Image
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
 
 
4
 
5
- # Load the processor and model
6
- processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
7
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
8
 
9
- # Define the function to recognize handwriting
10
  def recognize_handwriting(image):
11
- pixel_values = processor(images=image, return_tensors="pt").pixel_values
 
 
 
 
 
 
 
12
  generated_ids = model.generate(pixel_values)
13
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
14
  return generated_text
15
 
16
  # Create the Gradio interface
17
- with gr.Blocks() as demo:
18
- gr.Markdown("# Handwriting Recognition")
19
- with gr.Row():
20
- with gr.Column():
21
- image_input = gr.Image(tool="editor", type="numpy", label="Draw or Upload an Image")
22
- recognize_button = gr.Button("Recognize Handwriting")
23
- with gr.Column():
24
- output_text = gr.Textbox(label="Recognized Text")
25
-
26
- recognize_button.click(fn=recognize_handwriting, inputs=image_input, outputs=output_text)
27
 
28
- # Launch the Gradio app
29
- demo.launch()
 
 
 
 
 
 
1
  import gradio as gr
 
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
+ from PIL import Image
4
+ import numpy as np
5
 
6
+ # Load the model and processor
7
+ processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
8
+ model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten')
9
 
10
+ # Define the prediction function
11
  def recognize_handwriting(image):
12
+ if isinstance(image, dict):
13
+ image = Image.fromarray(image['image'])
14
+ elif isinstance(image, np.ndarray):
15
+ image = Image.fromarray(image)
16
+ else:
17
+ image = Image.open(image)
18
+
19
+ pixel_values = processor(image, return_tensors="pt").pixel_values
20
  generated_ids = model.generate(pixel_values)
21
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
22
  return generated_text
23
 
24
  # Create the Gradio interface
25
+ image_input = gr.Image(type="numpy", label="Draw or Upload an Image")
26
+ output_text = gr.Textbox(label="Recognized Text")
 
 
 
 
 
 
 
 
27
 
28
+ gr.Interface(
29
+ fn=recognize_handwriting,
30
+ inputs=image_input,
31
+ outputs=output_text,
32
+ title="Handwritten Text Recognition",
33
+ description="Draw or upload an image of handwritten text to recognize it.",
34
+ ).launch()