File size: 2,105 Bytes
5e4488c 2fd0f3a 4a003ae 2fd0f3a 6bd6ea4 86d342d 2fd0f3a 3dd9291 6bd6ea4 2fd0f3a 4da9241 fd8f944 4da9241 2fd0f3a 4da9241 2fd0f3a 4a003ae 4da9241 7e53392 6bd6ea4 2fd0f3a 4da9241 2fd0f3a 6bd6ea4 4da9241 6bd6ea4 cff0816 2fd0f3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import tempfile
import os
import logging
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
logging.basicConfig(level=logging.DEBUG)
# Initialize the TrOCR model and processor
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten')
def display_sketch(sketch):
logging.debug(f"Received sketch data: {sketch}")
if isinstance(sketch, dict) and "composite" in sketch:
image_data = sketch["composite"]
logging.debug(f"Image data type: {type(image_data)}")
logging.debug(f"Image data shape: {np.array(image_data).shape}")
plt.imshow(image_data, cmap='gray')
plt.axis('off')
# Use a temporary directory for saving the image
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
plt.savefig(temp_file.name, bbox_inches='tight')
temp_file_path = temp_file.name
return temp_file_path
else:
error_message = f"Unexpected sketch data format: {type(sketch)}"
logging.error(error_message)
return error_message
def recognize_text(image_path):
# Load the image
image = Image.open(image_path).convert("RGB")
# Prepare the image for the model
pixel_values = processor(image, return_tensors="pt").pixel_values
# Generate the text
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
with gr.Blocks() as demo:
sketchpad = gr.Sketchpad(label="Draw Something")
output_image = gr.Image(label="Your Sketch")
recognized_text = gr.Textbox(label="Recognized Text")
submit_btn = gr.Button("Submit")
submit_btn.click(fn=display_sketch, inputs=sketchpad, outputs=output_image)
submit_btn.click(fn=recognize_text, inputs=output_image, outputs=recognized_text)
demo.launch()
|