test_embedding_shape / app_spiket.py
zzzzzeee's picture
Rename app.py to app_spiket.py
e41a656 verified
# app.py
import gradio as gr
import torch
import torch.nn as nn
from swin_transformer_3d import SwinTransformer3D
from spiketencoder import LongSpikeStreamEncoderConv
def test_model(batch_size=2, height=64, width=64):
# Initialize model
model = LongSpikeStreamEncoderConv()
# Create dummy input
input_tensor = torch.randn(batch_size, 128, height, width)
# Print initial shapes
output_text = f"Input shape: {list(input_tensor.shape)}\n\n"
# Forward pass
model.eval()
with torch.no_grad():
# Get Swin Transformer outputs
features = model.swin3d(input_tensor)
output_text += "Swin Transformer 3D outputs:\n"
for i, feat in enumerate(features):
output_text += f"Layer {i} shape: {list(feat.shape)}\n"
# Process through full model
outputs = model(input_tensor)
output_text += "\nFinal outputs after conv layers:\n"
for i, out in enumerate(outputs):
output_text += f"Layer {i} shape: {list(out.shape)}\n"
return output_text
# Gradio interface
interface = gr.Interface(
fn=test_model,
inputs=[
gr.Slider(minimum=1, maximum=8, step=1, value=2, label="Batch Size"),
gr.Slider(minimum=32, maximum=128, step=32, value=64, label="Height"),
gr.Slider(minimum=32, maximum=128, step=32, value=64, label="Width")
],
outputs=gr.Textbox(label="Feature Map Shapes"),
title="LongSpikeStreamEncoderConv Tester",
description="Test the LongSpikeStreamEncoderConv model and visualize feature map shapes at different stages"
)
if __name__ == "__main__":
interface.launch()