Jangai commited on
Commit
7574fa9
·
verified ·
1 Parent(s): 443e319

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -57
app.py CHANGED
@@ -1,69 +1,32 @@
1
  import gradio as gr
2
- import torch
3
- from PIL import Image
4
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
5
- import os
6
- import json
7
-
8
- # Load the model and processor
9
- processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
10
- model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
11
-
12
- # Handle model initialization warnings
13
- model.eval()
14
 
15
- # Load feedback data
16
- feedback_data_path = "feedback_data.json"
17
- if os.path.exists(feedback_data_path):
18
- with open(feedback_data_path, "r") as file:
19
- feedback_data = json.load(file)
20
- else:
21
- feedback_data = []
22
 
23
  def recognize_handwriting(image):
24
  if isinstance(image, dict):
25
- image = Image.fromarray(image['image'].astype('uint8')).convert("RGB")
26
- else:
27
- image = Image.open(image).convert("RGB")
28
-
29
- pixel_values = processor(images=image, return_tensors="pt").pixel_values
30
  generated_ids = model.generate(pixel_values)
31
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
32
-
33
  return generated_text
34
 
35
- def provide_feedback(image, corrected_text):
36
- feedback_data.append({
37
- "image": image,
38
- "corrected_text": corrected_text
39
- })
40
-
41
- # Save feedback every 5 entries
42
- if len(feedback_data) % 5 == 0:
43
- with open(feedback_data_path, "w") as file:
44
- json.dump(feedback_data, file)
45
-
46
- return "Feedback received. Thank you!"
47
 
48
- # Gradio Interface
49
- with gr.Blocks() as demo:
50
- gr.Markdown("# Handwriting Recognition with Feedback")
51
-
52
- with gr.Tabs():
53
- with gr.TabItem("Recognize Handwriting"):
54
- image_input = gr.Image(source="upload", tool="editor", type="numpy", label="Draw or Upload an Image")
55
- recognize_button = gr.Button("Recognize Handwriting")
56
- output_text = gr.Textbox(label="Recognized Text")
57
-
58
- recognize_button.click(recognize_handwriting, inputs=image_input, outputs=output_text)
59
-
60
- with gr.TabItem("Provide Feedback"):
61
- feedback_image_input = gr.Image(source="upload", tool="editor", type="numpy", label="Draw or Upload an Image")
62
- corrected_text_input = gr.Textbox(label="Corrected Text")
63
- feedback_button = gr.Button("Submit Feedback")
64
- feedback_output = gr.Textbox(label="Feedback Status")
65
-
66
- feedback_button.click(provide_feedback, inputs=[feedback_image_input, corrected_text_input], outputs=feedback_output)
67
 
68
- # Run the Gradio app
69
- demo.launch(share=True)
 
1
  import gradio as gr
 
 
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
+ from PIL import Image
4
+ import numpy as np
 
 
 
 
 
 
 
5
 
6
+ # Load the processor and model
7
+ processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
8
+ model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten')
 
 
 
 
9
 
10
  def recognize_handwriting(image):
11
  if isinstance(image, dict):
12
+ image = image["image"]
13
+ pil_image = Image.fromarray(image).convert("RGB")
14
+ pixel_values = processor(images=pil_image, return_tensors="pt").pixel_values
 
 
15
  generated_ids = model.generate(pixel_values)
16
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
17
  return generated_text
18
 
19
+ # Gradio interface
20
+ image_input = gr.Image(tool="editor", type="numpy", label="Draw or Upload an Image")
21
+ output_text = gr.Textbox(label="Recognized Text")
 
 
 
 
 
 
 
 
 
22
 
23
+ iface = gr.Interface(
24
+ fn=recognize_handwriting,
25
+ inputs=image_input,
26
+ outputs=output_text,
27
+ title="Handwriting Recognition",
28
+ description="Draw or upload an image of handwritten text, and the model will recognize the text.",
29
+ live=True,
30
+ )
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ iface.launch()