import gradio as gr
import numpy as np
from PIL import Image
import os
from rtmdet import RTMDet
from parseq import PARSEQ
from yaml import safe_load
from ndl_parser import convert_to_xml_string3
from concurrent.futures import ThreadPoolExecutor
import xml.etree.ElementTree as ET
from reading_order.xy_cut.eval import eval_xml
from xml.dom import minidom
import re
# Model Heading and Description
model_heading = "NDL Kotenseki OCR-Lite Gradio App"
description = """
Upload an image or click an example image to use.
Examples:
1. 『竹取物語』上, 江戸前期. https://dl.ndl.go.jp/pid/1287221/1/2
2. 曲亭馬琴 作 ほか『人間万事賽翁馬 3巻』, 鶴喜, 寛政12. https://dl.ndl.go.jp/pid/10301438/1/17
"""
article = "This application is powered by NDL Kotenseki OCR-Lite. For more details, please visit the official repository: [NDL Kotenseki OCR-Lite GitHub Repository](https://github.com/ndl-lab/ndlkotenocr-lite)."
#
https://github.com/ndl-lab/ndlkotenocr-lite.
image_path = [
['samples/digidepo_1287221_00000002.jpg'],
['samples/digidepo_10301438_0017.jpg']
]
# Functions to load models
def get_detector(weights_path, classes_path, device='cpu'):
assert os.path.isfile(weights_path), f"Weight file not found: {weights_path}"
assert os.path.isfile(classes_path), f"Classes file not found: {classes_path}"
return RTMDet(model_path=weights_path,
class_mapping_path=classes_path,
score_threshold=0.3,
conf_thresold=0.3,
iou_threshold=0.3,
device=device)
def get_recognizer(weights_path, classes_path, device='cpu'):
assert os.path.isfile(weights_path), f"Weight file not found: {weights_path}"
assert os.path.isfile(classes_path), f"Classes file not found: {classes_path}"
with open(classes_path, encoding="utf-8") as f:
charlist = list(safe_load(f)["model"]["charset_train"])
return PARSEQ(model_path=weights_path, charlist=charlist, device=device)
def create_txt(recognizer, root, img):
alltextlist = []
targetdflist=[]
tatelinecnt=0
alllinecnt=0
with ThreadPoolExecutor(max_workers=4, thread_name_prefix="thread") as executor:
for lineobj in root.findall(".//LINE"):
xmin=int(lineobj.get("X"))
ymin=int(lineobj.get("Y"))
line_w=int(lineobj.get("WIDTH"))
line_h=int(lineobj.get("HEIGHT"))
if line_h>line_w:
tatelinecnt+=1
alllinecnt+=1
lineimg=img[ymin:ymin+line_h,xmin:xmin+line_w,:]
targetdflist.append(lineimg)
resultlines = executor.map(recognizer.read, targetdflist)
resultlines=list(resultlines)
alltextlist.append("\n".join(resultlines))
alltextstr=""
for text in alltextlist:
alltextstr+=text+"\n"
return alltextstr
def create_xml(detections,classeslist,img_w,img_h,imgname, recognizer, img):
resultobj=[dict(),dict()]
resultobj[0][0]=list()
for i in range(16):
resultobj[1][i]=[]
for det in detections:
xmin,ymin,xmax,ymax=det["box"]
conf=det["confidence"]
if det["class_index"]==0:
resultobj[0][0].append([xmin,ymin,xmax,ymax])
resultobj[1][det["class_index"]].append([xmin,ymin,xmax,ymax,conf])
xmlstr=convert_to_xml_string3(img_w, img_h, imgname, classeslist, resultobj,score_thr = 0.3,min_bbox_size= 5,use_block_ad= False)
xmlstr=""+xmlstr+""
root = ET.fromstring(xmlstr)
eval_xml(root, logger=None)
targetdflist=[]
tatelinecnt=0
alllinecnt=0
with ThreadPoolExecutor(max_workers=4, thread_name_prefix="thread") as executor:
for lineobj in root.findall(".//LINE"):
xmin=int(lineobj.get("X"))
ymin=int(lineobj.get("Y"))
line_w=int(lineobj.get("WIDTH"))
line_h=int(lineobj.get("HEIGHT"))
if line_h>line_w:
tatelinecnt+=1
alllinecnt+=1
lineimg=img[ymin:ymin+line_h,xmin:xmin+line_w,:]
targetdflist.append(lineimg)
resultlines = executor.map(recognizer.read, targetdflist)
resultlines=list(resultlines)
for idx,lineobj in enumerate(root.findall(".//LINE")):
lineobj.set("STRING",resultlines[idx])
return root
def create_txt(root):
alltextlist=[]
for lineobj in root.findall(".//LINE"):
alltextlist.append(lineobj.get("STRING"))
return "\n".join(alltextlist)
def create_xmlstr(root):
rough_string = ET.tostring(root, 'utf-8')
reparsed = minidom.parseString(rough_string)
pretty = re.sub(r"[\t ]+\n", "", reparsed.toprettyxml(indent="\t")) # インデント後の不要な改行を削除
pretty = pretty.replace(">\n\n\t<", ">\n\t<") # 不要な空行を削除
pretty = re.sub(r"\n\s*\n", "\n", pretty) # 連続した改行(空白行を含む)を単一の改行に置換
return pretty
def create_json(root):
resjsonarray=[]
img_w=int(root.find("PAGE").get("WIDTH"))
img_h=int(root.find("PAGE").get("HEIGHT"))
inputpath=root.find("PAGE").get("IMAGENAME")
for idx,lineobj in enumerate(root.findall(".//LINE")):
text = lineobj.get("STRING")
xmin=int(lineobj.get("X"))
ymin=int(lineobj.get("Y"))
line_w=int(lineobj.get("WIDTH"))
line_h=int(lineobj.get("HEIGHT"))
try:
conf=float(lineobj.get("CONF"))
except:
conf=0
jsonobj={"boundingBox": [[xmin,ymin],[xmin,ymin+line_h],[xmin+line_w,ymin],[xmin+line_w,ymin+line_h]],
"id": idx,"isVertical": "true","text": text,"isTextline": "true","confidence": conf}
resjsonarray.append(jsonobj)
alljsonobj={
"contents":[resjsonarray],
"imginfo": {
"img_width": img_w,
"img_height": img_h,
"img_path":inputpath,
"img_name":os.path.basename(inputpath)
}
}
return alljsonobj
# Inference Function
def process(image_path: str):
try:
# Load the models
detector = get_detector(
weights_path="model/rtmdet-s-1280x1280.onnx",
classes_path="config/ndl.yaml",
device="cpu"
)
recognizer = get_recognizer(
weights_path="model/parseq-ndl-32x384-tiny-10.onnx",
classes_path="config/NDLmoji.yaml",
device="cpu"
)
# Load image
pil_image = Image.open(image_path).convert('RGB')
npimg = np.array(pil_image)
# Object detection
detections = detector.detect(npimg)
classeslist=list(detector.classes.values())
img_h,img_w=npimg.shape[:2]
imgname=os.path.basename(image_path)
root = create_xml(detections, classeslist, img_w, img_h, imgname, recognizer, npimg)
alltext = create_txt(root)
result_json = create_json(root)
pil_image =detector.draw_detections(npimg, detections=detections)
return pil_image, alltext, create_xmlstr(root), result_json
except Exception as e:
return [
Image.fromarray(np.zeros((100, 100), dtype=np.uint8)),
"Error",
"Error",
{}
]
# Gradio Inputs and Outputs
inputs_image = gr.Image(type="filepath", label="Input Image")
outputs_image = [
gr.Image(type="pil", label="Output Image"),
gr.TextArea(label="Output Text"),
gr.TextArea(label="Output XML"),
gr.JSON(label="Output JSON")
]
# Gradio Interface
demo = gr.Interface(
fn=process,
inputs=inputs_image,
outputs=outputs_image,
title=model_heading,
description=description,
examples=image_path,
article=article,
cache_examples=False,
# flagging_mode="never"
allow_flagging="never"
)
demo.launch(share=False, server_name="0.0.0.0")