|
import gradio as gr |
|
import json |
|
import os |
|
import torch |
|
import numpy as np |
|
import matplotlib |
|
import matplotlib.cm as cm |
|
from SpikeT.model.S2DepthNet import S2DepthTransformerUNetConv |
|
from SpikeT.utils.data_augmentation import CenterCrop |
|
|
|
|
|
DEVICE = torch.device("cpu") |
|
title = "# Spike Transformer - Depth Estimation (CPU)" |
|
description = "上傳 `.dat` 或 `.npy` spike 檔案,模型將重建 spike 圖並預測對應的深度圖" |
|
|
|
|
|
model_path = 'SpikeT/s2d_weights/debug_A100_SpikeTransformerUNetConv_LocalGlobal-Swin3D-T/model_best.pth.tar' |
|
config_path = os.path.join(os.path.dirname(model_path), 'config.json') |
|
with open(config_path) as f: |
|
config = json.load(f) |
|
|
|
config['model']['gpu'] = config['gpu'] |
|
config['model']['every_x_rgb_frame'] = config['data_loader']['train']['every_x_rgb_frame'] |
|
config['model']['baseline'] = config['data_loader']['train']['baseline'] |
|
config['model']['loss_composition'] = config['trainer']['loss_composition'] |
|
|
|
model = eval(config['arch'])(config['model']) |
|
checkpoint = torch.load(model_path, map_location='cpu') |
|
state_dict = checkpoint['state_dict'] |
|
cleaned_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} |
|
model.load_state_dict(cleaned_state_dict) |
|
model.eval() |
|
model.to(DEVICE) |
|
data_transform = CenterCrop(224) |
|
|
|
|
|
def RawToSpike(video_seq, h, w, flipud=True): |
|
video_seq = np.array(video_seq).astype(np.uint8) |
|
img_size = h * w |
|
img_num = len(video_seq) // (img_size // 8) |
|
SpikeMatrix = np.zeros([img_num, h, w], np.uint8) |
|
pix_id = np.arange(0, h * w).reshape((h, w)) |
|
comparator = np.left_shift(1, np.mod(pix_id, 8)) |
|
byte_id = pix_id // 8 |
|
for img_id in range(img_num): |
|
id_start = img_id * img_size // 8 |
|
id_end = id_start + img_size // 8 |
|
cur_info = video_seq[id_start:id_end] |
|
data = cur_info[byte_id] |
|
result = np.bitwise_and(data, comparator) |
|
SpikeMatrix[img_id] = np.flipud((result == comparator)) if flipud else (result == comparator) |
|
return SpikeMatrix.astype(np.float32) |
|
|
|
def load_spike_file(path): |
|
if path.endswith(".npy"): |
|
return np.load(path).astype(np.float32) |
|
elif path.endswith(".dat"): |
|
with open(path, 'rb') as f: |
|
video_seq = np.frombuffer(f.read(), dtype='b') |
|
return RawToSpike(video_seq, h=260, w=346) |
|
else: |
|
raise ValueError("Unsupported file format. Only .dat and .npy are supported.") |
|
|
|
def predict_recon_bsf(spike, model, device): |
|
spikes = torch.from_numpy(spike).to(device) |
|
data = data_transform(spikes) |
|
dT, dH, dW = data.shape |
|
input_tensor = {'image': data[None, dT // 2 - 64: dT // 2 + 64].to(device)} |
|
prev_super_states = {'image': None} |
|
prev_states_lstm = {} |
|
|
|
with torch.no_grad(): |
|
pred, _, _ = model(input_tensor, prev_super_states['image'], prev_states_lstm) |
|
depth = pred['image'][0].cpu().numpy() |
|
spikes_np = data.permute(1, 2, 0).cpu().numpy() |
|
spike_vis = np.mean(spikes_np, axis=2) |
|
|
|
return torch.tensor(spike_vis).unsqueeze(0), depth |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(title) |
|
gr.Markdown(description) |
|
|
|
with gr.Row(): |
|
input_file = gr.File(label="Upload .dat or .npy Spike File", type="file") |
|
output_spike = gr.Image(label="Reconstructed Spike Image") |
|
output_depth = gr.Image(label="Depth Prediction (Colormap)") |
|
|
|
cmap = matplotlib.colormaps.get_cmap('Spectral_r') |
|
submit = gr.Button("Submit") |
|
|
|
def on_submit(file_obj): |
|
spike = load_spike_file(file_obj.name) |
|
spike_img, depth = predict_recon_bsf(spike, model, DEVICE) |
|
|
|
|
|
spike_img = spike_img.repeat(3, 1, 1) |
|
h, w = spike_img.shape[1:] |
|
min_dim = min(h, w) |
|
center_crop = spike_img[:, (h - min_dim) // 2:(h + min_dim) // 2, (w - min_dim) // 2:(w + min_dim) // 2] |
|
spike_img_np = (center_crop.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8) |
|
|
|
|
|
depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8) * 255.0 |
|
colored_depth = (cmap(depth.astype(np.uint8))[:, :, :3] * 255).astype(np.uint8) |
|
|
|
return spike_img_np, colored_depth |
|
|
|
submit.click(fn=on_submit, inputs=[input_file], outputs=[output_spike, output_depth]) |
|
|
|
|
|
example_dir = "assets/" |
|
if os.path.exists(example_dir): |
|
example_files = sorted([ |
|
os.path.join(example_dir, f) |
|
for f in os.listdir(example_dir) |
|
if f.endswith(".npy") or f.endswith(".dat") |
|
]) |
|
else: |
|
example_files = [] |
|
|
|
gr.Examples( |
|
examples=example_files, |
|
inputs=[input_file], |
|
outputs=[output_spike, output_depth], |
|
fn=on_submit, |
|
cache_examples=False |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch() |