Sketch / app.py
Jangai's picture
Update app.py
6bd6ea4 verified
raw
history blame
2.11 kB
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import tempfile
import os
import logging
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
logging.basicConfig(level=logging.DEBUG)
# Initialize the TrOCR model and processor
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten')
def display_sketch(sketch):
logging.debug(f"Received sketch data: {sketch}")
if isinstance(sketch, dict) and "composite" in sketch:
image_data = sketch["composite"]
logging.debug(f"Image data type: {type(image_data)}")
logging.debug(f"Image data shape: {np.array(image_data).shape}")
plt.imshow(image_data, cmap='gray')
plt.axis('off')
# Use a temporary directory for saving the image
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
plt.savefig(temp_file.name, bbox_inches='tight')
temp_file_path = temp_file.name
return temp_file_path
else:
error_message = f"Unexpected sketch data format: {type(sketch)}"
logging.error(error_message)
return error_message
def recognize_text(image_path):
# Load the image
image = Image.open(image_path).convert("RGB")
# Prepare the image for the model
pixel_values = processor(image, return_tensors="pt").pixel_values
# Generate the text
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
with gr.Blocks() as demo:
sketchpad = gr.Sketchpad(label="Draw Something")
output_image = gr.Image(label="Your Sketch")
recognized_text = gr.Textbox(label="Recognized Text")
submit_btn = gr.Button("Submit")
submit_btn.click(fn=display_sketch, inputs=sketchpad, outputs=output_image)
submit_btn.click(fn=recognize_text, inputs=output_image, outputs=recognized_text)
demo.launch()