File size: 674 Bytes
a936ce4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from typing import List
from transformers import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
from QBModelConfig import QBModelConfig
from qbmodel import QuizBowlModel
import torch 

class QBModelWrapper(PreTrainedModel):
    config_class= QBModelConfig

    # def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
    #     super().__init__(config, *inputs, **kwargs)

    #     self.model = QuizBowlModel()


    def __init__(self, config):
        super().__init__(config)

        self.model = QuizBowlModel()

    

    def forward(self, question):
        return self.model.guess_and_buzz(question)