Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""Untitled1.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1J4fCr7TGzdFvkCeikMAQ5af5ml2Q83W0 | |
""" | |
import os | |
os.system('pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu') | |
import os, glob, fitz | |
import cv2 | |
import os | |
import PIL | |
import torch | |
import pandas as pd | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
from tqdm import tqdm | |
from PIL import Image as im | |
from scipy import ndimage | |
from difflib import SequenceMatcher | |
from itertools import groupby | |
from datasets import load_metric | |
from datasets import load_dataset | |
from datasets.features import ClassLabel | |
from transformers import AutoProcessor | |
from PIL import Image, ImageDraw, ImageFont | |
from transformers import AutoModelForTokenClassification | |
from transformers.data.data_collator import default_data_collator | |
from datasets import Features, Sequence, ClassLabel, Value, Array2D, Array3D | |
from transformers import LayoutLMv3ForTokenClassification,LayoutLMv3FeatureExtractor | |
# define id2label | |
id2label = {0: 'song name', 1: 'artist', 2: 'year', 3: 'album', 4: 'genres', 5: 'song writer', 6: 'lyrics', 7: 'others'} | |
custom_config = r'--oem 3 --psm 6' | |
# lang='eng+deu+ita+chi_sim' | |
lang='spa' | |
label_ints = np.random.randint(0,len(PIL.ImageColor.colormap.items()),42) | |
label_color_pil = [k for k,_ in PIL.ImageColor.colormap.items()] | |
label_color = [label_color_pil[i] for i in label_ints] | |
label2color = {} | |
for k,v in id2label.items(): | |
if v[:2] == '': | |
label2color['o']=label_color[k] | |
else: | |
label2color[v[2:]]=label_color[k] | |
processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=True,lang=lang) | |
model = AutoModelForTokenClassification.from_pretrained("alitavanaali/music_layoutlmv3_model") | |
feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=True,lang=lang) | |
def unnormalize_box(bbox, width, height): | |
#print('shape is: ', np.asarray(bbox).shape, ' and box has values: ', bbox) | |
return [ | |
width * (bbox[0] / 1000), | |
height * (bbox[1] / 1000), | |
width * (bbox[2] / 1000), | |
height * (bbox[3] / 1000), | |
] | |
def iob_to_label(label): | |
if label == 0: | |
return 'song name' | |
if label == 1: | |
return 'artist' | |
if label == 2: | |
return 'year' | |
if label == 3: | |
return 'album' | |
if label == 4: | |
return 'genres' | |
if label == 5: | |
return 'song writer' | |
if label == 6: | |
return 'lyrics' | |
if label == 7: | |
return 'others' | |
# this method will detect if there is any intersect between two boxes or not | |
def intersect(w, z): | |
x1 = max(w[0], z[0]) #190 | 881 | 10 | |
y1 = max(w[1], z[1]) #90 | 49 | 273 | |
x2 = min(w[2], z[2]) #406 | 406 | 1310 | |
y2 = min(w[3], z[3]) #149 | 703 | 149 | |
if (x1 > x2 or y1 > y2): | |
return 0 | |
else: | |
# because sometimes in annotating, it is possible to overlap rows or columns by mistake | |
# for very small pixels, we check a threshold to delete them | |
area = (x2-x1) * (y2-y1) | |
if (area > 0): | |
return [int(x1), int(y1), int(x2), int(y2)] | |
else: | |
return 0 | |
def process_image(image): | |
custom_config = r'--oem 3 --psm 6' | |
# lang='eng+deu+ita+chi_sim' | |
lang='eng' | |
width, height = image.size | |
encoding_feature_extractor = feature_extractor(image, return_tensors="pt",truncation=True) | |
words, boxes = encoding_feature_extractor.words, encoding_feature_extractor.boxes | |
custom_config = r'--oem 3 --psm 6' | |
# encode | |
inference_image = [image.convert("RGB")] | |
encoding = processor(inference_image , truncation=True, return_offsets_mapping=True, return_tensors="pt", | |
padding="max_length", stride =128, max_length=512, return_overflowing_tokens=True) | |
offset_mapping = encoding.pop('offset_mapping') | |
overflow_to_sample_mapping = encoding.pop('overflow_to_sample_mapping') | |
# change the shape of pixel values | |
x = [] | |
for i in range(0, len(encoding['pixel_values'])): | |
x.append(encoding['pixel_values'][i]) | |
x = torch.stack(x) | |
encoding['pixel_values'] = x | |
# forward pass | |
outputs = model(**encoding) | |
# get predictions | |
predictions = outputs.logits.argmax(-1).squeeze().tolist() | |
token_boxes = encoding.bbox.squeeze().tolist() | |
# only keep non-subword predictions | |
preds = [] | |
l_words = [] | |
bboxes = [] | |
token_section_num = [] # related to more than 512 tokens | |
if (len(token_boxes) == 512): | |
predictions = [predictions] | |
token_boxes = [token_boxes] | |
for i in range(0, len(token_boxes)): | |
for j in range(0, len(token_boxes[i])): | |
#print(np.asarray(token_boxes[i][j]).shape) | |
unnormal_box = unnormalize_box(token_boxes[i][j], width, height) | |
#print('prediction: {} - box: {} - word:{}'.format(predictions[i][j], unnormal_box, processor.tokenizer.decode(encoding["input_ids"][i][j]))) | |
if (np.asarray(token_boxes[i][j]).shape != (4,)): | |
continue | |
elif (token_boxes[i][j] == [0, 0, 0, 0] or token_boxes[i][j] == 0): | |
#print('zero found!') | |
continue | |
# if bbox is available in the list, just we need to update text | |
elif (unnormal_box not in bboxes): | |
preds.append(predictions[i][j]) | |
l_words.append(processor.tokenizer.decode(encoding["input_ids"][i][j])) | |
bboxes.append(unnormal_box) | |
token_section_num.append(i) | |
else: | |
# we have to update the word | |
_index = bboxes.index(unnormal_box) | |
if (token_section_num[_index] == i): | |
# check if they're in a same section or not (documents with more than 512 tokens will divide to seperate | |
# parts, so it's possible to have a word in both of the pages and we have to control that repetetive words | |
# HERE: because they're in a same section, so we can merge them safely | |
l_words[_index] = l_words[_index] + processor.tokenizer.decode(encoding["input_ids"][i][j]) | |
else: | |
continue | |
return bboxes, preds, l_words, image | |
def visualize_image(final_bbox, final_preds, l_words, image): | |
draw = ImageDraw.Draw(image) | |
font = ImageFont.load_default() | |
#{0: 'document number', 1: 'elemento pn', 2: 'nombre del responsabile', 3: 'fecha', 4: 'internal reference', 5: 'others'} | |
#id2label = {0: 'song name', 1: 'artist', 2: 'year', 3: 'album', 4: 'genres', 5: 'song writer', 6: 'lyrics', 7: 'others'} | |
label2color = {'song name':'red', 'artist':'blue', 'year':'black', 'album':'green', 'genres':'brown', 'song writer':'blue', 'lyrics':'purple', 'others': 'white'} | |
l2l = {'song name':'red', 'artist':'blue', 'year':'black', 'album':'green', 'genres':'brown', 'song writer':'blue','lyrics':'purple', 'others':'white'} | |
f_labels = {'song name':'red', 'artist':'blue', 'year':'black', 'album':'green', 'genres':'brown', 'song writer':'blue','lyrics':'purple', 'others':'white'} | |
json_df = [] | |
# draw bboxes on image | |
for ix, (prediction, box) in enumerate(zip(final_preds, final_bbox)): | |
predicted_label = iob_to_label(prediction).lower() | |
if (predicted_label != 'others'): | |
draw.rectangle(box, outline=label2color[predicted_label]) | |
draw.text((box[0]+10, box[1]-10), text=predicted_label, fill=label2color[predicted_label], font=font) | |
json_dict = {} | |
json_dict['TEXT'] = l_words[ix] | |
json_dict['LABEL'] = f_labels[predicted_label] | |
json_df.append(json_dict) | |
return image, json_df | |
def mergeCloseBoxes(pr, bb, wr, threshold): | |
idx = 0 | |
final_bbox =[] | |
final_preds =[] | |
final_words=[] | |
for box, pred, word in zip(bb, pr, wr): | |
if (pred=='others'): | |
continue | |
else: | |
flag = False | |
for b, p, w in zip(bb, pr, wr): | |
if (p == 'others'): | |
#print('others') | |
#print('-------') | |
continue | |
elif (box==b): # we shouldn't check each item with itself | |
#print('itself') | |
#print('--------') | |
continue | |
else: | |
XMIN, YMIN, XMAX, YMAX = box | |
xmin, ymin, xmax, ymax = b | |
#print('word: {} , w:{}'.format(word, w)) | |
intsc = intersect([XMIN, YMIN, XMAX+threshold, YMAX], [xmin-threshold, ymin, xmax, ymax]) | |
if (intsc != 0 and pred==p): | |
flag = True | |
#print('there is intersect') | |
# if(abs(XMAX - xmin) < treshold and abs(YMIN - ymin) < 10): | |
# we have to check if there is any intersection between box and all the values in final_bbox list | |
# because if we have updated it before, now we have to update in final_bbox | |
#print(final_bbox) | |
print(*final_bbox, sep=",") | |
merged_box = [ | |
min(XMIN, xmin), | |
min(YMIN, ymin), | |
max(XMAX, xmax), | |
max(YMAX, ymax) | |
] | |
merged_words = word + ' ' + w | |
# add to final_bbox | |
wasAvailable = False | |
for id, fbox in enumerate(final_bbox): | |
if (intersect(box, fbox) != 0 and pred==final_preds[id]): | |
#print('added before!') | |
# box is inside another processed box, so we have to update it | |
wasAvailable = True | |
merged_box = [ | |
min(fbox[0], min(XMIN, xmin)), | |
min(fbox[1], min(YMIN, ymin)), | |
max(fbox[2], max(XMAX, xmax)), | |
max(fbox[3], max(YMAX, ymax)) | |
] | |
final_bbox[id] = merged_box | |
final_words[id] = final_words[id] + ' ' + w | |
break | |
if (not wasAvailable): | |
# there was no intersect, bbox is not added before | |
#print('not added before, so we add merged box!') | |
final_bbox.append(merged_box) | |
final_preds.append(pred) | |
final_words.append(merged_words) | |
'''else: | |
print() | |
final_bbox.append(box) | |
final_preds.append(pred) | |
final_words.append(word)''' | |
if (flag == False): | |
#print('flag is false, word: {} added'.format(word)) | |
# there is no intersect between word and the others | |
# we will check for last time if box is inside the others, because if the word is last word (like Juan + Mulian + Alexander) (Alexander) | |
# it is added before but it has not intersection with others, so we will check to prevent | |
for id, fbox in enumerate(final_bbox): | |
if (intersect(box, fbox) != 0 and pred==final_preds[id]): | |
flag = True | |
if (not flag): | |
final_bbox.append(box) | |
final_preds.append(pred) | |
final_words.append(word) | |
return final_bbox, final_preds, final_words | |
def createDataframe(preds, words): | |
df = pd.DataFrame(columns = ['song name', 'artist', 'year', 'album', 'genres', 'song writer', 'lyrics', 'others']) | |
if (len(preds) > 0): | |
flag_label = preds[0] | |
#print(preds) | |
#print(words) | |
#print('@@@@@') | |
#print(flag_label) | |
row_number = -1 | |
for i in range(len(preds)): | |
#print('i is: {}'.format(i)) | |
if (preds[i] == flag_label): | |
row_number = row_number + 1 | |
df.at[row_number, preds[i]] = words[i] | |
#print('row number is: {}'.format(row_number)) | |
continue | |
else: | |
#print('row_number {} is <= of df.shape {}'.format(row_number, df.shape[0])) | |
#print(pd.isna(df[preds[i]].iloc[row_number])) | |
#print(pd.isna(df[preds[i]].iloc[row_number])) | |
if(pd.isna(df[preds[i]].iloc[row_number])): | |
df.at[row_number, preds[i]] = words[i] | |
else: | |
row_number = row_number + 1 | |
df.at[row_number, preds[i]] = words[i] | |
return df | |
def isInside(w, z): | |
# return True if w is inside z, if z is inside w return false | |
if(w[0] >= z[0] and w[1] >= z[1] and w[2] <= z[2] and w[3] <= z[3]): | |
return True | |
return False | |
def removeSimilarItems(final_bbox, final_preds, final_words): | |
_bb =[] | |
_pp=[] | |
_ww=[] | |
for i in range(len(final_bbox)): | |
_bb.append(final_bbox[i]) | |
_pp.append(final_preds[i]) | |
_ww.append(final_words[i]) | |
for j in range(len(final_bbox)): | |
if (final_bbox[i] == final_bbox[j]): | |
continue | |
elif (isInside(final_bbox[i], final_bbox[j]) and final_preds[i]==final_preds[j] ): | |
# box i is inside box j, so we have to remove it | |
#print('box[i]: {} is inside box[j]:{}'.format(final_bbox[i], final_bbox[j])) | |
_bb = _bb[:-1] | |
_pp = _pp[:-1] | |
_ww = _ww[:-1] | |
continue | |
return _bb, _pp, _ww | |
#[45.604, 2309.811, 66.652, 2391.6839999999997] | |
def process_form(preds, words, bboxes): | |
final_bbox, final_preds, final_words = mergeCloseBoxes(preds, bboxes, words, 30) | |
_bbox, _preds, _words = removeSimilarItems(final_bbox, final_preds, final_words) | |
# convert float list to int | |
_bbox = [[int(x) for x in item ] for item in _bbox] | |
# creat data object for sorting | |
data = [] | |
for index in range(len(_bbox)): | |
data.append((_bbox[index], _preds[index], _words[index])) | |
# sorting by the height of the page | |
sorted_list = sorted( | |
data, | |
key=lambda x: x[0][1] | |
) | |
_bbox = [item[0] for item in sorted_list] | |
_preds = [item[1] for item in sorted_list] | |
_words = [item[2] for item in sorted_list] | |
return _bbox, _preds, _words | |
def mergeImageVertical(a): | |
list_im = a | |
imgs = [ Image.open(i) for i in list_im ] | |
# pick the image which is the smallest, and resize the others to match it (can be arbitrary image shape here) | |
min_shape = sorted( [(np.sum(i.size), i.size ) for i in imgs])[0][1] | |
imgs_comb = np.hstack([i.resize(min_shape) for i in imgs]) | |
# for a vertical stacking it is simple: use vstack | |
imgs_comb = np.vstack([i.resize(min_shape) for i in imgs]) | |
imgs_comb = Image.fromarray( imgs_comb) | |
imgs_comb.save( 'Trifecta_vertical.jpg' ) | |
return imgs_comb | |
def completepreprocess(pdffile): | |
myDataFrame = pd.DataFrame() | |
a=[] | |
doc = fitz.open(pdffile) | |
for i in range(0,len(doc)): | |
page = doc.load_page(i) | |
zoom = 2 # zoom factor | |
mat = fitz.Matrix(zoom, zoom) | |
pix = page.get_pixmap(matrix = mat,dpi = 200) | |
t=pix.save("page"+str(i)+".jpg") | |
images = Image.open("page"+str(i)+".jpg") | |
image = images.convert("RGB") | |
bbox, preds, words, image = process_image(image) | |
print(preds) | |
print(words) | |
im, df = visualize_image(bbox, preds, words, image) | |
im1 = im.save("page"+str(i)+".jpg") | |
a.append("page"+str(i)+".jpg") | |
pred_list = [] | |
for number in preds: | |
pred_list.append(iob_to_label(number)) | |
_bbox, _preds, _words = process_form(pred_list, words, bbox) | |
print('page: ' + str(i) + ' ' + str(len(_preds))+ ' ' + str(len(_words))) | |
df = createDataframe(_preds, _words) | |
myDataFrame=myDataFrame.append(df) | |
im2=mergeImageVertical(a) | |
return im2,myDataFrame | |
title = "Interactive demo: Music Information Extraction model" | |
description = "Music Information Extraction - We used Microsoft’s LayoutLMv3 trained on Our Music Dataset through csv's to predict the labels. To use it, simply upload a PDF or use the example PDF below and click ‘Submit’. Results will show up in a few seconds. If you want to make the output bigger, right-click on it and select ‘Open image in new tab’.Train =16 ,Test =7" | |
css = """.output_image, .input_image {height: 600px !important}""" | |
#examples = [["461BHH69.PDF"],["AP-481-RF.PDF"],["DP-095-ML.PDF"],["DQ-231-LL.PDF"],["FK-941-ET.PDF"], ["FL-078-NH.PDF"] | |
# ,["14ZZ69.PDF"],["74BCA69.PDF"],["254BEG69.PDF"],["761BJQ69.PDF"],["AB-486-EH.PDF"],["AZ-211-ZA.PDF"], ["CY-073-YV.PDF"]] | |
# ["744BJQ69.PDF"], ['tarros_2.jpg'], | |
examples = [['test1.jpg'], ['doc1.pdf'], ['doc1.2.pdf']] | |
iface = gr.Interface(fn=completepreprocess, | |
#inputs=gr.inputs.Image(type="pil",optional=True,label="upload file"), | |
inputs=gr.File(label="PDF"), | |
#inputs=gr.inputs.Image(type="pil") | |
outputs=[gr.outputs.Image(type="pil", label="annotated image"),"dataframe"] , | |
title=title, | |
description=description, | |
examples=examples, | |
css=css, | |
analytics_enabled = True, enable_queue=True) | |
iface.launch(inline=False , debug=True) |