Spaces:
Runtime error
Runtime error
import gradio as gr | |
import cv2 | |
import matplotlib as mpl | |
import matplotlib.cm as cm | |
import numpy as np | |
import torch | |
from model.S2DepthNet import S2DepthTransformerUNetConv # Ensure this is accessible | |
from data_augmentation import CenterCrop | |
import os | |
from os.path import join | |
import tempfile | |
# Assuming model weights and config are preloaded or available in the Space | |
CONFIG = { | |
"use_phased_arch": True, | |
"gpu": 0, | |
"arch": "S2DepthTransformerUNetConv", | |
"model": { | |
"gpu": 0, | |
"every_x_rgb_frame": 1, # Example value, adjust as per your config | |
"baseline": 0.1, # Example value | |
"loss_composition": "default", # Example value | |
"num_bins_events": 128, # Example value, adjust as needed | |
"spatial_resolution": (224, 224) # Example resolution | |
}, | |
"data_loader": { | |
"train": { | |
"every_x_rgb_frame": 1, | |
"baseline": 0.1 | |
} | |
}, | |
"trainer": { | |
"loss_composition": "default" | |
} | |
} | |
# Load model (assuming weights are in the Space directory) | |
INITIAL_CHECKPOINT = "path_to_your_model_weights.pth" # Update this path | |
model = S2DepthTransformerUNetConv(CONFIG["model"]) | |
checkpoint = torch.load(INITIAL_CHECKPOINT, map_location="cpu") # Use CPU for simplicity in Spaces | |
model.load_state_dict(checkpoint['state_dict']) | |
model.eval() | |
# Predefine color mapper | |
vmax = 0.95 # Example percentile, adjust as needed | |
normalizer = mpl.colors.Normalize(vmin=0, vmax=vmax) | |
color_mapper_overall = cm.ScalarMappable(norm=normalizer, cmap='magma') | |
def RawToSpike(video_seq, h, w, flipud=True): | |
video_seq = np.array(video_seq).astype(np.uint8) | |
img_size = h * w | |
img_num = len(video_seq) // (img_size // 8) | |
SpikeMatrix = np.zeros([img_num, h, w], np.uint8) | |
pix_id = np.arange(0, h * w) | |
pix_id = np.reshape(pix_id, (h, w)) | |
comparator = np.left_shift(1, np.mod(pix_id, 8)) | |
byte_id = pix_id // 8 | |
for img_id in np.arange(img_num): | |
id_start = int(img_id) * int(img_size) // 8 | |
id_end = int(id_start) + int(img_size) // 8 | |
cur_info = video_seq[id_start:id_end] | |
data = cur_info[byte_id] | |
result = np.bitwise_and(data, comparator) | |
if flipud: | |
SpikeMatrix[img_id, :, :] = np.flipud((result == comparator)) | |
else: | |
SpikeMatrix[img_id, :, :] = (result == comparator) | |
return SpikeMatrix | |
def make_colormap(img, color_mapper): | |
color_map_inv = np.ones_like(img[0]) * np.amax(img[0]) - img[0] | |
color_map_inv = np.nan_to_num(color_map_inv, nan=1) | |
color_map_inv = color_map_inv / np.amax(color_map_inv) | |
color_map_inv = np.nan_to_num(color_map_inv) | |
color_map_inv = color_mapper.to_rgba(color_map_inv) | |
color_map_inv[:, :, 0:3] = color_map_inv[:, :, 0:3][..., ::-1] | |
return color_map_inv | |
def process_spike_file(file): | |
with open(file.name, 'rb') as f: | |
spike_seq = f.read() | |
spike_seq = np.frombuffer(spike_seq, 'b') | |
# Process spike data | |
spikes = RawToSpike(spike_seq, 250, 400) | |
spikes = spikes.astype(np.float32) | |
spikes = torch.from_numpy(spikes) | |
data_tranfsorm = CenterCrop(224) | |
data = data_tranfsorm(spikes) | |
dT, dH, dW = data.shape | |
input_data = {"image": data[None, dT//2-64:dT//2+64]} | |
prev_super_states = {"image": None} | |
prev_states_lstm = {} | |
with torch.no_grad(): | |
new_predicted_targets, _, _ = model(input_data, prev_super_states["image"], prev_states_lstm) | |
predict_depth = new_predicted_targets["image"][0].cpu().numpy() | |
input_spikes = np.mean(data.permute(1, 2, 0).cpu().numpy(), axis=2).astype(np.float32) | |
color_map = make_colormap(predict_depth, color_mapper_overall) | |
# Convert to images for Gradio | |
depth_img = (predict_depth[0] * 255.0).astype(np.uint8) | |
input_img = (input_spikes * 255.0).astype(np.uint8) | |
color_img = (color_map * 255.0).astype(np.uint8) | |
# Save temporarily and return paths | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_depth, \ | |
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_input, \ | |
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_color: | |
cv2.imwrite(tmp_depth.name, depth_img[:, :, None]) | |
cv2.imwrite(tmp_input.name, input_img[:, :, None]) | |
cv2.imwrite(tmp_color.name, color_img) | |
return tmp_depth.name, tmp_input.name, tmp_color.name | |
# Gradio Interface | |
interface = gr.Interface( | |
fn=process_spike_file, | |
inputs=gr.File(label="Upload Spike Sequence File"), | |
outputs=[ | |
gr.Image(label="Depth Estimation"), | |
gr.Image(label="Input Spikes"), | |
gr.Image(label="Color Map") | |
], | |
title="Spike Stream Depth Estimation", | |
description="Upload a spike sequence file to estimate depth using S2DepthTransformerUNetConv." | |
) | |
if __name__ == "__main__": | |
interface.launch() |