Spaces:
Paused
Paused
File size: 1,682 Bytes
e15f1e0 9a229a7 993e75e 75eaa7d dd97cd7 c592be6 1c8cf8d 993e75e c592be6 89057d0 ad59b0f e15f1e0 32a4ae2 9a229a7 32a4ae2 9a229a7 32a4ae2 9a229a7 32a4ae2 9a229a7 32a4ae2 01c2292 0a1beeb 8521de0 01c2292 8521de0 01c2292 8521de0 9a229a7 32a4ae2 e15f1e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import gradio as gr
from transformers import BertForQuestionAnswering
from transformers import BertTokenizerFast
import torch
from nltk.tokenize import word_tokenize
import timm
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
#model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
model = BertForQuestionAnswering.from_pretrained("CountingMstar/ai-tutor-bert-model")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def get_prediction(context, question):
inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device)
outputs = model(**inputs)
answer_start = torch.argmax(outputs[0])
answer_end = torch.argmax(outputs[1]) + 1
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
return answer
def question_answer(context, question):
prediction = get_prediction(context,question)
return prediction
def split(text):
context, question = '', ''
act = False
tmp = ''
for t in text:
tmp += t
if len(tmp) == 4:
tmp = tmp[1:]
if tmp == '///':
act = True
if act == True:
question += t
if act == False:
context += t
return context[:-2], question[1:]
def greet(texts):
context, question = split(texts)
answer = question_answer(context, question)
return answer
# def greet(text):
# context, question = split(text)
# # answer = question_answer(context, question)
# return context
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
iface.launch() |