File size: 2,105 Bytes
5e4488c
2fd0f3a
 
4a003ae
 
2fd0f3a
6bd6ea4
 
86d342d
2fd0f3a
3dd9291
6bd6ea4
 
 
 
2fd0f3a
4da9241
 
fd8f944
 
4da9241
 
2fd0f3a
4da9241
2fd0f3a
4a003ae
 
 
 
 
 
 
4da9241
 
 
 
7e53392
6bd6ea4
 
 
 
 
 
 
 
 
 
2fd0f3a
4da9241
2fd0f3a
6bd6ea4
4da9241
6bd6ea4
 
 
cff0816
2fd0f3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import tempfile
import os
import logging
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image

logging.basicConfig(level=logging.DEBUG)

# Initialize the TrOCR model and processor
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten')

def display_sketch(sketch):
    logging.debug(f"Received sketch data: {sketch}")
    
    if isinstance(sketch, dict) and "composite" in sketch:
        image_data = sketch["composite"]
        logging.debug(f"Image data type: {type(image_data)}")
        logging.debug(f"Image data shape: {np.array(image_data).shape}")
        
        plt.imshow(image_data, cmap='gray')
        plt.axis('off')
        
        # Use a temporary directory for saving the image
        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
            plt.savefig(temp_file.name, bbox_inches='tight')
            temp_file_path = temp_file.name
        
        return temp_file_path
    else:
        error_message = f"Unexpected sketch data format: {type(sketch)}"
        logging.error(error_message)
        return error_message

def recognize_text(image_path):
    # Load the image
    image = Image.open(image_path).convert("RGB")
    # Prepare the image for the model
    pixel_values = processor(image, return_tensors="pt").pixel_values
    # Generate the text
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

with gr.Blocks() as demo:
    sketchpad = gr.Sketchpad(label="Draw Something")
    output_image = gr.Image(label="Your Sketch")
    recognized_text = gr.Textbox(label="Recognized Text")
    submit_btn = gr.Button("Submit")
    
    submit_btn.click(fn=display_sketch, inputs=sketchpad, outputs=output_image)
    submit_btn.click(fn=recognize_text, inputs=output_image, outputs=recognized_text)

demo.launch()