motionmask / utils_spike.py
zzzzzeee's picture
Update utils_spike.py
2eb339f verified
import gradio as gr
import cv2
import matplotlib as mpl
import matplotlib.cm as cm
import numpy as np
import torch
from model.S2DepthNet import S2DepthTransformerUNetConv # Ensure this is accessible
from data_augmentation import CenterCrop
import os
from os.path import join
import tempfile
# Assuming model weights and config are preloaded or available in the Space
CONFIG = {
"use_phased_arch": True,
"gpu": 0,
"arch": "S2DepthTransformerUNetConv",
"model": {
"gpu": 0,
"every_x_rgb_frame": 1, # Example value, adjust as per your config
"baseline": 0.1, # Example value
"loss_composition": "default", # Example value
"num_bins_events": 128, # Example value, adjust as needed
"spatial_resolution": (224, 224) # Example resolution
},
"data_loader": {
"train": {
"every_x_rgb_frame": 1,
"baseline": 0.1
}
},
"trainer": {
"loss_composition": "default"
}
}
# Load model (assuming weights are in the Space directory)
INITIAL_CHECKPOINT = "path_to_your_model_weights.pth" # Update this path
model = S2DepthTransformerUNetConv(CONFIG["model"])
checkpoint = torch.load(INITIAL_CHECKPOINT, map_location="cpu") # Use CPU for simplicity in Spaces
model.load_state_dict(checkpoint['state_dict'])
model.eval()
# Predefine color mapper
vmax = 0.95 # Example percentile, adjust as needed
normalizer = mpl.colors.Normalize(vmin=0, vmax=vmax)
color_mapper_overall = cm.ScalarMappable(norm=normalizer, cmap='magma')
def RawToSpike(video_seq, h, w, flipud=True):
video_seq = np.array(video_seq).astype(np.uint8)
img_size = h * w
img_num = len(video_seq) // (img_size // 8)
SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
pix_id = np.arange(0, h * w)
pix_id = np.reshape(pix_id, (h, w))
comparator = np.left_shift(1, np.mod(pix_id, 8))
byte_id = pix_id // 8
for img_id in np.arange(img_num):
id_start = int(img_id) * int(img_size) // 8
id_end = int(id_start) + int(img_size) // 8
cur_info = video_seq[id_start:id_end]
data = cur_info[byte_id]
result = np.bitwise_and(data, comparator)
if flipud:
SpikeMatrix[img_id, :, :] = np.flipud((result == comparator))
else:
SpikeMatrix[img_id, :, :] = (result == comparator)
return SpikeMatrix
def make_colormap(img, color_mapper):
color_map_inv = np.ones_like(img[0]) * np.amax(img[0]) - img[0]
color_map_inv = np.nan_to_num(color_map_inv, nan=1)
color_map_inv = color_map_inv / np.amax(color_map_inv)
color_map_inv = np.nan_to_num(color_map_inv)
color_map_inv = color_mapper.to_rgba(color_map_inv)
color_map_inv[:, :, 0:3] = color_map_inv[:, :, 0:3][..., ::-1]
return color_map_inv
def process_spike_file(file):
with open(file.name, 'rb') as f:
spike_seq = f.read()
spike_seq = np.frombuffer(spike_seq, 'b')
# Process spike data
spikes = RawToSpike(spike_seq, 250, 400)
spikes = spikes.astype(np.float32)
spikes = torch.from_numpy(spikes)
data_tranfsorm = CenterCrop(224)
data = data_tranfsorm(spikes)
dT, dH, dW = data.shape
input_data = {"image": data[None, dT//2-64:dT//2+64]}
prev_super_states = {"image": None}
prev_states_lstm = {}
with torch.no_grad():
new_predicted_targets, _, _ = model(input_data, prev_super_states["image"], prev_states_lstm)
predict_depth = new_predicted_targets["image"][0].cpu().numpy()
input_spikes = np.mean(data.permute(1, 2, 0).cpu().numpy(), axis=2).astype(np.float32)
color_map = make_colormap(predict_depth, color_mapper_overall)
# Convert to images for Gradio
depth_img = (predict_depth[0] * 255.0).astype(np.uint8)
input_img = (input_spikes * 255.0).astype(np.uint8)
color_img = (color_map * 255.0).astype(np.uint8)
# Save temporarily and return paths
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_depth, \
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_input, \
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_color:
cv2.imwrite(tmp_depth.name, depth_img[:, :, None])
cv2.imwrite(tmp_input.name, input_img[:, :, None])
cv2.imwrite(tmp_color.name, color_img)
return tmp_depth.name, tmp_input.name, tmp_color.name
# Gradio Interface
interface = gr.Interface(
fn=process_spike_file,
inputs=gr.File(label="Upload Spike Sequence File"),
outputs=[
gr.Image(label="Depth Estimation"),
gr.Image(label="Input Spikes"),
gr.Image(label="Color Map")
],
title="Spike Stream Depth Estimation",
description="Upload a spike sequence file to estimate depth using S2DepthTransformerUNetConv."
)
if __name__ == "__main__":
interface.launch()