import gradio as gr import json import os import torch import numpy as np import matplotlib import matplotlib.cm as cm from SpikeT.model.S2DepthNet import S2DepthTransformerUNetConv from SpikeT.utils.data_augmentation import CenterCrop # === 設定 === DEVICE = torch.device("cpu") title = "# Spike Transformer - Depth Estimation (CPU)" description = "上傳 `.dat` 或 `.npy` spike 檔案,模型將重建 spike 圖並預測對應的深度圖" # === 載入模型與 config === model_path = 'SpikeT/s2d_weights/debug_A100_SpikeTransformerUNetConv_LocalGlobal-Swin3D-T/model_best.pth.tar' config_path = os.path.join(os.path.dirname(model_path), 'config.json') with open(config_path) as f: config = json.load(f) config['model']['gpu'] = config['gpu'] config['model']['every_x_rgb_frame'] = config['data_loader']['train']['every_x_rgb_frame'] config['model']['baseline'] = config['data_loader']['train']['baseline'] config['model']['loss_composition'] = config['trainer']['loss_composition'] model = eval(config['arch'])(config['model']) checkpoint = torch.load(model_path, map_location='cpu') state_dict = checkpoint['state_dict'] cleaned_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} model.load_state_dict(cleaned_state_dict) model.eval() model.to(DEVICE) data_transform = CenterCrop(224) # === 工具函數 === def RawToSpike(video_seq, h, w, flipud=True): video_seq = np.array(video_seq).astype(np.uint8) img_size = h * w img_num = len(video_seq) // (img_size // 8) SpikeMatrix = np.zeros([img_num, h, w], np.uint8) pix_id = np.arange(0, h * w).reshape((h, w)) comparator = np.left_shift(1, np.mod(pix_id, 8)) byte_id = pix_id // 8 for img_id in range(img_num): id_start = img_id * img_size // 8 id_end = id_start + img_size // 8 cur_info = video_seq[id_start:id_end] data = cur_info[byte_id] result = np.bitwise_and(data, comparator) SpikeMatrix[img_id] = np.flipud((result == comparator)) if flipud else (result == comparator) return SpikeMatrix.astype(np.float32) def load_spike_file(path): if path.endswith(".npy"): return np.load(path).astype(np.float32) elif path.endswith(".dat"): with open(path, 'rb') as f: video_seq = np.frombuffer(f.read(), dtype='b') return RawToSpike(video_seq, h=260, w=346) else: raise ValueError("Unsupported file format. Only .dat and .npy are supported.") def predict_recon_bsf(spike, model, device): spikes = torch.from_numpy(spike).to(device) data = data_transform(spikes) dT, dH, dW = data.shape input_tensor = {'image': data[None, dT // 2 - 64: dT // 2 + 64].to(device)} prev_super_states = {'image': None} prev_states_lstm = {} with torch.no_grad(): pred, _, _ = model(input_tensor, prev_super_states['image'], prev_states_lstm) depth = pred['image'][0].cpu().numpy() spikes_np = data.permute(1, 2, 0).cpu().numpy() spike_vis = np.mean(spikes_np, axis=2) return torch.tensor(spike_vis).unsqueeze(0), depth # === Gradio 介面 === with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(description) with gr.Row(): input_file = gr.File(label="Upload .dat or .npy Spike File", type="file") output_spike = gr.Image(label="Reconstructed Spike Image") output_depth = gr.Image(label="Depth Prediction (Colormap)") cmap = matplotlib.colormaps.get_cmap('Spectral_r') submit = gr.Button("Submit") def on_submit(file_obj): spike = load_spike_file(file_obj.name) spike_img, depth = predict_recon_bsf(spike, model, DEVICE) # 處理 spike 圖 spike_img = spike_img.repeat(3, 1, 1) h, w = spike_img.shape[1:] min_dim = min(h, w) center_crop = spike_img[:, (h - min_dim) // 2:(h + min_dim) // 2, (w - min_dim) // 2:(w + min_dim) // 2] spike_img_np = (center_crop.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8) # Colormap depth depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8) * 255.0 colored_depth = (cmap(depth.astype(np.uint8))[:, :, :3] * 255).astype(np.uint8) return spike_img_np, colored_depth submit.click(fn=on_submit, inputs=[input_file], outputs=[output_spike, output_depth]) # 示例資料(僅支援 .npy) example_dir = "assets/" if os.path.exists(example_dir): example_files = sorted([ os.path.join(example_dir, f) for f in os.listdir(example_dir) if f.endswith(".npy") or f.endswith(".dat") ]) else: example_files = [] gr.Examples( examples=example_files, inputs=[input_file], outputs=[output_spike, output_depth], fn=on_submit, cache_examples=False ) if __name__ == "__main__": demo.queue().launch()