Spaces:
Sleeping
Sleeping
import gradio as gr | |
from ultralytics import YOLO | |
from fastapi import FastAPI | |
from PIL import Image | |
import torch | |
import spaces | |
import numpy as np | |
app = FastAPI() | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# 移除 .to(device),在预测时指定设备 | |
model = YOLO('nailong_yolo11.onnx') | |
def predict(img): | |
# 将输入图像转换为PIL Image对象 | |
input_image = Image.fromarray(img) | |
# 保持长宽比的情况下调整尺寸 | |
w, h = input_image.size | |
scale = min(640/w, 640/h) | |
new_w, new_h = int(w * scale), int(h * scale) | |
if scale != 1: | |
input_image = input_image.resize((new_w, new_h), Image.LANCZOS) | |
# 转换为numpy数组并进行预测 | |
img_array = np.array(input_image) | |
# 在predict时指定device | |
results = model.predict(img_array, device=device) | |
result = results[0] | |
# 获取预测结果 | |
result_img = result.plot() | |
# 处理检测信息 | |
info = { | |
"detected": len(result.boxes) > 0, | |
"count": len(result.boxes), | |
"detections": [] | |
} | |
if info["detected"]: | |
# 获取每个检测框的信息 | |
for box in result.boxes: | |
conf = float(box.conf[0]) | |
cls = int(box.cls[0]) | |
cls_name = result.names[cls] | |
detection_info = { | |
"class": cls_name, | |
"confidence": f"{conf:.2%}" | |
} | |
info["detections"].append(detection_info) | |
# 生成输出文本 | |
output_text = f"""检测结果: | |
- 是否检测到目标: {'是' if info['detected'] else '否'} | |
- 检测到的目标数量: {info['count']}""" | |
if info["detections"]: | |
output_text += "\n- 详细信息:" | |
for idx, det in enumerate(info["detections"], 1): | |
output_text += f"\n 目标 {idx}: {det['class']} (置信度: {det['confidence']})" | |
# 如果需要将结果图像缩放回原始尺寸 | |
if scale != 1: | |
result_img = Image.fromarray(result_img) | |
result_img = result_img.resize((w, h), Image.LANCZOS) | |
result_img = np.array(result_img) | |
return result_img, output_text | |
demo = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(label="输入图片"), | |
outputs=[ | |
gr.Image(label="检测结果", type="numpy"), | |
gr.Textbox(label="检测信息") | |
], | |
title="🐉 奶龙杀手 (NailongKiller)", | |
description="上传图片来检测奶龙 | Upload an image to detect Nailong", | |
# examples=[["example1.jpg"]], | |
# cache_examples=True | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860) |