import torch.nn as nn import torch from monai.networks.layers.utils import get_act_layer class LabelEmbedder(nn.Module): def __init__(self, emb_dim=32, num_classes=2, act_name=("SWISH", {})): super().__init__() self.emb_dim = emb_dim self.embedding = nn.Embedding(num_classes, emb_dim) # self.embedding = nn.Embedding(num_classes, emb_dim//4) # self.emb_net = nn.Sequential( # nn.Linear(1, emb_dim), # get_act_layer(act_name), # nn.Linear(emb_dim, emb_dim) # ) def forward(self, condition): c = self.embedding(condition) #[B,] -> [B, C] # c = self.emb_net(c) # c = self.emb_net(condition[:,None].float()) # c = (2*condition-1)[:, None].expand(-1, self.emb_dim).type(torch.float32) return c