CountingMstar commited on
Commit
1c9cd05
1 Parent(s): c592be6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -10,20 +10,20 @@ model = timm.create_model('hf_hub:pseudolab/AI_Tutor_BERT', pretrained=True)
10
  #model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
- def get_prediction(context, question):
14
- inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device)
15
- outputs = model(**inputs)
16
 
17
- answer_start = torch.argmax(outputs[0])
18
- answer_end = torch.argmax(outputs[1]) + 1
19
 
20
- answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
21
 
22
- return answer
23
 
24
- def question_answer(context, question):
25
- prediction = get_prediction(context,question)
26
- return prediction
27
 
28
  def split(text):
29
  context, question = '', ''
 
10
  #model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
 
13
+ # def get_prediction(context, question):
14
+ # inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device)
15
+ # outputs = model(**inputs)
16
 
17
+ # answer_start = torch.argmax(outputs[0])
18
+ # answer_end = torch.argmax(outputs[1]) + 1
19
 
20
+ # answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
21
 
22
+ # return answer
23
 
24
+ # def question_answer(context, question):
25
+ # prediction = get_prediction(context,question)
26
+ # return prediction
27
 
28
  def split(text):
29
  context, question = '', ''