File size: 674 Bytes
a936ce4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from typing import List
from transformers import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from QBModelConfig import QBModelConfig
from qbmodel import QuizBowlModel
import torch
class QBModelWrapper(PreTrainedModel):
config_class= QBModelConfig
# def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
# super().__init__(config, *inputs, **kwargs)
# self.model = QuizBowlModel()
def __init__(self, config):
super().__init__(config)
self.model = QuizBowlModel()
def forward(self, question):
return self.model.guess_and_buzz(question)
|