File size: 3,690 Bytes
4c45d43
 
 
 
 
 
 
 
 
f4d5676
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c45d43
f4d5676
4c45d43
f4d5676
 
 
 
4c45d43
 
f4d5676
 
4c45d43
 
f4d5676
 
 
 
 
 
 
 
 
 
 
 
 
c96ffbf
f4d5676
 
 
 
 
 
 
 
c96ffbf
f4d5676
 
 
 
 
 
 
 
 
 
 
 
c96ffbf
f4d5676
 
4c45d43
f4d5676
 
4c45d43
f4d5676
4c45d43
 
f4d5676
4c45d43
 
f4d5676
 
 
4c45d43
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import requests
from PIL import Image
import numpy as np
import cv2
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("aico/TrOCR-MNIST")

def _group_rectangles(rec):
    """
    Uion intersecting rectangles.
    Args:
        rec - list of rectangles in form [x, y, w, h]
    Return:
        list of grouped ractangles 
    """
    tested = [False for i in range(len(rec))]
    final = []
    i = 0
    while i < len(rec):
        if not tested[i]:
            j = i+1
            while j < len(rec):
                if not tested[j] and intersect_area(rec[i], rec[j]):
                    rec[i] = union(rec[i], rec[j])
                    tested[j] = True
                    j = i
                j += 1
            final += [rec[i]]
        i += 1

    return final

def process_image(image):
    bounding_boxes = []
    generated_text_list = []
    #boundingBoxes_2 = []
    #print(np.shape(image))
    #print(image)
    #dim = (28,28)
    #resized = cv2.resize(image, dim, interpolation = cv2.INTER_AREA)
    #rint(image.astype('uint8'))
    #cv2.imwrite("image.png",image.astype('uint8'),(28, 28))
    #mask = np.zeros(np.shape(image), dtype=np.uint8)
    thresh = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
    #gray = cv2.cvtColor(thresh, cv2.COLOR_BGR2GRAY)
    
    cnts = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    cnts = cnts[0] if len(cnts) == 2 else cnts[1]
    (cnts, _) = contours.sort_contours(cnts, method="left-to-right")
    dim = (28, 28)
    for c in cnts:
        area = cv2.contourArea(c)
        #print(area)
        #if area < 120:
        bounding_boxes.append(cv2.boundingRect(c))
        #print("for loop bb: ",bounding_boxes)
    
    boundingBoxes_filter =  [i for i in bounding_boxes if i != (0 , 0, 128, 128)]
    
    boundingBoxes = _group_rectangles(boundingBoxes_filter)
    #print(boundingBoxes)
    #
    #print(boundingBoxes_2)
    for (x, y, w, h) in boundingBoxes:
        #print(x,y,w,h)
        ROI = thresh[y:y+h, x:x+w]
        ROI2 = cv2.bitwise_not(ROI)
        borderoutput = cv2.copyMakeBorder(ROI2, 30, 30, 30, 30, cv2.BORDER_CONSTANT, value=[0, 0, 0])
        
        resized = cv2.resize(borderoutput, dim, interpolation = cv2.INTER_AREA)
        cv2.imwrite('ROI_{}.png'.format(x), resized)
        #imageinv = cv2.bitwise_not(resized)
        img = Image.fromarray(resized.astype('uint8')).convert("RGB")
        
        pixel_values = processor(img, return_tensors="pt").pixel_values
        generated_ids = model.generate(pixel_values)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        #print(generated_text)
        generated_text_list.append(generated_text)
    #img = Image.fromarray(image.astype('uint8')).convert("RGB")
    #img = Image.open("image.png").convert("RGB")
    #print(img)
    
    # prepare image
    #pixel_values = processor(img, return_tensors="pt").pixel_values

    # generate (no beam search)
    #generated_ids = model.generate(pixel_values)

    # decode
    #generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return ''.join(generated_text_list)
    #return generated_text

title = "Interactive demo: Single Digits MNIST"
description = "Aico - University Utrecht"
iface = gr.Interface(fn=process_image, 
                     inputs="sketchpad", 
                     outputs="label",
                     title = title,
                     description = description)
iface.launch(debug=True)