Hakureirm commited on
Commit
a8c778e
·
verified ·
1 Parent(s): 4192108

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -261,7 +261,7 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
261
  return filtered_points
262
 
263
  # 修改轨迹图生成部分
264
- trajectory_img = torch.ones((height, width, 3), device=device) * 255
265
  points = np.array(all_positions, dtype=np.int32)
266
  if len(points) > 1:
267
  filtered_points = filter_trajectories_gpu(points.tolist(), width, height)
@@ -273,7 +273,7 @@ def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8)
273
  int((1 - ratio) * 255), # B
274
  50, # G
275
  int(ratio * 255) # R
276
- ], device=device)
277
 
278
  # 使用GPU绘制线段
279
  pt1, pt2 = points[i], points[i + 1]
@@ -403,10 +403,13 @@ def draw_line_gpu(image, pt1, pt2, color, thickness=1):
403
  (points[:, 1] >= 0) & (points[:, 1] < image.shape[0])
404
  points = points[valid_points]
405
 
 
 
 
406
  if thickness > 1:
407
  for dx in range(-thickness//2, thickness//2 + 1):
408
  for dy in range(-thickness//2, thickness//2 + 1):
409
- offset_points = points + torch.tensor([dx, dy], device=device)
410
  valid_offset = (offset_points[:, 0] >= 0) & (offset_points[:, 0] < image.shape[1]) & \
411
  (offset_points[:, 1] >= 0) & (offset_points[:, 1] < image.shape[0])
412
  offset_points = offset_points[valid_offset]
@@ -425,7 +428,7 @@ with gr.Blocks() as demo:
425
  login_msg = gr.Textbox(label="消息", interactive=False)
426
 
427
  with gr.Group(visible=False) as main_interface:
428
- gr.Markdown("上传视频来检测和分析小鼠行为 | Upload a video to detect and analyze mice behavior")
429
 
430
  with gr.Row():
431
  with gr.Column():
 
261
  return filtered_points
262
 
263
  # 修改轨迹图生成部分
264
+ trajectory_img = torch.ones((height, width, 3), device=device, dtype=torch.float32)
265
  points = np.array(all_positions, dtype=np.int32)
266
  if len(points) > 1:
267
  filtered_points = filter_trajectories_gpu(points.tolist(), width, height)
 
273
  int((1 - ratio) * 255), # B
274
  50, # G
275
  int(ratio * 255) # R
276
+ ], device=device, dtype=torch.float32)
277
 
278
  # 使用GPU绘制线段
279
  pt1, pt2 = points[i], points[i + 1]
 
403
  (points[:, 1] >= 0) & (points[:, 1] < image.shape[0])
404
  points = points[valid_points]
405
 
406
+ # 确保color是正确的数据类型
407
+ color = color.to(image.dtype) # 修改这里,确保颜色张量与图像类型匹配
408
+
409
  if thickness > 1:
410
  for dx in range(-thickness//2, thickness//2 + 1):
411
  for dy in range(-thickness//2, thickness//2 + 1):
412
+ offset_points = points + torch.tensor([dx, dy], device=device, dtype=torch.long) # 修改这里,确保是long类型
413
  valid_offset = (offset_points[:, 0] >= 0) & (offset_points[:, 0] < image.shape[1]) & \
414
  (offset_points[:, 1] >= 0) & (offset_points[:, 1] < image.shape[0])
415
  offset_points = offset_points[valid_offset]
 
428
  login_msg = gr.Textbox(label="消息", interactive=False)
429
 
430
  with gr.Group(visible=False) as main_interface:
431
+ gr.Markdown("上传视频来检测和���析小鼠行为 | Upload a video to detect and analyze mice behavior")
432
 
433
  with gr.Row():
434
  with gr.Column():