DHRUV SHEKHAWAT commited on
Commit
8a998a2
·
1 Parent(s): b95d471

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import json
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import torch.utils.data
6
+ from models import *
7
+ from utils import *
8
+ st.title("UniLM chatbot")
9
+ st.subheader("AI language chatbot by Webraft-AI")
10
+ #Picking what NLP task you want to do
11
+
12
+ #Textbox for text user is entering
13
+ st.subheader("Start the conversation")
14
+ text2 = st.text_input('Human: ') #text is stored in this variable
15
+
16
+ load_checkpoint = True
17
+ ckpt_path = 'checkpoint_79.pth.tar'
18
+ with open('WORDMAP_corpus.json', 'r') as j:
19
+ word_map = json.load(j)
20
+
21
+ def evaluate(transformer, question, question_mask, max_len, word_map):
22
+ """
23
+ Performs Greedy Decoding with a batch size of 1
24
+ """
25
+ rev_word_map = {v: k for k, v in word_map.items()}
26
+ transformer.eval()
27
+ start_token = word_map['<start>']
28
+ encoded = transformer.encode(question, question_mask)
29
+ words = torch.LongTensor([[start_token]]).to(device)
30
+
31
+ for step in range(max_len - 1):
32
+ size = words.shape[1]
33
+ target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
34
+ target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0)
35
+ decoded = transformer.decode(words, target_mask, encoded, question_mask)
36
+ predictions = transformer.logit(decoded[:, -1])
37
+ _, next_word = torch.max(predictions, dim = 1)
38
+ next_word = next_word.item()
39
+ if next_word == word_map['<end>']:
40
+ break
41
+ words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim = 1) # (1,step+2)
42
+
43
+ # Construct Sentence
44
+ if words.dim() == 2:
45
+ words = words.squeeze(0)
46
+ words = words.tolist()
47
+
48
+ sen_idx = [w for w in words if w not in {word_map['<start>']}]
49
+ sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
50
+
51
+ return sentence
52
+
53
+
54
+ if load_checkpoint:
55
+ checkpoint = torch.load(ckpt_path)
56
+ transformer = checkpoint['transformer']
57
+
58
+
59
+
60
+ question = text2
61
+ if question == 'quit':
62
+ break
63
+ max_len = 128
64
+ enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()]
65
+ question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
66
+ question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1)
67
+ sentence = evaluate(transformer, question, question_mask, int(max_len), word_map)
68
+ st.write("UniLM: "+sentence)
69
+