nakamura196's picture
fix: bug
7c62087
raw
history blame
3.32 kB
import gradio as gr
import numpy as np
from PIL import Image
import os
from rtmdet import RTMDet
from parseq import PARSEQ
from yaml import safe_load
# Model Heading and Description
model_heading = "YOLOv11x くずし字認識サービス(一文字)"
description = """YOLOv11x くずし字認識サービス(一文字) Gradio demo for classification. Upload an image or click an example image to use."""
article = "<p style='text-align: center'>YOLOv11x くずし字認識サービス(一文字) is a classification model trained on the <a href=\"https://lab.hi.u-tokyo.ac.jp/datasets/kuzushiji\">東京大学史料編纂所くずし字データセット</a>.</p>"
image_path = [
['samples/default.jpg']
]
# Functions to load models
def get_detector(weights_path, classes_path, device='cpu'):
assert os.path.isfile(weights_path), f"Weight file not found: {weights_path}"
assert os.path.isfile(classes_path), f"Classes file not found: {classes_path}"
return RTMDet(model_path=weights_path,
class_mapping_path=classes_path,
score_threshold=0.3,
conf_thresold=0.3,
iou_threshold=0.3,
device=device)
def get_recognizer(weights_path, classes_path, device='cpu'):
assert os.path.isfile(weights_path), f"Weight file not found: {weights_path}"
assert os.path.isfile(classes_path), f"Classes file not found: {classes_path}"
with open(classes_path, encoding="utf-8") as f:
charlist = list(safe_load(f)["model"]["charset_train"])
return PARSEQ(model_path=weights_path, charlist=charlist, device=device)
# YOLO Inference Function
def YOLOv11x_img_inference(image_path: str):
try:
# Load the models
detector = get_detector(
weights_path="model/rtmdet-s-1280x1280.onnx",
classes_path="config/ndl.yaml",
device="cpu"
)
recognizer = get_recognizer(
weights_path="model/parseq-ndl-32x384-tiny-10.onnx",
classes_path="config/NDLmoji.yaml",
device="cpu"
)
# Load image
pil_image = Image.open(image_path).convert('RGB')
npimg = np.array(pil_image)
# Object detection
detections = detector.detect(npimg)
result_json = []
# Text recognition
for det in detections:
xmin, ymin, xmax, ymax = det["box"]
line_img = npimg[int(ymin):int(ymax), int(xmin):int(xmax)]
text = recognizer.read(line_img)
result_json.append({
"boundingBox": [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]],
"text": text,
"confidence": det["confidence"]
})
# Return results in JSON format
return result_json
except Exception as e:
return {"error": str(e)}
# Gradio Inputs and Outputs
inputs_image = gr.Image(type="filepath", label="Input Image")
outputs_image = gr.JSON(label="Output JSON")
# Gradio Interface
demo = gr.Interface(
fn=YOLOv11x_img_inference,
inputs=inputs_image,
outputs=outputs_image,
title=model_heading,
description=description,
examples=image_path,
article=article,
cache_examples=False
)
demo.launch(share=False, server_name="0.0.0.0")