Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import numpy as np | |
from PIL import Image | |
from utils_spike import load_vidar_dat # 假设您有这个函数 | |
import os | |
# 设置设备 | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
import torch | |
import numpy as np | |
def compute_motion_map(spike, tau_f, tau_s, U_c, theta, sampling_rate=40000): | |
""" | |
使用 STP 动力学从脉冲数据计算运动图。 | |
参数: | |
spike: 形状为 (T, H, W) 的脉冲数据张量(0 或 1) | |
tau_f: 易化时间常数(毫秒) | |
tau_s: 恢复时间常数(毫秒) | |
U_c: 易化参数(0 到 1) | |
theta: 运动检测阈值 | |
sampling_rate: 脉冲数据的采样率(赫兹),默认为 40000 Hz | |
返回: | |
motion_map: 形状为 (H, W) 的 numpy 数组,表示二值运动图 | |
""" | |
# 如果 spike 是 numpy 数组,转换为 tensor | |
if isinstance(spike, np.ndarray): | |
spike = torch.from_numpy(spike).float().to(DEVICE) | |
else: | |
spike = spike.float().to(DEVICE) | |
# 时间步长(秒) | |
delta_t = 1 / sampling_rate # 例如 1/40000 秒 | |
# 将时间常数从毫秒转换为秒 | |
tau_f = tau_f / 1000 # 毫秒 -> 秒 | |
tau_s = tau_s / 1000 # 毫秒 -> 秒 | |
# 在循环外计算指数衰减因子 | |
exp_f = torch.exp(torch.tensor(-delta_t / tau_f, device=DEVICE)) | |
exp_s = torch.exp(torch.tensor(-delta_t / tau_s, device=DEVICE)) | |
T, H, W = spike.shape | |
# 初始化 STP 变量 R(资源)和 u(利用率) | |
R = torch.ones((T, H, W), device=DEVICE) # R 初始为 1 | |
u = U_c * torch.ones((T, H, W), device=DEVICE) # u 初始为 U_c | |
# 计算随时间变化的 STP 动力学 | |
for n in range(1, T): | |
# 脉冲间指数衰减 | |
u_decay = u[n-1] * exp_f | |
R_decay = 1 + (R[n-1] - 1) * exp_s | |
# 当前时间步的脉冲更新 | |
spike_n = spike[n] | |
u[n] = u_decay + spike_n * U_c * (1 - u_decay) # 易化更新 | |
R[n] = R_decay * (1 - spike_n * u[n]) # 资源消耗 | |
# 在中心时间步计算运动掩码 | |
central_index = T // 2 | |
R_diff = torch.abs(R[central_index] - R[central_index - 1]) | |
M = (R_diff >= theta).float() # 二值掩码:运动区域为 1 | |
print(M.max()) | |
print(M.min()) | |
# 转换为图像格式 (0-255) | |
motion_map = (M * 255).cpu().numpy().astype(np.uint8) | |
return motion_map | |
# Gradio 接口 | |
with gr.Blocks() as demo: | |
# 标题和描述 | |
gr.Markdown("# 运动图预测") | |
gr.Markdown( | |
"上传包含脉冲数据的 `.dat` 文件,并调整超参数以预测运动物体的运动图。" | |
) | |
# 输入和输出布局 | |
with gr.Row(): | |
input_dat = gr.File(label="输入 Dat 文件") | |
output_motion_map = gr.Image(label="运动图") | |
# 超参数滑块 | |
with gr.Row(): | |
tau_f = gr.Slider( | |
minimum=10, maximum=1000, value=100, step=10, | |
label="τ_f (ms): 易化时间常数" | |
) | |
tau_s = gr.Slider( | |
minimum=10, maximum=1000, value=100, step=10, | |
label="τ_s (ms): 恢复时间常数" | |
) | |
U_c = gr.Slider( | |
minimum=0.1, maximum=0.9, value=0.5, step=0.05, | |
label="U_c: 易化参数" | |
) | |
theta = gr.Slider( | |
minimum=0.01, maximum=0.5, value=0.1, step=0.01, | |
label="θ: 运动检测阈值" | |
) | |
# 提交按钮 | |
submit = gr.Button(value="提交") | |
def on_submit(dat_path, tau_f, tau_s, U_c, theta): | |
""" | |
处理上传的 dat 文件并返回运动图。 | |
参数: | |
dat_path: 上传的 .dat 文件路径 | |
tau_f, tau_s, U_c, theta: 滑块中的超参数 | |
返回: | |
运动图的 PIL 图像 | |
""" | |
# 加载脉冲数据(假设从参考代码中获取固定大小) | |
spike = load_vidar_dat(dat_path, width=400, height=250) | |
# 计算运动图 | |
motion_map = compute_motion_map(spike, tau_f, tau_s, U_c, theta) | |
# 转换为 PIL 图像以供 Gradio 使用 | |
return Image.fromarray(motion_map) | |
# 连接提交按钮到函数 | |
submit.click( | |
fn=on_submit, | |
inputs=[input_dat, tau_f, tau_s, U_c, theta], | |
outputs=[output_motion_map] | |
) | |
example_dir = "assets/examples" # 示例文件目录 | |
if os.path.exists(example_dir): | |
example_files = sorted(os.listdir(example_dir)) | |
example_files = [os.path.join(example_dir, filename) for filename in example_files if filename.endswith(".dat")] | |
else: | |
example_files = [] | |
# 添加 gr.Examples 组件 | |
examples = gr.Examples( | |
examples=example_files, | |
inputs=[input_dat], # 选中示例时自动填入 input_dat | |
outputs=[output_motion_map], # 输出运动图 | |
fn=on_submit, # 调用 on_submit 函数 | |
cache_examples=False # 禁用缓存,确保加载最新文件 | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() |