# app.py import gradio as gr import torch import torch.nn as nn from swin_transformer_3d import SwinTransformer3D from spiketencoder import LongSpikeStreamEncoderConv def test_model(batch_size=2, height=64, width=64): # Initialize model model = LongSpikeStreamEncoderConv() # Create dummy input input_tensor = torch.randn(batch_size, 128, height, width) # Print initial shapes output_text = f"Input shape: {list(input_tensor.shape)}\n\n" # Forward pass model.eval() with torch.no_grad(): # Get Swin Transformer outputs features = model.swin3d(input_tensor) output_text += "Swin Transformer 3D outputs:\n" for i, feat in enumerate(features): output_text += f"Layer {i} shape: {list(feat.shape)}\n" # Process through full model outputs = model(input_tensor) output_text += "\nFinal outputs after conv layers:\n" for i, out in enumerate(outputs): output_text += f"Layer {i} shape: {list(out.shape)}\n" return output_text # Gradio interface interface = gr.Interface( fn=test_model, inputs=[ gr.Slider(minimum=1, maximum=8, step=1, value=2, label="Batch Size"), gr.Slider(minimum=32, maximum=128, step=32, value=64, label="Height"), gr.Slider(minimum=32, maximum=128, step=32, value=64, label="Width") ], outputs=gr.Textbox(label="Feature Map Shapes"), title="LongSpikeStreamEncoderConv Tester", description="Test the LongSpikeStreamEncoderConv model and visualize feature map shapes at different stages" ) if __name__ == "__main__": interface.launch()