herrius's picture
Upload 259 files
32b542e
import torch
class DeepPrompt(torch.nn.Module):
# naive implementation
def __init__(self, cfg):
super().__init__()
embedding_hidden_size = cfg.MODEL.BERT.HIDDEN_SIZE
self.target_prompt = cfg.MODEL.PROMPT_EMBED.TARGET_DEEP_PROMPT and not cfg.MODEL.PROMPT_EMBED.SHARE_DEEP_PROMPT
self.embedding = torch.nn.Embedding(cfg.MODEL.PROMPT_EMBED.INPUT_DEEP_PROMPT_LENGTH, embedding_hidden_size)
if self.target_prompt:
self.target_embedding = torch.nn.Embedding(cfg.MODEL.PROMPT_EMBED.TARGET_DEEP_PROMPT_LENGTH, embedding_hidden_size)
def forward(self, x, batch_first=False, data_type=None, **kwargs):
# x: length, bs, hidden_size
if data_type == 'target' and self.target_prompt:
embddings = self.target_embedding.weight
else:
embddings = self.embedding.weight
if batch_first:
bs = x.shape[0]
embddings = embddings.unsqueeze(0).expand(bs, -1, -1)
else:
bs = x.shape[1]
embddings = embddings.unsqueeze(1).expand(-1,bs, -1)
return embddings