Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import os | |
from rtmdet import RTMDet | |
from parseq import PARSEQ | |
from yaml import safe_load | |
from ndl_parser import convert_to_xml_string3 | |
from concurrent.futures import ThreadPoolExecutor | |
import xml.etree.ElementTree as ET | |
from reading_order.xy_cut.eval import eval_xml | |
from xml.dom import minidom | |
import re | |
# Model Heading and Description | |
model_heading = "NDL Kotenseki OCR-Lite Gradio App" | |
description = """ | |
Upload an image or click an example image to use. | |
Examples: | |
1. 『竹取物語』上, 江戸前期. https://dl.ndl.go.jp/pid/1287221/1/2 | |
2. 曲亭馬琴 作 ほか『人間万事賽翁馬 3巻』, 鶴喜, 寛政12. https://dl.ndl.go.jp/pid/10301438/1/17 | |
""" | |
article = "This application is powered by NDL Kotenseki OCR-Lite. For more details, please visit the official repository: [NDL Kotenseki OCR-Lite GitHub Repository](https://github.com/ndl-lab/ndlkotenocr-lite)." | |
# <p style='text-align: center'><a href=\"https://github.com/ndl-lab/ndlkotenocr-lite\">https://github.com/ndl-lab/ndlkotenocr-lite</a>.</p> | |
image_path = [ | |
['samples/digidepo_1287221_00000002.jpg'], | |
['samples/digidepo_10301438_0017.jpg'] | |
] | |
# Functions to load models | |
def get_detector(weights_path, classes_path, device='cpu'): | |
assert os.path.isfile(weights_path), f"Weight file not found: {weights_path}" | |
assert os.path.isfile(classes_path), f"Classes file not found: {classes_path}" | |
return RTMDet(model_path=weights_path, | |
class_mapping_path=classes_path, | |
score_threshold=0.3, | |
conf_thresold=0.3, | |
iou_threshold=0.3, | |
device=device) | |
def get_recognizer(weights_path, classes_path, device='cpu'): | |
assert os.path.isfile(weights_path), f"Weight file not found: {weights_path}" | |
assert os.path.isfile(classes_path), f"Classes file not found: {classes_path}" | |
with open(classes_path, encoding="utf-8") as f: | |
charlist = list(safe_load(f)["model"]["charset_train"]) | |
return PARSEQ(model_path=weights_path, charlist=charlist, device=device) | |
def create_txt(recognizer, root, img): | |
alltextlist = [] | |
targetdflist=[] | |
tatelinecnt=0 | |
alllinecnt=0 | |
with ThreadPoolExecutor(max_workers=4, thread_name_prefix="thread") as executor: | |
for lineobj in root.findall(".//LINE"): | |
xmin=int(lineobj.get("X")) | |
ymin=int(lineobj.get("Y")) | |
line_w=int(lineobj.get("WIDTH")) | |
line_h=int(lineobj.get("HEIGHT")) | |
if line_h>line_w: | |
tatelinecnt+=1 | |
alllinecnt+=1 | |
lineimg=img[ymin:ymin+line_h,xmin:xmin+line_w,:] | |
targetdflist.append(lineimg) | |
resultlines = executor.map(recognizer.read, targetdflist) | |
resultlines=list(resultlines) | |
alltextlist.append("\n".join(resultlines)) | |
alltextstr="" | |
for text in alltextlist: | |
alltextstr+=text+"\n" | |
return alltextstr | |
def create_xml(detections,classeslist,img_w,img_h,imgname, recognizer, img): | |
resultobj=[dict(),dict()] | |
resultobj[0][0]=list() | |
for i in range(16): | |
resultobj[1][i]=[] | |
for det in detections: | |
xmin,ymin,xmax,ymax=det["box"] | |
conf=det["confidence"] | |
if det["class_index"]==0: | |
resultobj[0][0].append([xmin,ymin,xmax,ymax]) | |
resultobj[1][det["class_index"]].append([xmin,ymin,xmax,ymax,conf]) | |
xmlstr=convert_to_xml_string3(img_w, img_h, imgname, classeslist, resultobj,score_thr = 0.3,min_bbox_size= 5,use_block_ad= False) | |
xmlstr="<OCRDATASET>"+xmlstr+"</OCRDATASET>" | |
root = ET.fromstring(xmlstr) | |
eval_xml(root, logger=None) | |
targetdflist=[] | |
tatelinecnt=0 | |
alllinecnt=0 | |
with ThreadPoolExecutor(max_workers=4, thread_name_prefix="thread") as executor: | |
for lineobj in root.findall(".//LINE"): | |
xmin=int(lineobj.get("X")) | |
ymin=int(lineobj.get("Y")) | |
line_w=int(lineobj.get("WIDTH")) | |
line_h=int(lineobj.get("HEIGHT")) | |
if line_h>line_w: | |
tatelinecnt+=1 | |
alllinecnt+=1 | |
lineimg=img[ymin:ymin+line_h,xmin:xmin+line_w,:] | |
targetdflist.append(lineimg) | |
resultlines = executor.map(recognizer.read, targetdflist) | |
resultlines=list(resultlines) | |
for idx,lineobj in enumerate(root.findall(".//LINE")): | |
lineobj.set("STRING",resultlines[idx]) | |
return root | |
def create_txt(root): | |
alltextlist=[] | |
for lineobj in root.findall(".//LINE"): | |
alltextlist.append(lineobj.get("STRING")) | |
return "\n".join(alltextlist) | |
def create_xmlstr(root): | |
rough_string = ET.tostring(root, 'utf-8') | |
reparsed = minidom.parseString(rough_string) | |
pretty = re.sub(r"[\t ]+\n", "", reparsed.toprettyxml(indent="\t")) # インデント後の不要な改行を削除 | |
pretty = pretty.replace(">\n\n\t<", ">\n\t<") # 不要な空行を削除 | |
pretty = re.sub(r"\n\s*\n", "\n", pretty) # 連続した改行(空白行を含む)を単一の改行に置換 | |
return pretty | |
def create_json(root): | |
resjsonarray=[] | |
img_w=int(root.find("PAGE").get("WIDTH")) | |
img_h=int(root.find("PAGE").get("HEIGHT")) | |
inputpath=root.find("PAGE").get("IMAGENAME") | |
for idx,lineobj in enumerate(root.findall(".//LINE")): | |
text = lineobj.get("STRING") | |
xmin=int(lineobj.get("X")) | |
ymin=int(lineobj.get("Y")) | |
line_w=int(lineobj.get("WIDTH")) | |
line_h=int(lineobj.get("HEIGHT")) | |
try: | |
conf=float(lineobj.get("CONF")) | |
except: | |
conf=0 | |
jsonobj={"boundingBox": [[xmin,ymin],[xmin,ymin+line_h],[xmin+line_w,ymin],[xmin+line_w,ymin+line_h]], | |
"id": idx,"isVertical": "true","text": text,"isTextline": "true","confidence": conf} | |
resjsonarray.append(jsonobj) | |
alljsonobj={ | |
"contents":[resjsonarray], | |
"imginfo": { | |
"img_width": img_w, | |
"img_height": img_h, | |
"img_path":inputpath, | |
"img_name":os.path.basename(inputpath) | |
} | |
} | |
return alljsonobj | |
# Inference Function | |
def process(image_path: str): | |
try: | |
# Load the models | |
detector = get_detector( | |
weights_path="model/rtmdet-s-1280x1280.onnx", | |
classes_path="config/ndl.yaml", | |
device="cpu" | |
) | |
recognizer = get_recognizer( | |
weights_path="model/parseq-ndl-32x384-tiny-10.onnx", | |
classes_path="config/NDLmoji.yaml", | |
device="cpu" | |
) | |
# Load image | |
pil_image = Image.open(image_path).convert('RGB') | |
npimg = np.array(pil_image) | |
# Object detection | |
detections = detector.detect(npimg) | |
classeslist=list(detector.classes.values()) | |
img_h,img_w=npimg.shape[:2] | |
imgname=os.path.basename(image_path) | |
root = create_xml(detections, classeslist, img_w, img_h, imgname, recognizer, npimg) | |
alltext = create_txt(root) | |
result_json = create_json(root) | |
pil_image =detector.draw_detections(npimg, detections=detections) | |
return pil_image, alltext, create_xmlstr(root), result_json | |
except Exception as e: | |
return [ | |
Image.fromarray(np.zeros((100, 100), dtype=np.uint8)), | |
"Error", | |
"Error", | |
{} | |
] | |
# Gradio Inputs and Outputs | |
inputs_image = gr.Image(type="filepath", label="Input Image") | |
outputs_image = [ | |
gr.Image(type="pil", label="Output Image"), | |
gr.TextArea(label="Output Text"), | |
gr.TextArea(label="Output XML"), | |
gr.JSON(label="Output JSON") | |
] | |
# Gradio Interface | |
demo = gr.Interface( | |
fn=process, | |
inputs=inputs_image, | |
outputs=outputs_image, | |
title=model_heading, | |
description=description, | |
examples=image_path, | |
article=article, | |
cache_examples=False, | |
# flagging_mode="never" | |
allow_flagging="never" | |
) | |
demo.launch(share=False, server_name="0.0.0.0") |