Hakureirm commited on
Commit
457c927
·
verified ·
1 Parent(s): 719133f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -59
app.py CHANGED
@@ -1,86 +1,109 @@
1
  import gradio as gr
2
  from ultralytics import YOLO
3
  from fastapi import FastAPI
4
- from PIL import Image
5
  import torch
6
  import spaces
7
  import numpy as np
 
8
 
9
  app = FastAPI()
10
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
- # 移除 .to(device),在预测时指定设备
12
- model = YOLO('nailong_yolo11.onnx')
13
 
14
  @spaces.GPU
15
- def predict(img):
16
- # 将输入图像转换为PIL Image对象
17
- input_image = Image.fromarray(img)
18
 
19
- # 保持长宽比的情况下调整尺寸
20
- w, h = input_image.size
21
- scale = min(640/w, 640/h)
22
- new_w, new_h = int(w * scale), int(h * scale)
23
- if scale != 1:
24
- input_image = input_image.resize((new_w, new_h), Image.LANCZOS)
25
 
26
- # 转换为numpy数组并进行预测
27
- img_array = np.array(input_image)
 
 
 
 
 
 
28
 
29
- # 在predict时指定device
30
- results = model.predict(img_array, device=device)
31
- result = results[0]
 
 
32
 
33
- # 获取预测结果
34
- result_img = result.plot()
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- # 处理检测信息
37
- info = {
38
- "detected": len(result.boxes) > 0,
39
- "count": len(result.boxes),
40
- "detections": []
41
- }
 
 
 
 
 
 
 
 
42
 
43
- if info["detected"]:
44
- # 获取每个检测框的信息
45
- for box in result.boxes:
46
- conf = float(box.conf[0])
47
- cls = int(box.cls[0])
48
- cls_name = result.names[cls]
49
-
50
- detection_info = {
51
- "class": cls_name,
52
- "confidence": f"{conf:.2%}"
53
- }
54
- info["detections"].append(detection_info)
55
 
56
- # 生成输出文本
57
- output_text = f"""检测结果:
58
- - 是否检测到目标: {'是' if info['detected'] else '否'}
59
- - 检测到的目标数量: {info['count']}"""
 
 
 
 
60
 
61
- if info["detections"]:
62
- output_text += "\n- 详细信息:"
63
- for idx, det in enumerate(info["detections"], 1):
64
- output_text += f"\n 目标 {idx}: {det['class']} (置信度: {det['confidence']})"
 
 
 
65
 
66
- # 如果需要将结果图像缩放回原始尺寸
67
- if scale != 1:
68
- result_img = Image.fromarray(result_img)
69
- result_img = result_img.resize((w, h), Image.LANCZOS)
70
- result_img = np.array(result_img)
71
-
72
- return result_img, output_text
73
 
74
  demo = gr.Interface(
75
- fn=predict,
76
- inputs=gr.Image(label="输入图片"),
 
 
 
77
  outputs=[
78
- gr.Image(label="检测结果", type="numpy"),
79
- gr.Textbox(label="检测信息")
80
  ],
81
- title="🐉 奶龙杀手 (NailongKiller)",
82
- description="上传图片来检测奶龙 | Upload an image to detect Nailong",
83
- # examples=[["example1.jpg"]],
84
  # cache_examples=True
85
  )
86
 
 
1
  import gradio as gr
2
  from ultralytics import YOLO
3
  from fastapi import FastAPI
4
+ import cv2
5
  import torch
6
  import spaces
7
  import numpy as np
8
+ from pathlib import Path
9
 
10
  app = FastAPI()
11
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ model = YOLO('kunin-mice-pose.v0.1.0.pt') # 使用你的小鼠检测模型
 
13
 
14
  @spaces.GPU
15
+ def process_video(video_path, process_seconds=20):
16
+ # 创建临时输出路径
17
+ output_path = Path("temp_output.mp4")
18
 
19
+ # 获取视频信息
20
+ cap = cv2.VideoCapture(video_path)
21
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
22
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
23
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
24
+ total_frames = int(process_seconds * fps)
25
 
26
+ # 创建视频写入器
27
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
28
+ video_writer = cv2.VideoWriter(
29
+ str(output_path),
30
+ fourcc,
31
+ fps,
32
+ (width, height)
33
+ )
34
 
35
+ # 统计信息
36
+ frame_count = 0
37
+ total_detections = 0
38
+ max_mice = 0
39
+ detection_stats = []
40
 
41
+ # 处理视频
42
+ results = model.predict(
43
+ source=video_path,
44
+ device=device,
45
+ conf=0.5,
46
+ save=False,
47
+ show=False,
48
+ stream=True,
49
+ line_width=2,
50
+ show_boxes=True,
51
+ show_labels=True,
52
+ show_conf=True,
53
+ vid_stride=1,
54
+ )
55
 
56
+ for r in results:
57
+ # 获取当前帧的检测结果
58
+ frame = r.plot()
59
+ num_mice = len(r.boxes)
60
+ max_mice = max(max_mice, num_mice)
61
+ total_detections += num_mice
62
+ detection_stats.append(num_mice)
63
+
64
+ # 写入视频
65
+ video_writer.write(frame)
66
+
67
+ frame_count += 1
68
+ if frame_count >= total_frames:
69
+ break
70
 
71
+ # 释放资源
72
+ video_writer.release()
73
+ cap.release()
 
 
 
 
 
 
 
 
 
74
 
75
+ # 生成统计信息
76
+ avg_mice = total_detections / frame_count if frame_count > 0 else 0
77
+ output_text = f"""小鼠检测统计:
78
+ - 处理时长: {process_seconds}
79
+ - 总帧数: {frame_count}
80
+ - 最大检测数量: {max_mice}只
81
+ - 平均检测数量: {avg_mice:.1f}只
82
+ - 检测帧率: {fps} FPS"""
83
 
84
+ # 如果有检测到的帧
85
+ if detection_stats:
86
+ output_text += "\n\n帧检测分布:"
87
+ for count in range(max(detection_stats) + 1):
88
+ frames = detection_stats.count(count)
89
+ percentage = frames / frame_count * 100
90
+ output_text += f"\n{count}只小鼠: {frames}帧 ({percentage:.1f}%)"
91
 
92
+ return str(output_path), output_text
 
 
 
 
 
 
93
 
94
  demo = gr.Interface(
95
+ fn=process_video,
96
+ inputs=[
97
+ gr.Video(label="输入视频"),
98
+ gr.Slider(minimum=1, maximum=60, value=20, step=1, label="处理时长(秒)")
99
+ ],
100
  outputs=[
101
+ gr.Video(label="检测结果"),
102
+ gr.Textbox(label="检测统计")
103
  ],
104
+ title="🐁 小鼠行为检测系统",
105
+ description="上传视频来检测小鼠 | Upload a video to detect mice",
106
+ # examples=[["example.mp4", 20]],
107
  # cache_examples=True
108
  )
109