TrOCR-digit / app.py
aico's picture
Update app.py
f4d5676
raw
history blame
3.69 kB
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import requests
from PIL import Image
import numpy as np
import cv2
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("aico/TrOCR-MNIST")
def _group_rectangles(rec):
"""
Uion intersecting rectangles.
Args:
rec - list of rectangles in form [x, y, w, h]
Return:
list of grouped ractangles
"""
tested = [False for i in range(len(rec))]
final = []
i = 0
while i < len(rec):
if not tested[i]:
j = i+1
while j < len(rec):
if not tested[j] and intersect_area(rec[i], rec[j]):
rec[i] = union(rec[i], rec[j])
tested[j] = True
j = i
j += 1
final += [rec[i]]
i += 1
return final
def process_image(image):
bounding_boxes = []
generated_text_list = []
#boundingBoxes_2 = []
#print(np.shape(image))
#print(image)
#dim = (28,28)
#resized = cv2.resize(image, dim, interpolation = cv2.INTER_AREA)
#rint(image.astype('uint8'))
#cv2.imwrite("image.png",image.astype('uint8'),(28, 28))
#mask = np.zeros(np.shape(image), dtype=np.uint8)
thresh = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
#gray = cv2.cvtColor(thresh, cv2.COLOR_BGR2GRAY)
cnts = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
(cnts, _) = contours.sort_contours(cnts, method="left-to-right")
dim = (28, 28)
for c in cnts:
area = cv2.contourArea(c)
#print(area)
#if area < 120:
bounding_boxes.append(cv2.boundingRect(c))
print("for loop bb: ",bounding_boxes)
boundingBoxes_filter = [i for i in bounding_boxes if i != (0 , 0, 128, 128)]
boundingBoxes = _group_rectangles(boundingBoxes_filter)
#print(boundingBoxes)
#
#print(boundingBoxes_2)
for (x, y, w, h) in boundingBoxes:
print(x,y,w,h)
ROI = thresh[y:y+h, x:x+w]
ROI2 = cv2.bitwise_not(ROI)
borderoutput = cv2.copyMakeBorder(ROI2, 30, 30, 30, 30, cv2.BORDER_CONSTANT, value=[0, 0, 0])
resized = cv2.resize(borderoutput, dim, interpolation = cv2.INTER_AREA)
cv2.imwrite('ROI_{}.png'.format(x), resized)
#imageinv = cv2.bitwise_not(resized)
img = Image.fromarray(resized.astype('uint8')).convert("RGB")
pixel_values = processor(img, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_text)
generated_text_list.append(generated_text)
#img = Image.fromarray(image.astype('uint8')).convert("RGB")
#img = Image.open("image.png").convert("RGB")
#print(img)
# prepare image
#pixel_values = processor(img, return_tensors="pt").pixel_values
# generate (no beam search)
#generated_ids = model.generate(pixel_values)
# decode
#generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return ''.join(generated_text_list)
#return generated_text
title = "Interactive demo: Single Digits MNIST"
description = "Aico - University Utrecht"
iface = gr.Interface(fn=process_image,
inputs="sketchpad",
outputs="label",
title = title,
description = description)
iface.launch(debug=True)