File size: 1,120 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch


class DeepPrompt(torch.nn.Module):
    # naive implementation
    def __init__(self, cfg):
        super().__init__()

        embedding_hidden_size = cfg.MODEL.BERT.HIDDEN_SIZE
        self.target_prompt = cfg.MODEL.PROMPT_EMBED.TARGET_DEEP_PROMPT and not cfg.MODEL.PROMPT_EMBED.SHARE_DEEP_PROMPT
        self.embedding = torch.nn.Embedding(cfg.MODEL.PROMPT_EMBED.INPUT_DEEP_PROMPT_LENGTH, embedding_hidden_size)
        if self.target_prompt:
            self.target_embedding = torch.nn.Embedding(cfg.MODEL.PROMPT_EMBED.TARGET_DEEP_PROMPT_LENGTH, embedding_hidden_size)


    def forward(self, x, batch_first=False, data_type=None, **kwargs):
        # x: length, bs, hidden_size

        if data_type == 'target' and self.target_prompt:
            embddings = self.target_embedding.weight
        else:
            embddings = self.embedding.weight

        if batch_first:
            bs = x.shape[0]
            embddings = embddings.unsqueeze(0).expand(bs, -1, -1)
        else:
            bs = x.shape[1]
            embddings = embddings.unsqueeze(1).expand(-1,bs, -1)
        return embddings