Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -59,7 +59,7 @@ class BigBirdForQuestionAnsweringWithNull(PreTrainedModel):
|
|
59 |
model_id = 'vasudevgupta/bigbird-roberta-natural-questions'
|
60 |
config = BigBirdConfig.from_pretrained(model_id)
|
61 |
model = BigBirdForQuestionAnsweringWithNull(config, model_id)
|
62 |
-
model.to('cuda')
|
63 |
model.eval()
|
64 |
|
65 |
model.load_state_dict(torch.load(model_path+'/pytorch_model.bin', map_location='cuda')) # map_location是指定加载到哪个设备
|
@@ -70,7 +70,7 @@ def main(question, context):
|
|
70 |
# 编码输入
|
71 |
text = question + " [SEP] " + context
|
72 |
inputs = tokenizer(text, max_length=4096, truncation=True, return_tensors="pt")
|
73 |
-
inputs.to('cuda')
|
74 |
# 预测答案
|
75 |
outputs = model(**inputs)
|
76 |
start_scores = outputs[0]
|
|
|
59 |
model_id = 'vasudevgupta/bigbird-roberta-natural-questions'
|
60 |
config = BigBirdConfig.from_pretrained(model_id)
|
61 |
model = BigBirdForQuestionAnsweringWithNull(config, model_id)
|
62 |
+
# model.to('cuda')
|
63 |
model.eval()
|
64 |
|
65 |
model.load_state_dict(torch.load(model_path+'/pytorch_model.bin', map_location='cuda')) # map_location是指定加载到哪个设备
|
|
|
70 |
# 编码输入
|
71 |
text = question + " [SEP] " + context
|
72 |
inputs = tokenizer(text, max_length=4096, truncation=True, return_tensors="pt")
|
73 |
+
# inputs.to('cuda')
|
74 |
# 预测答案
|
75 |
outputs = model(**inputs)
|
76 |
start_scores = outputs[0]
|