rightnow / app.py
zzzzzeee's picture
Update app.py
50aad66 verified
import gradio as gr
import json
import os
import torch
import numpy as np
import matplotlib
import matplotlib.cm as cm
from SpikeT.model.S2DepthNet import S2DepthTransformerUNetConv
from SpikeT.utils.data_augmentation import CenterCrop
# === 設定 ===
DEVICE = torch.device("cpu")
title = "# Spike Transformer - Depth Estimation (CPU)"
description = "上傳 `.dat` 或 `.npy` spike 檔案,模型將重建 spike 圖並預測對應的深度圖"
# === 載入模型與 config ===
model_path = 'SpikeT/s2d_weights/debug_A100_SpikeTransformerUNetConv_LocalGlobal-Swin3D-T/model_best.pth.tar'
config_path = os.path.join(os.path.dirname(model_path), 'config.json')
with open(config_path) as f:
config = json.load(f)
config['model']['gpu'] = config['gpu']
config['model']['every_x_rgb_frame'] = config['data_loader']['train']['every_x_rgb_frame']
config['model']['baseline'] = config['data_loader']['train']['baseline']
config['model']['loss_composition'] = config['trainer']['loss_composition']
model = eval(config['arch'])(config['model'])
checkpoint = torch.load(model_path, map_location='cpu')
state_dict = checkpoint['state_dict']
cleaned_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(cleaned_state_dict)
model.eval()
model.to(DEVICE)
data_transform = CenterCrop(224)
# === 工具函數 ===
def RawToSpike(video_seq, h, w, flipud=True):
video_seq = np.array(video_seq).astype(np.uint8)
img_size = h * w
img_num = len(video_seq) // (img_size // 8)
SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
pix_id = np.arange(0, h * w).reshape((h, w))
comparator = np.left_shift(1, np.mod(pix_id, 8))
byte_id = pix_id // 8
for img_id in range(img_num):
id_start = img_id * img_size // 8
id_end = id_start + img_size // 8
cur_info = video_seq[id_start:id_end]
data = cur_info[byte_id]
result = np.bitwise_and(data, comparator)
SpikeMatrix[img_id] = np.flipud((result == comparator)) if flipud else (result == comparator)
return SpikeMatrix.astype(np.float32)
def load_spike_file(path):
if path.endswith(".npy"):
return np.load(path).astype(np.float32)
elif path.endswith(".dat"):
with open(path, 'rb') as f:
video_seq = np.frombuffer(f.read(), dtype='b')
return RawToSpike(video_seq, h=260, w=346)
else:
raise ValueError("Unsupported file format. Only .dat and .npy are supported.")
def predict_recon_bsf(spike, model, device):
spikes = torch.from_numpy(spike).to(device)
data = data_transform(spikes)
dT, dH, dW = data.shape
input_tensor = {'image': data[None, dT // 2 - 64: dT // 2 + 64].to(device)}
prev_super_states = {'image': None}
prev_states_lstm = {}
with torch.no_grad():
pred, _, _ = model(input_tensor, prev_super_states['image'], prev_states_lstm)
depth = pred['image'][0].cpu().numpy()
spikes_np = data.permute(1, 2, 0).cpu().numpy()
spike_vis = np.mean(spikes_np, axis=2)
return torch.tensor(spike_vis).unsqueeze(0), depth
# === Gradio 介面 ===
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
input_file = gr.File(label="Upload .dat or .npy Spike File", type="file")
output_spike = gr.Image(label="Reconstructed Spike Image")
output_depth = gr.Image(label="Depth Prediction (Colormap)")
cmap = matplotlib.colormaps.get_cmap('Spectral_r')
submit = gr.Button("Submit")
def on_submit(file_obj):
spike = load_spike_file(file_obj.name)
spike_img, depth = predict_recon_bsf(spike, model, DEVICE)
# 處理 spike 圖
spike_img = spike_img.repeat(3, 1, 1)
h, w = spike_img.shape[1:]
min_dim = min(h, w)
center_crop = spike_img[:, (h - min_dim) // 2:(h + min_dim) // 2, (w - min_dim) // 2:(w + min_dim) // 2]
spike_img_np = (center_crop.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8)
# Colormap depth
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8) * 255.0
colored_depth = (cmap(depth.astype(np.uint8))[:, :, :3] * 255).astype(np.uint8)
return spike_img_np, colored_depth
submit.click(fn=on_submit, inputs=[input_file], outputs=[output_spike, output_depth])
# 示例資料(僅支援 .npy)
example_dir = "assets/"
if os.path.exists(example_dir):
example_files = sorted([
os.path.join(example_dir, f)
for f in os.listdir(example_dir)
if f.endswith(".npy") or f.endswith(".dat")
])
else:
example_files = []
gr.Examples(
examples=example_files,
inputs=[input_file],
outputs=[output_spike, output_depth],
fn=on_submit,
cache_examples=False
)
if __name__ == "__main__":
demo.queue().launch()