Jangai commited on
Commit
86d342d
·
verified ·
1 Parent(s): 5f15dd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -26
app.py CHANGED
@@ -1,34 +1,52 @@
1
  import gradio as gr
2
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  from PIL import Image
4
- import numpy as np
 
 
 
 
 
 
 
5
 
6
- # Load the model and processor
7
- processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
8
- model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten')
 
9
 
10
- # Define the prediction function
11
  def recognize_handwriting(image):
12
- if isinstance(image, dict):
13
- image = Image.fromarray(image['image'])
14
- elif isinstance(image, np.ndarray):
15
- image = Image.fromarray(image)
16
- else:
17
- image = Image.open(image)
 
 
18
 
19
- pixel_values = processor(image, return_tensors="pt").pixel_values
20
- generated_ids = model.generate(pixel_values)
21
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
22
- return generated_text
 
 
 
 
 
23
 
24
- # Create the Gradio interface
25
- image_input = gr.Sketchpad(label="Draw or Upload an Image")
26
- output_text = gr.Textbox(label="Recognized Text")
 
 
 
 
 
 
 
 
27
 
28
- gr.Interface(
29
- fn=recognize_handwriting,
30
- inputs=image_input,
31
- outputs=output_text,
32
- title="Handwritten Text Recognition",
33
- description="Draw or upload an image of handwritten text to recognize it.",
34
- ).launch()
 
1
  import gradio as gr
 
2
  from PIL import Image
3
+ import requests
4
+ import torch
5
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
6
+ import logging
7
+
8
+ # Setup logging
9
+ logging.basicConfig(level=logging.DEBUG)
10
+ logger = logging.getLogger(__name__)
11
 
12
+ # Load processor and model
13
+ model_name = "microsoft/trocr-large-handwritten"
14
+ processor = TrOCRProcessor.from_pretrained(model_name)
15
+ model = VisionEncoderDecoderModel.from_pretrained(model_name)
16
 
17
+ # Function to recognize handwriting
18
  def recognize_handwriting(image):
19
+ try:
20
+ logger.info("Received an image for handwriting recognition.")
21
+ if isinstance(image, dict):
22
+ image = image.get("image")
23
+
24
+ if image is None:
25
+ logger.error("No image found in the input.")
26
+ return "No image found"
27
 
28
+ image = Image.fromarray(image).convert("RGB")
29
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
30
+ generated_ids = model.generate(pixel_values)
31
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
32
+ logger.info("Handwriting recognized successfully.")
33
+ return generated_text
34
+ except Exception as e:
35
+ logger.error(f"Error during handwriting recognition: {e}")
36
+ return f"Error: {str(e)}"
37
 
38
+ # Create Gradio interface
39
+ with gr.Blocks() as demo:
40
+ gr.Markdown("## Handwritten Text Recognition")
41
+ with gr.Row():
42
+ with gr.Column():
43
+ image_input = gr.Image(tool="editor", type="numpy", label="Draw or Upload an Image")
44
+ submit_button = gr.Button("Submit")
45
+ with gr.Column():
46
+ output_text = gr.Textbox(label="Recognized Text")
47
+
48
+ submit_button.click(fn=recognize_handwriting, inputs=image_input, outputs=output_text)
49
 
50
+ # Launch the app
51
+ if __name__ == "__main__":
52
+ demo.launch()