File size: 1,682 Bytes
e15f1e0
9a229a7
993e75e
75eaa7d
dd97cd7
c592be6
1c8cf8d
993e75e
c592be6
89057d0
ad59b0f
e15f1e0
32a4ae2
 
 
9a229a7
32a4ae2
 
9a229a7
32a4ae2
9a229a7
32a4ae2
9a229a7
32a4ae2
 
 
01c2292
0a1beeb
8521de0
 
 
 
 
 
 
 
 
01c2292
8521de0
 
01c2292
8521de0
 
 
 
9a229a7
32a4ae2
 
 
 
 
 
 
 
e15f1e0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gradio as gr
from transformers import BertForQuestionAnswering
from transformers import BertTokenizerFast
import torch
from nltk.tokenize import word_tokenize
import timm

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
#model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
model = BertForQuestionAnswering.from_pretrained("CountingMstar/ai-tutor-bert-model")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_prediction(context, question):
  inputs = tokenizer.encode_plus(question, context, return_tensors='pt').to(device)
  outputs = model(**inputs)
  
  answer_start = torch.argmax(outputs[0])  
  answer_end = torch.argmax(outputs[1]) + 1 
  
  answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))
  
  return answer
  
def question_answer(context, question):
  prediction = get_prediction(context,question)
  return prediction

def split(text):
    context, question = '', ''
    act = False
    tmp = ''
    for t in text:
        tmp += t
        if len(tmp) == 4:
            tmp = tmp[1:]
            if tmp == '///':
                act = True

        if act == True:
            question += t

        if act == False:
            context += t
        
    return context[:-2], question[1:]
    
def greet(texts):
    context, question = split(texts)
    answer = question_answer(context, question)
    return answer
# def greet(text):
#     context, question = split(text)
#     # answer = question_answer(context, question)
#     return context

iface = gr.Interface(fn=greet, inputs="text", outputs="text")
iface.launch()