Spaces:
Runtime error
Runtime error
File size: 435 Bytes
f85e212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch
from medical_diffusion.models.embedders import TimeEmbbeding, SinusoidalPosEmb, LabelEmbedder
cond_emb = LabelEmbedder(10, num_classes=2)
c = torch.tensor([[0,], [1,]])
v = cond_emb(c)
print(v)
tim_emb = SinusoidalPosEmb(20, max_period=10)
t = torch.tensor([1,2,3, 1000])
v = tim_emb(t)
print(v)
tim_emb = TimeEmbbeding(4*4, SinusoidalPosEmb, {'max_period':10})
t = torch.tensor([1,2,3, 1000])
v = tim_emb(t)
print(v) |