Spaces:
Runtime error
Runtime error
File size: 711 Bytes
f85e212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
import torch
from medical_diffusion.models.utils.attention_blocks import LinearTransformer, SpatialTransformer
input = torch.randn((1, 32, 16, 64, 64)) # 3D
input = torch.randn((1, 32, 64, 64)) # 2D
b, ch, *_ = input.shape
dim = input.ndim
# attention = SpatialTransformer(dim-2, in_channels=ch, out_channels=ch, num_heads=8)
# attention(input)
embedding = input
embedding = None
emb_dim = embedding.shape[1] if embedding is not None else None
attention = LinearTransformer(input.ndim-2, in_channels=ch, out_channels=ch, num_heads=3, emb_dim=emb_dim)
attention = SpatialTransformer(input.ndim-2, in_channels=ch, out_channels=ch, num_heads=3, emb_dim=emb_dim)
print(attention(input, embedding)) |