motionmask / app.py
zzzzzeee's picture
Update app.py
74aa301 verified
import gradio as gr
import torch
import numpy as np
from PIL import Image
from utils_spike import load_vidar_dat, STPFilter # 假设您有这个函数
import os
from tqdm import tqdm, trange
# 设置设备
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
def compute_motion_map(spike, tau_f, tau_s, U_c, theta, sampling_rate=40000):
"""
使用 STPFilter 类从脉冲数据计算运动图。
参数:
spike: 形状为 (T, H, W) 的脉冲数据张量(0 或 1)
tau_f: 易化时间常数(毫秒)
tau_s: 恢复时间常数(毫秒)
U_c: 易化参数(0 到 1)
theta: 运动检测阈值
sampling_rate: 脉冲数据的采样率(赫兹),默认为 40000 Hz
返回:
motion_map: 形状为 (H, W) 的 numpy 数组,表示二值运动图
"""
# 如果 spike 是 numpy 数组,转换为 tensor
if isinstance(spike, np.ndarray):
spike = torch.from_numpy(spike).float().to(DEVICE)
else:
spike = spike.float().to(DEVICE)
T, H, W = spike.shape
device = spike.device
# 设置 STPFilter 参数
STPargs = {
'u0': U_c, # 初始易化参数
'D': tau_s, # 恢复时间常数(毫秒 -> 秒)
'F': tau_f , # 易化时间常数(毫秒 -> 秒)
'f': 0.11, # 假设与 STPFilter 默认值一致
'time_unit': 1000 / sampling_rate, # 每个时间步的持续时间(毫秒)
'filterThr': theta, # 运动检测阈值
'voltageMin': -8, # LIF 模型参数(假设值)
'lifThr': 2 # LIF 阈值(假设值)
}
# 初始化 STPFilter
stp_filter = STPFilter(H, W, device, diff_time=1, **STPargs)
# 逐时间步更新并获取运动掩码
motion_masks = []
for t in trange(T):
cur_spikes = spike[t]
stp_filter.update_dynamics(t, cur_spikes)
motion_masks.append(stp_filter.filter_spk.cpu().numpy())
print("complete calibration.")
# 取中心时间步的运动掩码
central_index = T // 2
motion_map = motion_masks[central_index]
print(motion_map.min(), motion_map.max())
# 转换为图像格式 (0-255)
motion_map = (motion_map * 255).astype(np.uint8)
return motion_map
# Gradio 接口
with gr.Blocks() as demo:
# 标题和描述
gr.Markdown("# 运动图预测")
gr.Markdown(
"上传包含脉冲数据的 `.dat` 文件,并调整超参数以预测运动物体的运动图。"
)
# 输入和输出布局
with gr.Row():
input_dat = gr.File(label="输入 Dat 文件")
output_motion_map = gr.Image(label="运动图")
# 超参数滑块
with gr.Row():
tau_f = gr.Slider(
minimum=1, maximum=400, value=10, step=1,
label="τ_f (ms): 易化时间常数"
)
tau_s = gr.Slider(
minimum=1, maximum=400, value=10, step=1,
label="τ_s (ms): 恢复时间常数"
)
U_c = gr.Slider(
minimum=0.1, maximum=1, value=0.15, step=0.05,
label="U_c: 易化参数"
)
theta = gr.Slider(
minimum=0.1, maximum=10, value=2, step=0.1,
label="θ: 运动检测阈值"
)
# 提交按钮
submit = gr.Button(value="提交")
def on_submit(dat_path, tau_f, tau_s, U_c, theta):
"""
处理上传的 dat 文件并返回运动图。
参数:
dat_path: 上传的 .dat 文件路径
tau_f, tau_s, U_c, theta: 滑块中的超参数
返回:
运动图的 PIL 图像
"""
# 加载脉冲数据(假设从参考代码中获取固定大小)
spike = load_vidar_dat(dat_path, width=400, height=250)
# 计算运动图
motion_map = compute_motion_map(spike, tau_f, tau_s, U_c, theta)
# 转换为 PIL 图像以供 Gradio 使用
return Image.fromarray(motion_map)
# 连接提交按钮到函数
submit.click(
fn=on_submit,
inputs=[input_dat, tau_f, tau_s, U_c, theta],
outputs=[output_motion_map]
)
example_dir = "assets/examples" # 示例文件目录
if os.path.exists(example_dir):
example_files = sorted(os.listdir(example_dir))
example_files = [os.path.join(example_dir, filename) for filename in example_files if filename.endswith(".dat")]
else:
example_files = []
# 添加 gr.Examples 组件
examples = gr.Examples(
examples=example_files,
inputs=[input_dat], # 选中示例时自动填入 input_dat
outputs=[output_motion_map], # 输出运动图
fn=on_submit, # 调用 on_submit 函数
cache_examples=False # 禁用缓存,确保加载最新文件
)
if __name__ == "__main__":
demo.queue().launch()