File size: 4,903 Bytes
2eb339f
485a5eb
2eb339f
 
485a5eb
 
2eb339f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672f88e
2eb339f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
import cv2
import matplotlib as mpl
import matplotlib.cm as cm
import numpy as np
import torch
from model.S2DepthNet import S2DepthTransformerUNetConv  # Ensure this is accessible
from data_augmentation import CenterCrop
import os
from os.path import join
import tempfile

# Assuming model weights and config are preloaded or available in the Space
CONFIG = {
    "use_phased_arch": True,
    "gpu": 0,
    "arch": "S2DepthTransformerUNetConv",
    "model": {
        "gpu": 0,
        "every_x_rgb_frame": 1,  # Example value, adjust as per your config
        "baseline": 0.1,         # Example value
        "loss_composition": "default",  # Example value
        "num_bins_events": 128,  # Example value, adjust as needed
        "spatial_resolution": (224, 224)  # Example resolution
    },
    "data_loader": {
        "train": {
            "every_x_rgb_frame": 1,
            "baseline": 0.1
        }
    },
    "trainer": {
        "loss_composition": "default"
    }
}

# Load model (assuming weights are in the Space directory)
INITIAL_CHECKPOINT = "path_to_your_model_weights.pth"  # Update this path
model = S2DepthTransformerUNetConv(CONFIG["model"])
checkpoint = torch.load(INITIAL_CHECKPOINT, map_location="cpu")  # Use CPU for simplicity in Spaces
model.load_state_dict(checkpoint['state_dict'])
model.eval()

# Predefine color mapper
vmax = 0.95  # Example percentile, adjust as needed
normalizer = mpl.colors.Normalize(vmin=0, vmax=vmax)
color_mapper_overall = cm.ScalarMappable(norm=normalizer, cmap='magma')

def RawToSpike(video_seq, h, w, flipud=True):
    video_seq = np.array(video_seq).astype(np.uint8)
    img_size = h * w
    img_num = len(video_seq) // (img_size // 8)
    SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
    pix_id = np.arange(0, h * w)
    pix_id = np.reshape(pix_id, (h, w))
    comparator = np.left_shift(1, np.mod(pix_id, 8))
    byte_id = pix_id // 8

    for img_id in np.arange(img_num):
        id_start = int(img_id) * int(img_size) // 8
        id_end = int(id_start) + int(img_size) // 8
        cur_info = video_seq[id_start:id_end]
        data = cur_info[byte_id]
        result = np.bitwise_and(data, comparator)
        if flipud:
            SpikeMatrix[img_id, :, :] = np.flipud((result == comparator))
        else:
            SpikeMatrix[img_id, :, :] = (result == comparator)
    return SpikeMatrix

def make_colormap(img, color_mapper):
    color_map_inv = np.ones_like(img[0]) * np.amax(img[0]) - img[0]
    color_map_inv = np.nan_to_num(color_map_inv, nan=1)
    color_map_inv = color_map_inv / np.amax(color_map_inv)
    color_map_inv = np.nan_to_num(color_map_inv)
    color_map_inv = color_mapper.to_rgba(color_map_inv)
    color_map_inv[:, :, 0:3] = color_map_inv[:, :, 0:3][..., ::-1]
    return color_map_inv

def process_spike_file(file):
    with open(file.name, 'rb') as f:
        spike_seq = f.read()
        spike_seq = np.frombuffer(spike_seq, 'b')
    
    # Process spike data
    spikes = RawToSpike(spike_seq, 250, 400)
    spikes = spikes.astype(np.float32)
    spikes = torch.from_numpy(spikes)
    
    data_tranfsorm = CenterCrop(224)
    data = data_tranfsorm(spikes)
    dT, dH, dW = data.shape
    
    input_data = {"image": data[None, dT//2-64:dT//2+64]}
    prev_super_states = {"image": None}
    prev_states_lstm = {}
    
    with torch.no_grad():
        new_predicted_targets, _, _ = model(input_data, prev_super_states["image"], prev_states_lstm)
    
    predict_depth = new_predicted_targets["image"][0].cpu().numpy()
    input_spikes = np.mean(data.permute(1, 2, 0).cpu().numpy(), axis=2).astype(np.float32)
    color_map = make_colormap(predict_depth, color_mapper_overall)
    
    # Convert to images for Gradio
    depth_img = (predict_depth[0] * 255.0).astype(np.uint8)
    input_img = (input_spikes * 255.0).astype(np.uint8)
    color_img = (color_map * 255.0).astype(np.uint8)
    
    # Save temporarily and return paths
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_depth, \
         tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_input, \
         tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_color:
        cv2.imwrite(tmp_depth.name, depth_img[:, :, None])
        cv2.imwrite(tmp_input.name, input_img[:, :, None])
        cv2.imwrite(tmp_color.name, color_img)
        
        return tmp_depth.name, tmp_input.name, tmp_color.name

# Gradio Interface
interface = gr.Interface(
    fn=process_spike_file,
    inputs=gr.File(label="Upload Spike Sequence File"),
    outputs=[
        gr.Image(label="Depth Estimation"),
        gr.Image(label="Input Spikes"),
        gr.Image(label="Color Map")
    ],
    title="Spike Stream Depth Estimation",
    description="Upload a spike sequence file to estimate depth using S2DepthTransformerUNetConv."
)

if __name__ == "__main__":
    interface.launch()