import gradio as gr import numpy as np from PIL import Image import os from rtmdet import RTMDet from parseq import PARSEQ from yaml import safe_load from ndl_parser import convert_to_xml_string3 from concurrent.futures import ThreadPoolExecutor import xml.etree.ElementTree as ET from reading_order.xy_cut.eval import eval_xml from xml.dom import minidom import re # Model Heading and Description model_heading = "NDL Kotenseki OCR-Lite Gradio App" description = """ Upload an image or click an example image to use. Examples: 1. 『竹取物語』上, 江戸前期. https://dl.ndl.go.jp/pid/1287221/1/2 2. 曲亭馬琴 作 ほか『人間万事賽翁馬 3巻』, 鶴喜, 寛政12. https://dl.ndl.go.jp/pid/10301438/1/17 """ article = "This application is powered by NDL Kotenseki OCR-Lite. For more details, please visit the official repository: [NDL Kotenseki OCR-Lite GitHub Repository](https://github.com/ndl-lab/ndlkotenocr-lite)." #

https://github.com/ndl-lab/ndlkotenocr-lite.

image_path = [ ['samples/digidepo_1287221_00000002.jpg'], ['samples/digidepo_10301438_0017.jpg'] ] # Functions to load models def get_detector(weights_path, classes_path, device='cpu'): assert os.path.isfile(weights_path), f"Weight file not found: {weights_path}" assert os.path.isfile(classes_path), f"Classes file not found: {classes_path}" return RTMDet(model_path=weights_path, class_mapping_path=classes_path, score_threshold=0.3, conf_thresold=0.3, iou_threshold=0.3, device=device) def get_recognizer(weights_path, classes_path, device='cpu'): assert os.path.isfile(weights_path), f"Weight file not found: {weights_path}" assert os.path.isfile(classes_path), f"Classes file not found: {classes_path}" with open(classes_path, encoding="utf-8") as f: charlist = list(safe_load(f)["model"]["charset_train"]) return PARSEQ(model_path=weights_path, charlist=charlist, device=device) def create_txt(recognizer, root, img): alltextlist = [] targetdflist=[] tatelinecnt=0 alllinecnt=0 with ThreadPoolExecutor(max_workers=4, thread_name_prefix="thread") as executor: for lineobj in root.findall(".//LINE"): xmin=int(lineobj.get("X")) ymin=int(lineobj.get("Y")) line_w=int(lineobj.get("WIDTH")) line_h=int(lineobj.get("HEIGHT")) if line_h>line_w: tatelinecnt+=1 alllinecnt+=1 lineimg=img[ymin:ymin+line_h,xmin:xmin+line_w,:] targetdflist.append(lineimg) resultlines = executor.map(recognizer.read, targetdflist) resultlines=list(resultlines) alltextlist.append("\n".join(resultlines)) alltextstr="" for text in alltextlist: alltextstr+=text+"\n" return alltextstr def create_xml(detections,classeslist,img_w,img_h,imgname, recognizer, img): resultobj=[dict(),dict()] resultobj[0][0]=list() for i in range(16): resultobj[1][i]=[] for det in detections: xmin,ymin,xmax,ymax=det["box"] conf=det["confidence"] if det["class_index"]==0: resultobj[0][0].append([xmin,ymin,xmax,ymax]) resultobj[1][det["class_index"]].append([xmin,ymin,xmax,ymax,conf]) xmlstr=convert_to_xml_string3(img_w, img_h, imgname, classeslist, resultobj,score_thr = 0.3,min_bbox_size= 5,use_block_ad= False) xmlstr=""+xmlstr+"" root = ET.fromstring(xmlstr) eval_xml(root, logger=None) targetdflist=[] tatelinecnt=0 alllinecnt=0 with ThreadPoolExecutor(max_workers=4, thread_name_prefix="thread") as executor: for lineobj in root.findall(".//LINE"): xmin=int(lineobj.get("X")) ymin=int(lineobj.get("Y")) line_w=int(lineobj.get("WIDTH")) line_h=int(lineobj.get("HEIGHT")) if line_h>line_w: tatelinecnt+=1 alllinecnt+=1 lineimg=img[ymin:ymin+line_h,xmin:xmin+line_w,:] targetdflist.append(lineimg) resultlines = executor.map(recognizer.read, targetdflist) resultlines=list(resultlines) for idx,lineobj in enumerate(root.findall(".//LINE")): lineobj.set("STRING",resultlines[idx]) return root def create_txt(root): alltextlist=[] for lineobj in root.findall(".//LINE"): alltextlist.append(lineobj.get("STRING")) return "\n".join(alltextlist) def create_xmlstr(root): rough_string = ET.tostring(root, 'utf-8') reparsed = minidom.parseString(rough_string) pretty = re.sub(r"[\t ]+\n", "", reparsed.toprettyxml(indent="\t")) # インデント後の不要な改行を削除 pretty = pretty.replace(">\n\n\t<", ">\n\t<") # 不要な空行を削除 pretty = re.sub(r"\n\s*\n", "\n", pretty) # 連続した改行(空白行を含む)を単一の改行に置換 return pretty def create_json(root): resjsonarray=[] img_w=int(root.find("PAGE").get("WIDTH")) img_h=int(root.find("PAGE").get("HEIGHT")) inputpath=root.find("PAGE").get("IMAGENAME") for idx,lineobj in enumerate(root.findall(".//LINE")): text = lineobj.get("STRING") xmin=int(lineobj.get("X")) ymin=int(lineobj.get("Y")) line_w=int(lineobj.get("WIDTH")) line_h=int(lineobj.get("HEIGHT")) try: conf=float(lineobj.get("CONF")) except: conf=0 jsonobj={"boundingBox": [[xmin,ymin],[xmin,ymin+line_h],[xmin+line_w,ymin],[xmin+line_w,ymin+line_h]], "id": idx,"isVertical": "true","text": text,"isTextline": "true","confidence": conf} resjsonarray.append(jsonobj) alljsonobj={ "contents":[resjsonarray], "imginfo": { "img_width": img_w, "img_height": img_h, "img_path":inputpath, "img_name":os.path.basename(inputpath) } } return alljsonobj # Inference Function def process(image_path: str): try: # Load the models detector = get_detector( weights_path="model/rtmdet-s-1280x1280.onnx", classes_path="config/ndl.yaml", device="cpu" ) recognizer = get_recognizer( weights_path="model/parseq-ndl-32x384-tiny-10.onnx", classes_path="config/NDLmoji.yaml", device="cpu" ) # Load image pil_image = Image.open(image_path).convert('RGB') npimg = np.array(pil_image) # Object detection detections = detector.detect(npimg) classeslist=list(detector.classes.values()) img_h,img_w=npimg.shape[:2] imgname=os.path.basename(image_path) root = create_xml(detections, classeslist, img_w, img_h, imgname, recognizer, npimg) alltext = create_txt(root) result_json = create_json(root) pil_image =detector.draw_detections(npimg, detections=detections) return pil_image, alltext, create_xmlstr(root), result_json except Exception as e: return [ Image.fromarray(np.zeros((100, 100), dtype=np.uint8)), "Error", "Error", {} ] # Gradio Inputs and Outputs inputs_image = gr.Image(type="filepath", label="Input Image") outputs_image = [ gr.Image(type="pil", label="Output Image"), gr.TextArea(label="Output Text"), gr.TextArea(label="Output XML"), gr.JSON(label="Output JSON") ] # Gradio Interface demo = gr.Interface( fn=process, inputs=inputs_image, outputs=outputs_image, title=model_heading, description=description, examples=image_path, article=article, cache_examples=False, # flagging_mode="never" allow_flagging="never" ) demo.launch(share=False, server_name="0.0.0.0")