import gradio as gr import numpy as np import matplotlib.pyplot as plt import tempfile import os import logging from transformers import TrOCRProcessor, VisionEncoderDecoderModel from PIL import Image logging.basicConfig(level=logging.DEBUG) # Initialize the TrOCR model and processor processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten') model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten') def display_sketch(sketch): logging.debug(f"Received sketch data: {sketch}") if isinstance(sketch, dict) and "composite" in sketch: image_data = sketch["composite"] logging.debug(f"Image data type: {type(image_data)}") logging.debug(f"Image data shape: {np.array(image_data).shape}") plt.imshow(image_data, cmap='gray') plt.axis('off') # Use a temporary directory for saving the image with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: plt.savefig(temp_file.name, bbox_inches='tight') temp_file_path = temp_file.name return temp_file_path else: error_message = f"Unexpected sketch data format: {type(sketch)}" logging.error(error_message) return error_message def recognize_text(image_path): # Load the image image = Image.open(image_path).convert("RGB") # Prepare the image for the model pixel_values = processor(image, return_tensors="pt").pixel_values # Generate the text generated_ids = model.generate(pixel_values) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text with gr.Blocks() as demo: sketchpad = gr.Sketchpad(label="Draw Something") output_image = gr.Image(label="Your Sketch") recognized_text = gr.Textbox(label="Recognized Text") submit_btn = gr.Button("Submit") submit_btn.click(fn=display_sketch, inputs=sketchpad, outputs=output_image) submit_btn.click(fn=recognize_text, inputs=output_image, outputs=recognized_text) demo.launch()