File size: 1,124 Bytes
1c3eb47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch.nn as nn
from mmengine.model import BaseModule
from mmengine.model.weight_init import constant_init
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmpl.registry import MODELS
from mmengine.model import BaseModule
from transformers import GPT2Model, GPT2Config


@MODELS.register_module()
class HFGPTTransformerDecoderNeck(BaseModule):
    def __init__(
            self,
            model_name='gpt2',
            from_pretrained=True,
            update_kwargs=dict(
                max_position_embeddings=512,
                hidden_size=512,
            )
    ):
        super(HFGPTTransformerDecoderNeck, self).__init__()
        self.model_name = model_name
        if from_pretrained:
            self.gpt_model = GPT2Model.from_pretrained(model_name)
        else:
            config = GPT2Config.from_pretrained(model_name)
            config.update(update_kwargs)
            self.gpt_model = GPT2Model(config=config)
            # self.wte = nn.Embedding(0, 512)

    def forward(self, *args, **kwargs):
        out_puts = self.gpt_model(*args, **kwargs)
        return out_puts