File size: 1,940 Bytes
9e42319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from QBModelWrapperCopy import QBModelWrapper
from QBModelConfig import QBModelConfig
from QBpipeline import QApipeline
from transformers.pipelines import PIPELINE_REGISTRY
from transformers import AutoModelForQuestionAnswering, TFAutoModelForQuestionAnswering
from transformers import pipeline
from transformers import AutoConfig, AutoModel, AutoModelForQuestionAnswering, TFAutoModel
from transformers import AutoTokenizer
    
config = QBModelConfig()
qb_model = QBModelWrapper(config)

# qa_pipe = QApipeline(model=qb_model)


AutoConfig.register("QA-umd-quizbowl", QBModelConfig)
AutoModel.register(QBModelConfig, QBModelWrapper)
AutoModelForQuestionAnswering.register(QBModelConfig, QBModelWrapper)
# TFAutoModel.register(QBModelConfig, QBModelWrapper)
# TFAutoModelForQuestionAnswering(QBModelConfig, QBModelWrapper)

QBModelConfig.register_for_auto_class()
QBModelWrapper.register_for_auto_class("AutoModel")
QBModelWrapper.register_for_auto_class("AutoModelForQuestionAnswering")


# result = qa_pipe(question="This star in the solar system has 8 planets", context="Context for the question")
# print(result["answer"])

PIPELINE_REGISTRY.register_pipeline(
    "qa-pipeline-qb",
    pipeline_class=QApipeline,
    pt_model=AutoModelForQuestionAnswering,
    tf_model=TFAutoModelForQuestionAnswering, 
    # pt_model=AutoModel,
    # tf_model=TFAutoModel
)

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
AutoTokenizer.register(QBModelConfig, tokenizer.slow_tokenizer_class)


qa_pipe = pipeline("qa-pipeline-qb", model=qb_model, tokenizer=tokenizer)
#qa_pipe.push_to_hub("new-attempt-pipeline-2", safe_serialization=False)
qa_pipe.save_pretrained("main", safe_serialization=False)

result = qa_pipe(question="This star in the solar system has 8 planets", context="Context for the question")
print(result)

#if still doesnt work then try making custom pipeline that inherits from QuestionAnsweringPipeline