zzzzzeee commited on
Commit
79efd3a
·
verified ·
1 Parent(s): 5cced60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -101
app.py CHANGED
@@ -3,138 +3,130 @@ import json
3
  import os
4
  import torch
5
  import numpy as np
6
- import cv2
7
- import matplotlib as mpl
8
  import matplotlib.cm as cm
9
  from SpikeT.model.S2DepthNet import S2DepthTransformerUNetConv
10
  from SpikeT.utils.data_augmentation import CenterCrop
11
 
12
- # === Helper Functions ===
13
-
14
- def RawToSpike(video_seq, h, w, flipud=True):
15
- video_seq = np.array(video_seq).astype(np.uint8)
16
- img_size = h * w
17
- img_num = len(video_seq) // (img_size // 8)
18
- SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
19
- pix_id = np.arange(0, h * w)
20
- pix_id = np.reshape(pix_id, (h, w))
21
- comparator = np.left_shift(1, np.mod(pix_id, 8))
22
- byte_id = pix_id // 8
23
-
24
- for img_id in np.arange(img_num):
25
- id_start = int(img_id) * int(img_size) // 8
26
- id_end = int(id_start) + int(img_size) // 8
27
- cur_info = video_seq[id_start:id_end]
28
- data = cur_info[byte_id]
29
- result = np.bitwise_and(data, comparator)
30
- if flipud:
31
- SpikeMatrix[img_id, :, :] = np.flipud((result == comparator))
32
- else:
33
- SpikeMatrix[img_id, :, :] = (result == comparator)
34
-
35
- return SpikeMatrix.astype(np.float32)
36
-
37
- def make_colormap(img, color_mapper):
38
- color_map_inv = np.ones_like(img[0]) * np.amax(img[0]) - img[0]
39
- color_map_inv = np.nan_to_num(color_map_inv, nan=1)
40
- color_map_inv = color_map_inv / np.amax(color_map_inv)
41
- color_map_inv = np.nan_to_num(color_map_inv)
42
- color_map_inv = color_mapper.to_rgba(color_map_inv)
43
- color_map_inv[:, :, 0:3] = color_map_inv[:, :, 0:3][..., ::-1]
44
- return color_map_inv
45
-
46
- # === Load model and config (CPU only + 去掉 'module.') ===
47
- device = torch.device("cpu")
48
 
 
49
  model_path = 'SpikeT/s2d_weights/debug_A100_SpikeTransformerUNetConv_LocalGlobal-Swin3D-T/model_best.pth.tar'
50
  config_path = os.path.join(os.path.dirname(model_path), 'config.json')
51
  with open(config_path) as f:
52
  config = json.load(f)
53
 
54
- # 更新 config 結構
55
  config['model']['gpu'] = config['gpu']
56
  config['model']['every_x_rgb_frame'] = config['data_loader']['train']['every_x_rgb_frame']
57
  config['model']['baseline'] = config['data_loader']['train']['baseline']
58
  config['model']['loss_composition'] = config['trainer']['loss_composition']
59
 
60
- # 建立模型
61
  model = eval(config['arch'])(config['model'])
62
-
63
- # 載入 checkpoint 並清理 DataParallel 的 'module.' 前綴
64
  checkpoint = torch.load(model_path, map_location='cpu')
65
  state_dict = checkpoint['state_dict']
66
  cleaned_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
67
  model.load_state_dict(cleaned_state_dict)
68
  model.eval()
69
- model.to(device)
70
-
71
  data_transform = CenterCrop(224)
72
 
73
- # 如果模型是 phased arch,先 dummy forward 初始化狀態
74
- use_phased_arch = config['use_phased_arch']
75
- if use_phased_arch:
76
- C, (H, W) = config["model"]["num_bins_events"], config["model"]["spatial_resolution"]
77
- dummy_input = torch.zeros(1, C, H, W).to(device)
78
- times = torch.zeros(1).to(device)
79
- _ = model.forward(dummy_input, times=times, prev_states=None)
80
-
81
- # === Inference 主邏輯 ===
82
- def infer_depth(filepath):
83
- file = open(filepath, 'rb')
84
- spike_seq = np.frombuffer(file.read(), 'b')
85
- spikes = RawToSpike(spike_seq, 260, 346)
86
- spikes = torch.from_numpy(spikes).to(device)
87
- data = data_transform(spikes)
 
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  dT, dH, dW = data.shape
90
- input = {'image': data[None, dT//2-64:dT//2+64].to(device)}
91
  prev_super_states = {'image': None}
92
  prev_states_lstm = {}
93
 
94
  with torch.no_grad():
95
- new_predicted_targets, _, _ = model(input, prev_super_states['image'], prev_states_lstm)
96
- predict_depth = new_predicted_targets['image'][0].cpu().numpy()
97
  spikes_np = data.permute(1, 2, 0).cpu().numpy()
98
  spike_vis = np.mean(spikes_np, axis=2)
99
 
100
- # Colormap
101
- color_map_inv = np.ones_like(predict_depth[0]) * np.amax(predict_depth[0]) - predict_depth[0]
102
- color_map_inv = np.nan_to_num(color_map_inv, nan=1)
103
- color_map_inv = color_map_inv / np.amax(color_map_inv)
104
- color_map_inv = np.nan_to_num(color_map_inv)
105
- vmax = np.percentile(color_map_inv, 95)
106
- normalizer = mpl.colors.Normalize(vmin=color_map_inv.min(), vmax=vmax)
107
- color_mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
108
- color_map = make_colormap(predict_depth, color_mapper)
109
-
110
- return (
111
- (predict_depth[0] * 255.0).astype(np.uint8),
112
- (spike_vis * 255.0).astype(np.uint8),
113
- (color_map * 255.0).astype(np.uint8)
114
- )
115
-
116
- example_dir = "assets/"
117
- if os.path.exists(example_dir):
118
- example_files = sorted([
119
- os.path.join(example_dir, f)
120
- for f in os.listdir(example_dir)
121
- if f.endswith(".npy")
122
- ])
123
- else:
124
- example_files = []
125
-
126
- iface = gr.Interface(
127
- fn=infer_depth,
128
- inputs=gr.File(label="Upload .dat Spike File"),
129
- outputs=[
130
- gr.Image(label="Predicted Depth (raw)", image_mode="L"),
131
- gr.Image(label="Spike Input Mean", image_mode="L"),
132
- gr.Image(label="Colored Depth Map")
133
- ],
134
- title="Spike Transformer - Depth Estimation (CPU)",
135
- description="上傳 .dat spike 檔案以獲得深度圖預測",
136
- examples=example_files # 正確做法:直接傳進來
137
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  if __name__ == "__main__":
140
- iface.launch()
 
3
  import os
4
  import torch
5
  import numpy as np
6
+ import matplotlib
 
7
  import matplotlib.cm as cm
8
  from SpikeT.model.S2DepthNet import S2DepthTransformerUNetConv
9
  from SpikeT.utils.data_augmentation import CenterCrop
10
 
11
+ # === 設定 ===
12
+ DEVICE = torch.device("cpu")
13
+ title = "# Spike Transformer - Depth Estimation (CPU)"
14
+ description = "上傳 `.dat` 或 `.npy` spike 檔案,模型將重建 spike 圖並預測對應的深度圖"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # === 載入模型與 config ===
17
  model_path = 'SpikeT/s2d_weights/debug_A100_SpikeTransformerUNetConv_LocalGlobal-Swin3D-T/model_best.pth.tar'
18
  config_path = os.path.join(os.path.dirname(model_path), 'config.json')
19
  with open(config_path) as f:
20
  config = json.load(f)
21
 
 
22
  config['model']['gpu'] = config['gpu']
23
  config['model']['every_x_rgb_frame'] = config['data_loader']['train']['every_x_rgb_frame']
24
  config['model']['baseline'] = config['data_loader']['train']['baseline']
25
  config['model']['loss_composition'] = config['trainer']['loss_composition']
26
 
 
27
  model = eval(config['arch'])(config['model'])
 
 
28
  checkpoint = torch.load(model_path, map_location='cpu')
29
  state_dict = checkpoint['state_dict']
30
  cleaned_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
31
  model.load_state_dict(cleaned_state_dict)
32
  model.eval()
33
+ model.to(DEVICE)
 
34
  data_transform = CenterCrop(224)
35
 
36
+ # === 工具函數 ===
37
+ def RawToSpike(video_seq, h, w, flipud=True):
38
+ video_seq = np.array(video_seq).astype(np.uint8)
39
+ img_size = h * w
40
+ img_num = len(video_seq) // (img_size // 8)
41
+ SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
42
+ pix_id = np.arange(0, h * w).reshape((h, w))
43
+ comparator = np.left_shift(1, np.mod(pix_id, 8))
44
+ byte_id = pix_id // 8
45
+ for img_id in range(img_num):
46
+ id_start = img_id * img_size // 8
47
+ id_end = id_start + img_size // 8
48
+ cur_info = video_seq[id_start:id_end]
49
+ data = cur_info[byte_id]
50
+ result = np.bitwise_and(data, comparator)
51
+ SpikeMatrix[img_id] = np.flipud((result == comparator)) if flipud else (result == comparator)
52
+ return SpikeMatrix.astype(np.float32)
53
 
54
+ def load_spike_file(path):
55
+ if path.endswith(".npy"):
56
+ return np.load(path).astype(np.float32)
57
+ elif path.endswith(".dat"):
58
+ with open(path, 'rb') as f:
59
+ video_seq = np.frombuffer(f.read(), dtype='b')
60
+ return RawToSpike(video_seq, h=260, w=346)
61
+ else:
62
+ raise ValueError("Unsupported file format. Only .dat and .npy are supported.")
63
+
64
+ def predict_recon_bsf(spike, model, device):
65
+ spikes = torch.from_numpy(spike).to(device)
66
+ data = data_transform(spikes)
67
  dT, dH, dW = data.shape
68
+ input_tensor = {'image': data[None, dT // 2 - 64: dT // 2 + 64].to(device)}
69
  prev_super_states = {'image': None}
70
  prev_states_lstm = {}
71
 
72
  with torch.no_grad():
73
+ pred, _, _ = model(input_tensor, prev_super_states['image'], prev_states_lstm)
74
+ depth = pred['image'][0].cpu().numpy()
75
  spikes_np = data.permute(1, 2, 0).cpu().numpy()
76
  spike_vis = np.mean(spikes_np, axis=2)
77
 
78
+ return torch.tensor(spike_vis).unsqueeze(0), depth
79
+
80
+ # === Gradio 介面 ===
81
+ with gr.Blocks() as demo:
82
+ gr.Markdown(title)
83
+ gr.Markdown(description)
84
+
85
+ with gr.Row():
86
+ input_file = gr.File(label="Upload .dat or .npy Spike File", type="file")
87
+ output_spike = gr.Image(label="Reconstructed Spike Image")
88
+ output_depth = gr.Image(label="Depth Prediction (Colormap)")
89
+
90
+ cmap = matplotlib.colormaps.get_cmap('Spectral_r')
91
+ submit = gr.Button("Submit")
92
+
93
+ def on_submit(file_obj):
94
+ spike = load_spike_file(file_obj.name)
95
+ spike_img, depth = predict_recon_bsf(spike, model, DEVICE)
96
+
97
+ # 處理 spike 圖
98
+ spike_img = spike_img.repeat(3, 1, 1)
99
+ h, w = spike_img.shape[1:]
100
+ min_dim = min(h, w)
101
+ center_crop = spike_img[:, (h - min_dim) // 2:(h + min_dim) // 2, (w - min_dim) // 2:(w + min_dim) // 2]
102
+ spike_img_np = (center_crop.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8)
103
+
104
+ # Colormap depth
105
+ depth = (depth - depth.min()) / (depth.max() - depth.min() + 1e-8) * 255.0
106
+ colored_depth = (cmap(depth.astype(np.uint8))[:, :, :3] * 255).astype(np.uint8)
107
+
108
+ return spike_img_np, colored_depth
109
+
110
+ submit.click(fn=on_submit, inputs=[input_file], outputs=[output_spike, output_depth])
111
+
112
+ # 示例資料(僅支援 .npy)
113
+ example_dir = "assets/examples"
114
+ if os.path.exists(example_dir):
115
+ example_files = sorted([
116
+ os.path.join(example_dir, f)
117
+ for f in os.listdir(example_dir)
118
+ if f.endswith(".npy") or f.endswith(".dat")
119
+ ])
120
+ else:
121
+ example_files = []
122
+
123
+ gr.Examples(
124
+ examples=example_files,
125
+ inputs=[input_file],
126
+ outputs=[output_spike, output_depth],
127
+ fn=on_submit,
128
+ cache_examples=False
129
+ )
130
 
131
  if __name__ == "__main__":
132
+ demo.queue().launch()