import gradio as gr import cv2 import matplotlib as mpl import matplotlib.cm as cm import numpy as np import torch from model.S2DepthNet import S2DepthTransformerUNetConv # Ensure this is accessible from data_augmentation import CenterCrop import os from os.path import join import tempfile # Assuming model weights and config are preloaded or available in the Space CONFIG = { "use_phased_arch": True, "gpu": 0, "arch": "S2DepthTransformerUNetConv", "model": { "gpu": 0, "every_x_rgb_frame": 1, # Example value, adjust as per your config "baseline": 0.1, # Example value "loss_composition": "default", # Example value "num_bins_events": 128, # Example value, adjust as needed "spatial_resolution": (224, 224) # Example resolution }, "data_loader": { "train": { "every_x_rgb_frame": 1, "baseline": 0.1 } }, "trainer": { "loss_composition": "default" } } # Load model (assuming weights are in the Space directory) INITIAL_CHECKPOINT = "path_to_your_model_weights.pth" # Update this path model = S2DepthTransformerUNetConv(CONFIG["model"]) checkpoint = torch.load(INITIAL_CHECKPOINT, map_location="cpu") # Use CPU for simplicity in Spaces model.load_state_dict(checkpoint['state_dict']) model.eval() # Predefine color mapper vmax = 0.95 # Example percentile, adjust as needed normalizer = mpl.colors.Normalize(vmin=0, vmax=vmax) color_mapper_overall = cm.ScalarMappable(norm=normalizer, cmap='magma') def RawToSpike(video_seq, h, w, flipud=True): video_seq = np.array(video_seq).astype(np.uint8) img_size = h * w img_num = len(video_seq) // (img_size // 8) SpikeMatrix = np.zeros([img_num, h, w], np.uint8) pix_id = np.arange(0, h * w) pix_id = np.reshape(pix_id, (h, w)) comparator = np.left_shift(1, np.mod(pix_id, 8)) byte_id = pix_id // 8 for img_id in np.arange(img_num): id_start = int(img_id) * int(img_size) // 8 id_end = int(id_start) + int(img_size) // 8 cur_info = video_seq[id_start:id_end] data = cur_info[byte_id] result = np.bitwise_and(data, comparator) if flipud: SpikeMatrix[img_id, :, :] = np.flipud((result == comparator)) else: SpikeMatrix[img_id, :, :] = (result == comparator) return SpikeMatrix def make_colormap(img, color_mapper): color_map_inv = np.ones_like(img[0]) * np.amax(img[0]) - img[0] color_map_inv = np.nan_to_num(color_map_inv, nan=1) color_map_inv = color_map_inv / np.amax(color_map_inv) color_map_inv = np.nan_to_num(color_map_inv) color_map_inv = color_mapper.to_rgba(color_map_inv) color_map_inv[:, :, 0:3] = color_map_inv[:, :, 0:3][..., ::-1] return color_map_inv def process_spike_file(file): with open(file.name, 'rb') as f: spike_seq = f.read() spike_seq = np.frombuffer(spike_seq, 'b') # Process spike data spikes = RawToSpike(spike_seq, 250, 400) spikes = spikes.astype(np.float32) spikes = torch.from_numpy(spikes) data_tranfsorm = CenterCrop(224) data = data_tranfsorm(spikes) dT, dH, dW = data.shape input_data = {"image": data[None, dT//2-64:dT//2+64]} prev_super_states = {"image": None} prev_states_lstm = {} with torch.no_grad(): new_predicted_targets, _, _ = model(input_data, prev_super_states["image"], prev_states_lstm) predict_depth = new_predicted_targets["image"][0].cpu().numpy() input_spikes = np.mean(data.permute(1, 2, 0).cpu().numpy(), axis=2).astype(np.float32) color_map = make_colormap(predict_depth, color_mapper_overall) # Convert to images for Gradio depth_img = (predict_depth[0] * 255.0).astype(np.uint8) input_img = (input_spikes * 255.0).astype(np.uint8) color_img = (color_map * 255.0).astype(np.uint8) # Save temporarily and return paths with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_depth, \ tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_input, \ tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_color: cv2.imwrite(tmp_depth.name, depth_img[:, :, None]) cv2.imwrite(tmp_input.name, input_img[:, :, None]) cv2.imwrite(tmp_color.name, color_img) return tmp_depth.name, tmp_input.name, tmp_color.name # Gradio Interface interface = gr.Interface( fn=process_spike_file, inputs=gr.File(label="Upload Spike Sequence File"), outputs=[ gr.Image(label="Depth Estimation"), gr.Image(label="Input Spikes"), gr.Image(label="Color Map") ], title="Spike Stream Depth Estimation", description="Upload a spike sequence file to estimate depth using S2DepthTransformerUNetConv." ) if __name__ == "__main__": interface.launch()