Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
import torch | |
from monai.networks.layers.utils import get_act_layer | |
class LabelEmbedder(nn.Module): | |
def __init__(self, emb_dim=32, num_classes=2, act_name=("SWISH", {})): | |
super().__init__() | |
self.emb_dim = emb_dim | |
self.embedding = nn.Embedding(num_classes, emb_dim) | |
# self.embedding = nn.Embedding(num_classes, emb_dim//4) | |
# self.emb_net = nn.Sequential( | |
# nn.Linear(1, emb_dim), | |
# get_act_layer(act_name), | |
# nn.Linear(emb_dim, emb_dim) | |
# ) | |
def forward(self, condition): | |
c = self.embedding(condition) #[B,] -> [B, C] | |
# c = self.emb_net(c) | |
# c = self.emb_net(condition[:,None].float()) | |
# c = (2*condition-1)[:, None].expand(-1, self.emb_dim).type(torch.float32) | |
return c | |