czq commited on
Commit
e9d72aa
·
1 Parent(s): c26953b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -59,7 +59,7 @@ class BigBirdForQuestionAnsweringWithNull(PreTrainedModel):
59
  model_id = 'vasudevgupta/bigbird-roberta-natural-questions'
60
  config = BigBirdConfig.from_pretrained(model_id)
61
  model = BigBirdForQuestionAnsweringWithNull(config, model_id)
62
- model.to('cuda')
63
  model.eval()
64
 
65
  model.load_state_dict(torch.load(model_path+'/pytorch_model.bin', map_location='cuda')) # map_location是指定加载到哪个设备
@@ -70,7 +70,7 @@ def main(question, context):
70
  # 编码输入
71
  text = question + " [SEP] " + context
72
  inputs = tokenizer(text, max_length=4096, truncation=True, return_tensors="pt")
73
- inputs.to('cuda')
74
  # 预测答案
75
  outputs = model(**inputs)
76
  start_scores = outputs[0]
 
59
  model_id = 'vasudevgupta/bigbird-roberta-natural-questions'
60
  config = BigBirdConfig.from_pretrained(model_id)
61
  model = BigBirdForQuestionAnsweringWithNull(config, model_id)
62
+ # model.to('cuda')
63
  model.eval()
64
 
65
  model.load_state_dict(torch.load(model_path+'/pytorch_model.bin', map_location='cuda')) # map_location是指定加载到哪个设备
 
70
  # 编码输入
71
  text = question + " [SEP] " + context
72
  inputs = tokenizer(text, max_length=4096, truncation=True, return_tensors="pt")
73
+ # inputs.to('cuda')
74
  # 预测答案
75
  outputs = model(**inputs)
76
  start_scores = outputs[0]