|
import gradio as gr |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import tempfile |
|
import os |
|
import logging |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
from PIL import Image |
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
|
|
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten') |
|
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten') |
|
|
|
def display_sketch(sketch): |
|
logging.debug(f"Received sketch data: {sketch}") |
|
|
|
if isinstance(sketch, dict) and "composite" in sketch: |
|
image_data = sketch["composite"] |
|
logging.debug(f"Image data type: {type(image_data)}") |
|
logging.debug(f"Image data shape: {np.array(image_data).shape}") |
|
|
|
plt.imshow(image_data, cmap='gray') |
|
plt.axis('off') |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: |
|
plt.savefig(temp_file.name, bbox_inches='tight') |
|
temp_file_path = temp_file.name |
|
|
|
return temp_file_path |
|
else: |
|
error_message = f"Unexpected sketch data format: {type(sketch)}" |
|
logging.error(error_message) |
|
return error_message |
|
|
|
def recognize_text(image_path): |
|
|
|
image = Image.open(image_path).convert("RGB") |
|
|
|
pixel_values = processor(image, return_tensors="pt").pixel_values |
|
|
|
generated_ids = model.generate(pixel_values) |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return generated_text |
|
|
|
with gr.Blocks() as demo: |
|
sketchpad = gr.Sketchpad(label="Draw Something") |
|
output_image = gr.Image(label="Your Sketch") |
|
recognized_text = gr.Textbox(label="Recognized Text") |
|
submit_btn = gr.Button("Submit") |
|
|
|
submit_btn.click(fn=display_sketch, inputs=sketchpad, outputs=output_image) |
|
submit_btn.click(fn=recognize_text, inputs=output_image, outputs=recognized_text) |
|
|
|
demo.launch() |
|
|