Jangai commited on
Commit
e02beda
·
verified ·
1 Parent(s): 79d3533

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -23
app.py CHANGED
@@ -1,39 +1,66 @@
1
  import gradio as gr
2
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
- from PIL import Image
4
  import torch
5
- import numpy as np
 
 
 
6
 
7
- # Load the pre-trained model and processor
8
  processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
9
  model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
10
 
11
- # Define the prediction function
 
 
 
 
 
 
 
12
  def recognize_handwriting(image):
13
- image = Image.fromarray(image.astype('uint8'), 'RGB')
 
 
 
 
14
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
15
  generated_ids = model.generate(pixel_values)
16
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
17
  return generated_text
18
 
19
- def provide_feedback(image, correct_text):
20
- # Save the feedback to a file or database for later retraining
21
- with open("feedback.txt", "a") as f:
22
- f.write(f"{correct_text}\n")
 
 
 
 
 
 
 
23
  return "Feedback received. Thank you!"
24
 
 
25
  with gr.Blocks() as demo:
26
- with gr.Tab("Recognize Handwriting"):
27
- sketchpad = gr.Sketchpad(label="Draw something")
28
- output = gr.Textbox(label="Recognized Text")
29
- recognize_button = gr.Button("Recognize")
30
- recognize_button.click(fn=recognize_handwriting, inputs=sketchpad, outputs=output)
31
-
32
- with gr.Tab("Provide Feedback"):
33
- image_feedback = gr.Image(type="numpy")
34
- correct_text = gr.Textbox(label="Correct Text")
35
- feedback_button = gr.Button("Submit Feedback")
36
- feedback_output = gr.Textbox()
37
- feedback_button.click(fn=provide_feedback, inputs=[image_feedback, correct_text], outputs=feedback_output)
 
 
 
 
 
38
 
39
- demo.launch()
 
 
1
  import gradio as gr
 
 
2
  import torch
3
+ from PIL import Image
4
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
5
+ import os
6
+ import json
7
 
8
+ # Load the model and processor
9
  processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
10
  model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
11
 
12
+ # Load feedback data
13
+ feedback_data_path = "feedback_data.json"
14
+ if os.path.exists(feedback_data_path):
15
+ with open(feedback_data_path, "r") as file:
16
+ feedback_data = json.load(file)
17
+ else:
18
+ feedback_data = []
19
+
20
  def recognize_handwriting(image):
21
+ if isinstance(image, dict):
22
+ image = Image.fromarray(image['image'].astype('uint8')).convert("RGB")
23
+ else:
24
+ image = Image.open(image).convert("RGB")
25
+
26
  pixel_values = processor(images=image, return_tensors="pt").pixel_values
27
  generated_ids = model.generate(pixel_values)
28
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
29
+
30
  return generated_text
31
 
32
+ def provide_feedback(image, corrected_text):
33
+ feedback_data.append({
34
+ "image": image,
35
+ "corrected_text": corrected_text
36
+ })
37
+
38
+ # Save feedback every 5 entries
39
+ if len(feedback_data) % 5 == 0:
40
+ with open(feedback_data_path, "w") as file:
41
+ json.dump(feedback_data, file)
42
+
43
  return "Feedback received. Thank you!"
44
 
45
+ # Gradio Interface
46
  with gr.Blocks() as demo:
47
+ gr.Markdown("# Handwriting Recognition with Feedback")
48
+
49
+ with gr.Tabs():
50
+ with gr.TabItem("Recognize Handwriting"):
51
+ image_input = gr.Image(source="upload", tool="sketch", type="numpy", label="Draw or Upload an Image")
52
+ recognize_button = gr.Button("Recognize Handwriting")
53
+ output_text = gr.Textbox(label="Recognized Text")
54
+
55
+ recognize_button.click(recognize_handwriting, inputs=image_input, outputs=output_text)
56
+
57
+ with gr.TabItem("Provide Feedback"):
58
+ feedback_image_input = gr.Image(source="upload", tool="sketch", type="numpy", label="Draw or Upload an Image")
59
+ corrected_text_input = gr.Textbox(label="Corrected Text")
60
+ feedback_button = gr.Button("Submit Feedback")
61
+ feedback_output = gr.Textbox(label="Feedback Status")
62
+
63
+ feedback_button.click(provide_feedback, inputs=[feedback_image_input, corrected_text_input], outputs=feedback_output)
64
 
65
+ # Run the Gradio app
66
+ demo.launch(share=True)