DHRUV SHEKHAWAT
commited on
Commit
·
8a998a2
1
Parent(s):
b95d471
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
import torch.utils.data
|
6 |
+
from models import *
|
7 |
+
from utils import *
|
8 |
+
st.title("UniLM chatbot")
|
9 |
+
st.subheader("AI language chatbot by Webraft-AI")
|
10 |
+
#Picking what NLP task you want to do
|
11 |
+
|
12 |
+
#Textbox for text user is entering
|
13 |
+
st.subheader("Start the conversation")
|
14 |
+
text2 = st.text_input('Human: ') #text is stored in this variable
|
15 |
+
|
16 |
+
load_checkpoint = True
|
17 |
+
ckpt_path = 'checkpoint_79.pth.tar'
|
18 |
+
with open('WORDMAP_corpus.json', 'r') as j:
|
19 |
+
word_map = json.load(j)
|
20 |
+
|
21 |
+
def evaluate(transformer, question, question_mask, max_len, word_map):
|
22 |
+
"""
|
23 |
+
Performs Greedy Decoding with a batch size of 1
|
24 |
+
"""
|
25 |
+
rev_word_map = {v: k for k, v in word_map.items()}
|
26 |
+
transformer.eval()
|
27 |
+
start_token = word_map['<start>']
|
28 |
+
encoded = transformer.encode(question, question_mask)
|
29 |
+
words = torch.LongTensor([[start_token]]).to(device)
|
30 |
+
|
31 |
+
for step in range(max_len - 1):
|
32 |
+
size = words.shape[1]
|
33 |
+
target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
|
34 |
+
target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0)
|
35 |
+
decoded = transformer.decode(words, target_mask, encoded, question_mask)
|
36 |
+
predictions = transformer.logit(decoded[:, -1])
|
37 |
+
_, next_word = torch.max(predictions, dim = 1)
|
38 |
+
next_word = next_word.item()
|
39 |
+
if next_word == word_map['<end>']:
|
40 |
+
break
|
41 |
+
words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim = 1) # (1,step+2)
|
42 |
+
|
43 |
+
# Construct Sentence
|
44 |
+
if words.dim() == 2:
|
45 |
+
words = words.squeeze(0)
|
46 |
+
words = words.tolist()
|
47 |
+
|
48 |
+
sen_idx = [w for w in words if w not in {word_map['<start>']}]
|
49 |
+
sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
|
50 |
+
|
51 |
+
return sentence
|
52 |
+
|
53 |
+
|
54 |
+
if load_checkpoint:
|
55 |
+
checkpoint = torch.load(ckpt_path)
|
56 |
+
transformer = checkpoint['transformer']
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
question = text2
|
61 |
+
if question == 'quit':
|
62 |
+
break
|
63 |
+
max_len = 128
|
64 |
+
enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()]
|
65 |
+
question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
|
66 |
+
question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1)
|
67 |
+
sentence = evaluate(transformer, question, question_mask, int(max_len), word_map)
|
68 |
+
st.write("UniLM: "+sentence)
|
69 |
+
|