File size: 1,635 Bytes
5e4488c
2fd0f3a
cc24739
 
11b9b6a
cc24739
11b9b6a
3dd9291
cc24739
11b9b6a
 
6bd6ea4
2fd0f3a
cc24739
 
 
 
 
 
 
 
 
7e53392
6bd6ea4
cc24739
 
 
 
 
 
11b9b6a
cc24739
11b9b6a
cc24739
11b9b6a
 
 
 
2fd0f3a
cc24739
 
6bd6ea4
11b9b6a
cc24739
 
 
cff0816
2fd0f3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import tempfile
from PIL import Image
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

# Load model and processor
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-large-handwritten")

def display_sketch(sketch):
    image_data = sketch['composite']
    plt.imshow(image_data)
    plt.axis('off')

    temp_file_path = "/mnt/data/output.png"
    plt.savefig(temp_file_path, bbox_inches='tight', pad_inches=0)
    plt.close()

    return temp_file_path

def recognize_text(image_path):
    # Open image and convert to grayscale
    image = Image.open(image_path).convert("L")
    # Resize image to 256x256
    image = image.resize((256, 256))
    # Binarize image (convert to black and white)
    image = image.point(lambda p: p > 128 and 255)
    
    # Preprocess the image
    pixel_values = processor(images=image, return_tensors="pt").pixel_values
    # Generate prediction
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_text

with gr.Blocks() as demo:
    sketchpad = gr.Sketchpad(label="Draw Something", brush_radius=10)
    sketchpad_output = gr.Image(label="Your Sketch")
    recognized_text = gr.Textbox(label="Recognized Text")
    
    sketchpad.submit(display_sketch, inputs=sketchpad, outputs=sketchpad_output).then(
        recognize_text, inputs=sketchpad_output, outputs=recognized_text
    )

demo.launch()