File size: 4,898 Bytes
760267b 79efd3a 760267b 79efd3a 760267b 79efd3a 760267b 79efd3a 760267b 79efd3a 760267b 79efd3a 760267b 79efd3a 760267b 79efd3a 760267b 79efd3a 50aad66 79efd3a af67d34 79efd3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
import gradio as gr
import json
import os
import torch
import numpy as np
import matplotlib
import matplotlib.cm as cm
from SpikeT.model.S2DepthNet import S2DepthTransformerUNetConv
from SpikeT.utils.data_augmentation import CenterCrop
# === 設定 ===
DEVICE = torch.device("cpu")
title = "# Spike Transformer - Depth Estimation (CPU)"
description = "上傳 `.dat` 或 `.npy` spike 檔案,模型將重建 spike 圖並預測對應的深度圖"
# === 載入模型與 config ===
model_path = 'SpikeT/s2d_weights/debug_A100_SpikeTransformerUNetConv_LocalGlobal-Swin3D-T/model_best.pth.tar'
config_path = os.path.join(os.path.dirname(model_path), 'config.json')
with open(config_path) as f:
config = json.load(f)
config['model']['gpu'] = config['gpu']
config['model']['every_x_rgb_frame'] = config['data_loader']['train']['every_x_rgb_frame']
config['model']['baseline'] = config['data_loader']['train']['baseline']
config['model']['loss_composition'] = config['trainer']['loss_composition']
model = eval(config['arch'])(config['model'])
checkpoint = torch.load(model_path, map_location='cpu')
state_dict = checkpoint['state_dict']
cleaned_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(cleaned_state_dict)
model.eval()
model.to(DEVICE)
data_transform = CenterCrop(224)
# === 工具函數 ===
def RawToSpike(video_seq, h, w, flipud=True):
video_seq = np.array(video_seq).astype(np.uint8)
img_size = h * w
img_num = len(video_seq) // (img_size // 8)
SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
pix_id = np.arange(0, h * w).reshape((h, w))
comparator = np.left_shift(1, np.mod(pix_id, 8))
byte_id = pix_id // 8
for img_id in range(img_num):
id_start = img_id * img_size // 8
id_end = id_start + img_size // 8
cur_info = video_seq[id_start:id_end]
data = cur_info[byte_id]
result = np.bitwise_and(data, comparator)
SpikeMatrix[img_id] = np.flipud((result == comparator)) if flipud else (result == comparator)
return SpikeMatrix.astype(np.float32)
def load_spike_file(path):
if path.endswith(".npy"):
return np.load(path).astype(np.float32)
elif path.endswith(".dat"):
with open(path, 'rb') as f:
video_seq = np.frombuffer(f.read(), dtype='b')
return RawToSpike(video_seq, h=260, w=346)
else:
raise ValueError("Unsupported file format. Only .dat and .npy are supported.")
def predict_recon_bsf(spike, model, device):
spikes = torch.from_numpy(spike).to(device)
data = data_transform(spikes)
dT, dH, dW = data.shape
input_tensor = {'image': data[None, dT // 2 - 64: dT // 2 + 64].to(device)}
prev_super_states = {'image': None}
prev_states_lstm = {}
with torch.no_grad():
pred, _, _ = model(input_tensor, prev_super_states['image'], prev_states_lstm)
depth = pred['image'][0].cpu().numpy()
spikes_np = data.permute(1, 2, 0).cpu().numpy()
spike_vis = np.mean(spikes_np, axis=2)
return torch.tensor(spike_vis).unsqueeze(0), depth
# === Gradio 介面 ===
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
input_file = gr.File(label="Upload .dat or .npy Spike File", type="file")
output_spike = gr.Image(label="Reconstructed Spike Image")
output_depth = gr.Image(label="Depth Prediction (Colormap)")
cmap = matplotlib.colormaps.get_cmap('Spectral_r')
submit = gr.Button("Submit")
def on_submit(file_obj):
spike = load_spike_file(file_obj.name)
spike_img, depth = predict_recon_bsf(spike, model, DEVICE)
# 處理 spike 圖
spike_img = spike_img.repeat(3, 1, 1)
h, w = spike_img.shape[1:]
min_dim = min(h, w)
center_crop = spike_img[:, (h - min_dim) // 2:(h + min_dim) // 2, (w - min_dim) // 2:(w + min_dim) // 2]
spike_img_np = (center_crop.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8)
# Colormap depth
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8) * 255.0
colored_depth = (cmap(depth.astype(np.uint8))[:, :, :3] * 255).astype(np.uint8)
return spike_img_np, colored_depth
submit.click(fn=on_submit, inputs=[input_file], outputs=[output_spike, output_depth])
# 示例資料(僅支援 .npy)
example_dir = "assets/"
if os.path.exists(example_dir):
example_files = sorted([
os.path.join(example_dir, f)
for f in os.listdir(example_dir)
if f.endswith(".npy") or f.endswith(".dat")
])
else:
example_files = []
gr.Examples(
examples=example_files,
inputs=[input_file],
outputs=[output_spike, output_depth],
fn=on_submit,
cache_examples=False
)
if __name__ == "__main__":
demo.queue().launch() |