File size: 4,898 Bytes
760267b
 
 
 
 
79efd3a
760267b
 
 
 
79efd3a
 
 
 
760267b
79efd3a
760267b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79efd3a
760267b
 
79efd3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760267b
79efd3a
 
 
 
 
 
 
 
 
 
 
 
 
760267b
79efd3a
760267b
 
 
 
79efd3a
 
760267b
 
 
79efd3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50aad66
79efd3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af67d34
 
79efd3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import json
import os
import torch
import numpy as np
import matplotlib
import matplotlib.cm as cm
from SpikeT.model.S2DepthNet import S2DepthTransformerUNetConv
from SpikeT.utils.data_augmentation import CenterCrop

# === 設定 ===
DEVICE = torch.device("cpu")
title = "# Spike Transformer - Depth Estimation (CPU)"
description = "上傳 `.dat` 或 `.npy` spike 檔案,模型將重建 spike 圖並預測對應的深度圖"

# === 載入模型與 config ===
model_path = 'SpikeT/s2d_weights/debug_A100_SpikeTransformerUNetConv_LocalGlobal-Swin3D-T/model_best.pth.tar'
config_path = os.path.join(os.path.dirname(model_path), 'config.json')
with open(config_path) as f:
    config = json.load(f)

config['model']['gpu'] = config['gpu']
config['model']['every_x_rgb_frame'] = config['data_loader']['train']['every_x_rgb_frame']
config['model']['baseline'] = config['data_loader']['train']['baseline']
config['model']['loss_composition'] = config['trainer']['loss_composition']

model = eval(config['arch'])(config['model'])
checkpoint = torch.load(model_path, map_location='cpu')
state_dict = checkpoint['state_dict']
cleaned_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
model.load_state_dict(cleaned_state_dict)
model.eval()
model.to(DEVICE)
data_transform = CenterCrop(224)

# === 工具函數 ===
def RawToSpike(video_seq, h, w, flipud=True):
    video_seq = np.array(video_seq).astype(np.uint8)
    img_size = h * w
    img_num = len(video_seq) // (img_size // 8)
    SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
    pix_id = np.arange(0, h * w).reshape((h, w))
    comparator = np.left_shift(1, np.mod(pix_id, 8))
    byte_id = pix_id // 8
    for img_id in range(img_num):
        id_start = img_id * img_size // 8
        id_end = id_start + img_size // 8
        cur_info = video_seq[id_start:id_end]
        data = cur_info[byte_id]
        result = np.bitwise_and(data, comparator)
        SpikeMatrix[img_id] = np.flipud((result == comparator)) if flipud else (result == comparator)
    return SpikeMatrix.astype(np.float32)

def load_spike_file(path):
    if path.endswith(".npy"):
        return np.load(path).astype(np.float32)
    elif path.endswith(".dat"):
        with open(path, 'rb') as f:
            video_seq = np.frombuffer(f.read(), dtype='b')
        return RawToSpike(video_seq, h=260, w=346)
    else:
        raise ValueError("Unsupported file format. Only .dat and .npy are supported.")

def predict_recon_bsf(spike, model, device):
    spikes = torch.from_numpy(spike).to(device)
    data = data_transform(spikes)
    dT, dH, dW = data.shape
    input_tensor = {'image': data[None, dT // 2 - 64: dT // 2 + 64].to(device)}
    prev_super_states = {'image': None}
    prev_states_lstm = {}

    with torch.no_grad():
        pred, _, _ = model(input_tensor, prev_super_states['image'], prev_states_lstm)
        depth = pred['image'][0].cpu().numpy()
        spikes_np = data.permute(1, 2, 0).cpu().numpy()
        spike_vis = np.mean(spikes_np, axis=2)

    return torch.tensor(spike_vis).unsqueeze(0), depth

# === Gradio 介面 ===
with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)

    with gr.Row():
        input_file = gr.File(label="Upload .dat or .npy Spike File", type="file")
        output_spike = gr.Image(label="Reconstructed Spike Image")
        output_depth = gr.Image(label="Depth Prediction (Colormap)")

    cmap = matplotlib.colormaps.get_cmap('Spectral_r')
    submit = gr.Button("Submit")

    def on_submit(file_obj):
        spike = load_spike_file(file_obj.name)
        spike_img, depth = predict_recon_bsf(spike, model, DEVICE)

        # 處理 spike 圖
        spike_img = spike_img.repeat(3, 1, 1)
        h, w = spike_img.shape[1:]
        min_dim = min(h, w)
        center_crop = spike_img[:, (h - min_dim) // 2:(h + min_dim) // 2, (w - min_dim) // 2:(w + min_dim) // 2]
        spike_img_np = (center_crop.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8)

        # Colormap depth
        depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8) * 255.0
        colored_depth = (cmap(depth.astype(np.uint8))[:, :, :3] * 255).astype(np.uint8)

        return spike_img_np, colored_depth

    submit.click(fn=on_submit, inputs=[input_file], outputs=[output_spike, output_depth])

    # 示例資料(僅支援 .npy)
    example_dir = "assets/"
    if os.path.exists(example_dir):
        example_files = sorted([
            os.path.join(example_dir, f)
            for f in os.listdir(example_dir)
            if f.endswith(".npy") or f.endswith(".dat")
        ])
    else:
        example_files = []

    gr.Examples(
        examples=example_files,
        inputs=[input_file],
        outputs=[output_spike, output_depth],
        fn=on_submit,
        cache_examples=False
    )

if __name__ == "__main__":
    demo.queue().launch()