import gradio as gr from transformers import TrOCRProcessor, VisionEncoderDecoderModel import requests from PIL import Image import numpy as np import cv2 processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") model = VisionEncoderDecoderModel.from_pretrained("aico/TrOCR-MNIST") def _group_rectangles(rec): """ Uion intersecting rectangles. Args: rec - list of rectangles in form [x, y, w, h] Return: list of grouped ractangles """ tested = [False for i in range(len(rec))] final = [] i = 0 while i < len(rec): if not tested[i]: j = i+1 while j < len(rec): if not tested[j] and intersect_area(rec[i], rec[j]): rec[i] = union(rec[i], rec[j]) tested[j] = True j = i j += 1 final += [rec[i]] i += 1 return final def process_image(image): bounding_boxes = [] generated_text_list = [] #boundingBoxes_2 = [] #print(np.shape(image)) #print(image) #dim = (28,28) #resized = cv2.resize(image, dim, interpolation = cv2.INTER_AREA) #rint(image.astype('uint8')) #cv2.imwrite("image.png",image.astype('uint8'),(28, 28)) #mask = np.zeros(np.shape(image), dtype=np.uint8) thresh = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] #gray = cv2.cvtColor(thresh, cv2.COLOR_BGR2GRAY) cnts = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) cnts = cnts[0] if len(cnts) == 2 else cnts[1] (cnts, _) = contours.sort_contours(cnts, method="left-to-right") dim = (28, 28) for c in cnts: area = cv2.contourArea(c) #print(area) #if area < 120: bounding_boxes.append(cv2.boundingRect(c)) print("for loop bb: ",bounding_boxes) boundingBoxes_filter = [i for i in bounding_boxes if i != (0 , 0, 128, 128)] boundingBoxes = _group_rectangles(boundingBoxes_filter) #print(boundingBoxes) # #print(boundingBoxes_2) for (x, y, w, h) in boundingBoxes: print(x,y,w,h) ROI = thresh[y:y+h, x:x+w] ROI2 = cv2.bitwise_not(ROI) borderoutput = cv2.copyMakeBorder(ROI2, 30, 30, 30, 30, cv2.BORDER_CONSTANT, value=[0, 0, 0]) resized = cv2.resize(borderoutput, dim, interpolation = cv2.INTER_AREA) cv2.imwrite('ROI_{}.png'.format(x), resized) #imageinv = cv2.bitwise_not(resized) img = Image.fromarray(resized.astype('uint8')).convert("RGB") pixel_values = processor(img, return_tensors="pt").pixel_values generated_ids = model.generate(pixel_values) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] print(generated_text) generated_text_list.append(generated_text) #img = Image.fromarray(image.astype('uint8')).convert("RGB") #img = Image.open("image.png").convert("RGB") #print(img) # prepare image #pixel_values = processor(img, return_tensors="pt").pixel_values # generate (no beam search) #generated_ids = model.generate(pixel_values) # decode #generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return ''.join(generated_text_list) #return generated_text title = "Interactive demo: Single Digits MNIST" description = "Aico - University Utrecht" iface = gr.Interface(fn=process_image, inputs="sketchpad", outputs="label", title = title, description = description) iface.launch(debug=True)