Jangai commited on
Commit
9317cd1
·
verified ·
1 Parent(s): 3dd9291

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -52
app.py CHANGED
@@ -1,62 +1,38 @@
1
  import gradio as gr
2
- import numpy as np
3
- import os
4
- import pickle
5
  from PIL import Image
6
- from transformers import pipeline
7
-
8
- # Load the classifier model
9
- classifier = pipeline("image-classification", model="google/vit-base-patch16-224")
10
-
11
- # Initialize a dictionary to store feedback
12
- feedback_data = {"images": [], "labels": []}
13
-
14
- def save_feedback():
15
- with open("feedback_data.pkl", "wb") as f:
16
- pickle.dump(feedback_data, f)
17
-
18
- def load_feedback():
19
- global feedback_data
20
- if os.path.exists("feedback_data.pkl"):
21
- with open("feedback_data.pkl", "rb") as f:
22
- feedback_data = pickle.load(f)
23
-
24
- def predict(image):
25
- image = Image.fromarray(image.astype('uint8'), 'RGB')
26
- prediction = classifier(image)
27
- return {pred["label"]: pred["score"] for pred in prediction}
28
-
29
- def provide_feedback(image, label):
30
- global feedback_data
31
- feedback_data["images"].append(image)
32
- feedback_data["labels"].append(label)
33
- save_feedback()
34
-
35
- if len(feedback_data["images"]) % 5 == 0:
36
- retrain_model()
37
- return "Feedback saved. Thank you!"
38
-
39
- def retrain_model():
40
- global classifier
41
- # Here, include the retraining logic using the feedback_data
42
- # This is a placeholder for actual retraining logic
43
- print("Retraining the model with new data...")
44
-
45
- # Load existing feedback data
46
- load_feedback()
47
 
48
  with gr.Blocks() as demo:
49
- with gr.Tab("Predict"):
50
- image_input = gr.Sketchpad()
51
- output = gr.JSON()
52
- image_input.change(fn=predict, inputs=image_input, outputs=output)
 
53
 
54
  with gr.Tab("Provide Feedback"):
55
- image_feedback = gr.Sketchpad()
56
- label_feedback = gr.Textbox(label="Enter the correct label")
57
  feedback_button = gr.Button("Submit Feedback")
58
  feedback_output = gr.Textbox()
59
-
60
- feedback_button.click(fn=provide_feedback, inputs=[image_feedback, label_feedback], outputs=feedback_output)
61
 
62
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
 
 
3
  from PIL import Image
4
+ import requests
5
+ import torch
6
+
7
+ # Load the pre-trained model and processor
8
+ processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
9
+ model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
10
+
11
+ # Define the prediction function
12
+ def recognize_handwriting(image):
13
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
14
+ generated_ids = model.generate(pixel_values)
15
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
16
+ return generated_text
17
+
18
+ def provide_feedback(image, correct_text):
19
+ # Save the feedback to a file or database for later retraining
20
+ with open("feedback.txt", "a") as f:
21
+ f.write(f"{correct_text}\n")
22
+ return "Feedback received. Thank you!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  with gr.Blocks() as demo:
25
+ with gr.Tab("Recognize Handwriting"):
26
+ image_input = gr.Image(type="pil")
27
+ output = gr.Textbox(label="Recognized Text")
28
+ recognize_button = gr.Button("Recognize")
29
+ recognize_button.click(fn=recognize_handwriting, inputs=image_input, outputs=output)
30
 
31
  with gr.Tab("Provide Feedback"):
32
+ image_feedback = gr.Image(type="pil")
33
+ correct_text = gr.Textbox(label="Correct Text")
34
  feedback_button = gr.Button("Submit Feedback")
35
  feedback_output = gr.Textbox()
36
+ feedback_button.click(fn=provide_feedback, inputs=[image_feedback, correct_text], outputs=feedback_output)
 
37
 
38
  demo.launch()