yolo12138's picture
publish v1
37170d6
import onnxruntime
import numpy as np
import cv2
from typing import Tuple, List, Union
from .base_onnx import BaseONNX
def center_crop(image: np.ndarray, target_size: Tuple[int, int]) -> np.ndarray:
"""
Center crop the image to the target size.
Args:
image (np.ndarray): The input image.
target_size (Tuple[int, int]): The desired output size (height, width).
Returns:
np.ndarray: The cropped image.
"""
h, w, _ = image.shape
target_w, target_h = target_size
center_x = w // 2
center_y = h // 2
start_x = int(center_x - target_w // 2)
start_y = int(center_y - target_h // 2)
cropped_image = image[start_y:start_y + target_h, start_x:start_x + target_w]
return cropped_image
dict_cate_names = {
'point': '.',
'other': 'x',
'red_king': 'K',
'red_advisor': 'A',
'red_bishop': 'B',
'red_knight': 'N',
'red_rook': 'R',
'red_cannon': 'C',
'red_pawn': 'P',
'black_king': 'k',
'black_advisor': 'a',
'black_bishop': 'b',
'black_knight': 'n',
'black_rook': 'r',
'black_cannon': 'c',
'black_pawn': 'p',
}
class FULL_CLASSIFIER_ONNX(BaseONNX):
label_2_short = dict_cate_names
classes_labels = list(dict_cate_names.keys())
def __init__(self,
model_path,
# 输入图片大小
input_size=(280, 315), # (w, h)
# 图片裁剪大小
crop_size=(400, 450), # (w, h)
):
super().__init__(model_path, input_size)
self.crop_size = crop_size
def preprocess_image(self, img_bgr: cv2.UMat, is_rgb: bool = True):
if not is_rgb:
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
else:
img_rgb = img_bgr
if img_rgb.shape[:2] != self.crop_size:
# 调整图片大小 执行 center crop
img_rgb = center_crop(img_rgb, self.crop_size) # dst_size = (w, h)
# resize 到 input_size
img_rgb = cv2.resize(img_rgb, self.input_size)
# normalize mean and std
img = (img_rgb - np.array([ 123.675, 116.28, 103.53])) / np.array([58.395, 57.12, 57.375])
img = img.astype(np.float32)
# 转换为浮点型并归一化
# img = img.astype(np.float32) / 255.0
# 调整维度顺序 (H,W,C) -> (C,H,W)
img = np.transpose(img, (2, 0, 1))
# 添加 batch 维度
img = np.expand_dims(img, axis=0)
return img
def run_inference(self, image: np.ndarray) -> np.ndarray:
"""
Run inference on the image.
Args:
image (np.ndarray): The image to run inference on.
Returns:
tuple: A tuple containing the detection results and labels.
"""
# 运行推理
outputs, = self.session.run(None, {self.input_name: image})
return outputs
def pred(self, image: List[Union[cv2.UMat, str]], is_rgb: bool = True) -> Tuple[List[List[str]], List[List[str]], List[List[float]], str]:
"""
Predict the detection results of the image.
Args:
image (cv2.UMat, str): The image to predict.
Returns:
"""
if isinstance(image, str):
img_bgr = cv2.imread(image)
is_rgb = False
else:
img_bgr = image.copy()
image = self.preprocess_image(img_bgr, is_rgb)
labels = self.run_inference(image)
# 校验 labels 的 shape
assert labels.shape[1:] == (90, 16)
# shape (90, 16)
first_batch_labels = labels[0]
# 获取置信度最高的标签
# list[int]
label_indexes = np.argmax(first_batch_labels, axis=-1).tolist()
# 将标签索引转换为标签
# list[str]
label_names = [self.classes_labels[index] for index in label_indexes]
# list[str]
label_short = [self.label_2_short[name] for name in label_names]
# 获取置信度, 根据 first_batch_labels 和 label_indexes
confidence = first_batch_labels[np.arange(first_batch_labels.shape[0]), label_indexes]
label_names_10x9 = [label_names[i*9:(i+1)*9] for i in range(10)]
label_short_10x9 = [label_short[i*9:(i+1)*9] for i in range(10)]
confidence_10x9 = [confidence[i*9:(i+1)*9] for i in range(10)]
layout_str = "\n".join(["".join(row) for row in label_short_10x9])
return label_names_10x9, label_short_10x9, confidence_10x9, layout_str
def draw_pred(self, image: cv2.UMat, label_index: int, label_name: str, label_short: str, confidence: float) -> cv2.UMat:
# 在图像上绘制预测结果
cv2.putText(image, f"{label_short} {confidence:.2f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
return image
def draw_pred_with_result(self, image: cv2.UMat, results: List[Tuple[int, str, str, float]], cells_xyxy: np.ndarray, is_rgb: bool = True) -> cv2.UMat:
assert len(results) == cells_xyxy.shape[0]
if not is_rgb:
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
for i, (label_index, label_name, label_short, confidence) in enumerate(results):
# 确保坐标是整数类型
x1, y1, x2, y2 = map(int, cells_xyxy[i])
if label_name.startswith('red'):
color = (180, 105, 255) # 粉红色
elif label_name.startswith('black'):
color = (0, 100, 50) # 黑色
else:
color = (0, 0, 255) # 蓝色
if confidence < 0.5:
# yellow
color = (255, 255, 0)
# confidence:.2f 仅保留两位小数 移除
label_str = f"{label_short} {confidence:.2f}" if confidence < 0.9 else f"{label_short}"
cv2.putText(image, label_str, (x1 + 8, y2 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.3, color, 1)
return image