Update app.py
Browse files
app.py
CHANGED
@@ -1,49 +1,48 @@
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
|
|
|
|
3 |
from PIL import Image
|
|
|
4 |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
5 |
|
6 |
-
#
|
7 |
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
|
8 |
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
|
9 |
|
10 |
def display_sketch(sketch):
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
19 |
|
20 |
def recognize_text(image_path):
|
21 |
-
# Open
|
22 |
-
image = Image.open(image_path)
|
23 |
-
#
|
24 |
-
image = image.
|
|
|
|
|
25 |
|
26 |
-
#
|
27 |
pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
|
|
28 |
generated_ids = model.generate(pixel_values)
|
29 |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
30 |
-
|
31 |
return generated_text
|
32 |
|
33 |
-
# Define the Gradio interface
|
34 |
with gr.Blocks() as demo:
|
35 |
-
sketchpad = gr.Sketchpad(label="Draw Something")
|
36 |
-
|
37 |
recognized_text = gr.Textbox(label="Recognized Text")
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
text = recognize_text(image_path)
|
43 |
-
return image_path, text
|
44 |
-
return None, ""
|
45 |
-
|
46 |
-
sketchpad.change(process_and_recognize, inputs=sketchpad, outputs=[output_image, recognized_text])
|
47 |
|
48 |
-
# Launch the demo
|
49 |
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import tempfile
|
5 |
from PIL import Image
|
6 |
+
import torch
|
7 |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
8 |
|
9 |
+
# Load model and processor
|
10 |
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
|
11 |
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")
|
12 |
|
13 |
def display_sketch(sketch):
|
14 |
+
image_data = sketch['composite']
|
15 |
+
plt.imshow(image_data)
|
16 |
+
plt.axis('off')
|
17 |
+
|
18 |
+
temp_file_path = "/mnt/data/output.png"
|
19 |
+
plt.savefig(temp_file_path, bbox_inches='tight', pad_inches=0)
|
20 |
+
plt.close()
|
21 |
+
|
22 |
+
return temp_file_path
|
23 |
|
24 |
def recognize_text(image_path):
|
25 |
+
# Open image and convert to grayscale
|
26 |
+
image = Image.open(image_path).convert("L")
|
27 |
+
# Resize image to 256x256
|
28 |
+
image = image.resize((256, 256))
|
29 |
+
# Binarize image (convert to black and white)
|
30 |
+
image = image.point(lambda p: p > 128 and 255)
|
31 |
|
32 |
+
# Preprocess the image
|
33 |
pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
34 |
+
# Generate prediction
|
35 |
generated_ids = model.generate(pixel_values)
|
36 |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
|
37 |
return generated_text
|
38 |
|
|
|
39 |
with gr.Blocks() as demo:
|
40 |
+
sketchpad = gr.Sketchpad(label="Draw Something", brush_radius=10)
|
41 |
+
sketchpad_output = gr.Image(label="Your Sketch")
|
42 |
recognized_text = gr.Textbox(label="Recognized Text")
|
43 |
|
44 |
+
sketchpad.submit(display_sketch, inputs=sketchpad, outputs=sketchpad_output).then(
|
45 |
+
recognize_text, inputs=sketchpad_output, outputs=recognized_text
|
46 |
+
)
|
|
|
|
|
|
|
|
|
|
|
47 |
|
|
|
48 |
demo.launch()
|