|
import torch |
|
from torch import nn |
|
from einops import rearrange |
|
|
|
from uniperceiver.config import configurable |
|
|
|
from ..layers.create_act import get_act_layer |
|
from .build import EMBEDDING_REGISTRY |
|
|
|
__all__ = ["PrefixPromptEmbedding"] |
|
|
|
@EMBEDDING_REGISTRY.register() |
|
class PrefixPromptEmbedding(nn.Module): |
|
@configurable |
|
def __init__( |
|
self, |
|
*, |
|
cfg, |
|
dim: int, |
|
query_size: int, |
|
label_prompt, |
|
label_size, |
|
**kwargs): |
|
super(PrefixPromptEmbedding, self).__init__() |
|
|
|
self.cfg = cfg |
|
self.input_prompt = cfg.MODEL.PROMPT_EMBED.INPUT_PROMPT and cfg.MODEL.PROMPT_EMBED.PROMPT_LENGTH > 0 |
|
self.target_prompt = cfg.MODEL.PROMPT_EMBED.TARGET_PROMPT and cfg.MODEL.PROMPT_EMBED.TARGET_PROMPT_LENGTH > 0 |
|
self.label_prompt = label_prompt |
|
|
|
if self.input_prompt: |
|
self.embeddings = nn.Embedding(cfg.MODEL.PROMPT_EMBED.PROMPT_LENGTH, cfg.MODEL.ENCODER_DIM) |
|
if self.target_prompt: |
|
if self.label_prompt: |
|
self.target_embedding = nn.Parameter( |
|
torch.zeros((cfg.MODEL.PROMPT_EMBED.LABEL_SIZE, 1, cfg.MODEL.ENCODER_DIM))) |
|
else: |
|
self.target_embedding = nn.Embedding(cfg.MODEL.PROMPT_EMBED.TARGET_PROMPT_LENGTH, cfg.MODEL.ENCODER_DIM) |
|
|
|
self.embeddings_act = kwargs.pop("embeddings_act", None) |
|
self.embeddings_norm = kwargs.pop("embeddings_norm", None) |
|
self.embeddings_dropout = kwargs.pop("embeddings_dropout", None) |
|
self.prompt_with_pos = kwargs.pop('prompt_with_pos', None) |
|
|
|
@classmethod |
|
def from_config(cls, cfg): |
|
kwargs = { |
|
"dim": cfg.MODEL.ENCODER_DIM, |
|
"query_size": cfg.MODEL.PROMPT_EMBED.PROMPT_LENGTH, |
|
"label_prompt": cfg.MODEL.PROMPT_EMBED.LABLE_PROMPT, |
|
"label_size": cfg.MODEL.PROMPT_EMBED.LABEL_SIZE, |
|
"target_prompt": cfg.MODEL.PROMPT_EMBED.TARGET_PROMPT, |
|
"num_layers": cfg.MODEL.BERT.NUM_HIDDEN_LAYERS, |
|
} |
|
|
|
activation_name = (cfg.MODEL.PROMPT_EMBED.ACTIVATION).lower() |
|
if activation_name != "none": |
|
activation = get_act_layer(activation_name) |
|
assert activation is not None |
|
|
|
act_kwargs = {} |
|
if activation_name in {"elu", "celu"}: |
|
act_kwargs["alpha"] = cfg.MODEL.PROMPT_EMBED.ELU_ALPHA |
|
embeddings_act = activation(**act_kwargs) |
|
kwargs['embeddings_act'] = embeddings_act |
|
|
|
if cfg.MODEL.PROMPT_EMBED.DROPOUT > 0: |
|
embeddings_dropout = nn.Dropout(cfg.MODEL.PROMPT_EMBED.DROPOUT) |
|
kwargs['embeddings_dropout'] = embeddings_dropout |
|
|
|
if cfg.MODEL.PROMPT_EMBED.USE_NORM: |
|
embeddings_norm = nn.LayerNorm(cfg.MODEL.PROMPT_EMBED.DIM) |
|
kwargs['embeddings_norm'] = embeddings_norm |
|
|
|
kwargs['prompt_with_pos'] = cfg.MODEL.PROMPT_EMBED.WITH_POS |
|
kwargs['cfg'] = cfg |
|
|
|
return kwargs |
|
|
|
def forward(self, data_list): |
|
bs = data_list[0]['data'].shape[0] |
|
prompt_embed = self._forward(bs, data_type=data_list[0]['data_type']) |
|
|
|
if prompt_embed is None: |
|
return |
|
|
|
|
|
insert_data = { |
|
'data': prompt_embed, |
|
'invalid_mask': None, |
|
'modality': None, |
|
'data_type': data_list[0]['data_type'], |
|
|
|
} |
|
data_list.insert(0, insert_data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _forward(self, bs, data_type:str = None): |
|
assert data_type in ['input', 'target'] |
|
if data_type == 'input' and self.input_prompt: |
|
embeddings = self.embeddings.weight.unsqueeze(0).repeat(bs, 1, 1) |
|
elif data_type == 'target' and self.target_prompt: |
|
if not self.label_prompt: |
|
embeddings = self.target_embedding.weight.unsqueeze(0).repeat(bs, 1, 1) |
|
elif self.label_prompt: |
|
embeddings = self.target_embedding |
|
else: |
|
|
|
return None |
|
else: |
|
return None |
|
|
|
if self.embeddings_act is not None: |
|
embeddings = self.embeddings_act(embeddings) |
|
|
|
if self.embeddings_norm is not None: |
|
embeddings = self.embeddings_norm(embeddings) |
|
|
|
if self.embeddings_dropout is not None: |
|
embeddings = self.embeddings_dropout(embeddings) |
|
|
|
return embeddings |
|
|