zzzzzeee commited on
Commit
5352cd7
·
verified ·
1 Parent(s): c679568

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import torch
4
+ import torch.nn as nn
5
+ from swin_transformer_3d import SwinTransformer3D
6
+ from spiketencoder import LongSpikeStreamEncoderConv
7
+
8
+
9
+ def test_model(batch_size=2, height=64, width=64):
10
+ # Initialize model
11
+ model = LongSpikeStreamEncoderConv()
12
+
13
+ # Create dummy input
14
+ input_tensor = torch.randn(batch_size, 128, height, width)
15
+
16
+ # Print initial shapes
17
+ output_text = f"Input shape: {list(input_tensor.shape)}\n\n"
18
+
19
+ # Forward pass
20
+ model.eval()
21
+ with torch.no_grad():
22
+ # Get Swin Transformer outputs
23
+ features = model.swin3d(input_tensor)
24
+ output_text += "Swin Transformer 3D outputs:\n"
25
+ for i, feat in enumerate(features):
26
+ output_text += f"Layer {i} shape: {list(feat.shape)}\n"
27
+
28
+ # Process through full model
29
+ outputs = model(input_tensor)
30
+ output_text += "\nFinal outputs after conv layers:\n"
31
+ for i, out in enumerate(outputs):
32
+ output_text += f"Layer {i} shape: {list(out.shape)}\n"
33
+
34
+ return output_text
35
+
36
+ # Gradio interface
37
+ interface = gr.Interface(
38
+ fn=test_model,
39
+ inputs=[
40
+ gr.Slider(minimum=1, maximum=8, step=1, value=2, label="Batch Size"),
41
+ gr.Slider(minimum=32, maximum=128, step=32, value=64, label="Height"),
42
+ gr.Slider(minimum=32, maximum=128, step=32, value=64, label="Width")
43
+ ],
44
+ outputs=gr.Textbox(label="Feature Map Shapes"),
45
+ title="LongSpikeStreamEncoderConv Tester",
46
+ description="Test the LongSpikeStreamEncoderConv model and visualize feature map shapes at different stages"
47
+ )
48
+
49
+ if __name__ == "__main__":
50
+ interface.launch()