Spaces:
Runtime error
Runtime error
File size: 1,124 Bytes
1c3eb47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch.nn as nn
from mmengine.model import BaseModule
from mmengine.model.weight_init import constant_init
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmpl.registry import MODELS
from mmengine.model import BaseModule
from transformers import GPT2Model, GPT2Config
@MODELS.register_module()
class HFGPTTransformerDecoderNeck(BaseModule):
def __init__(
self,
model_name='gpt2',
from_pretrained=True,
update_kwargs=dict(
max_position_embeddings=512,
hidden_size=512,
)
):
super(HFGPTTransformerDecoderNeck, self).__init__()
self.model_name = model_name
if from_pretrained:
self.gpt_model = GPT2Model.from_pretrained(model_name)
else:
config = GPT2Config.from_pretrained(model_name)
config.update(update_kwargs)
self.gpt_model = GPT2Model(config=config)
# self.wte = nn.Embedding(0, 512)
def forward(self, *args, **kwargs):
out_puts = self.gpt_model(*args, **kwargs)
return out_puts
|