Spaces:
Running
Running
# app.py | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from swin_transformer_3d import SwinTransformer3D | |
from spiketencoder import LongSpikeStreamEncoderConv | |
def test_model(batch_size=2, height=64, width=64): | |
# Initialize model | |
model = LongSpikeStreamEncoderConv() | |
# Create dummy input | |
input_tensor = torch.randn(batch_size, 128, height, width) | |
# Print initial shapes | |
output_text = f"Input shape: {list(input_tensor.shape)}\n\n" | |
# Forward pass | |
model.eval() | |
with torch.no_grad(): | |
# Get Swin Transformer outputs | |
features = model.swin3d(input_tensor) | |
output_text += "Swin Transformer 3D outputs:\n" | |
for i, feat in enumerate(features): | |
output_text += f"Layer {i} shape: {list(feat.shape)}\n" | |
# Process through full model | |
outputs = model(input_tensor) | |
output_text += "\nFinal outputs after conv layers:\n" | |
for i, out in enumerate(outputs): | |
output_text += f"Layer {i} shape: {list(out.shape)}\n" | |
return output_text | |
# Gradio interface | |
interface = gr.Interface( | |
fn=test_model, | |
inputs=[ | |
gr.Slider(minimum=1, maximum=8, step=1, value=2, label="Batch Size"), | |
gr.Slider(minimum=32, maximum=128, step=32, value=64, label="Height"), | |
gr.Slider(minimum=32, maximum=128, step=32, value=64, label="Width") | |
], | |
outputs=gr.Textbox(label="Feature Map Shapes"), | |
title="LongSpikeStreamEncoderConv Tester", | |
description="Test the LongSpikeStreamEncoderConv model and visualize feature map shapes at different stages" | |
) | |
if __name__ == "__main__": | |
interface.launch() |