Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
from mmengine.model import BaseModule | |
from mmengine.model.weight_init import constant_init | |
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm | |
from mmpl.registry import MODELS | |
from mmengine.model import BaseModule | |
from transformers import GPT2Model, GPT2Config | |
class HFGPTTransformerDecoderNeck(BaseModule): | |
def __init__( | |
self, | |
model_name='gpt2', | |
from_pretrained=True, | |
update_kwargs=dict( | |
max_position_embeddings=512, | |
hidden_size=512, | |
) | |
): | |
super(HFGPTTransformerDecoderNeck, self).__init__() | |
self.model_name = model_name | |
if from_pretrained: | |
self.gpt_model = GPT2Model.from_pretrained(model_name) | |
else: | |
config = GPT2Config.from_pretrained(model_name) | |
config.update(update_kwargs) | |
self.gpt_model = GPT2Model(config=config) | |
# self.wte = nn.Embedding(0, 512) | |
def forward(self, *args, **kwargs): | |
out_puts = self.gpt_model(*args, **kwargs) | |
return out_puts | |