DHRUV SHEKHAWAT commited on
Commit
e06f3ec
·
1 Parent(s): 2eacd00

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -31
app.py CHANGED
@@ -1,22 +1,63 @@
 
1
  import streamlit as st
 
2
  import json
3
  import torch
4
  from torch.utils.data import Dataset
5
  import torch.utils.data
6
  from models import *
7
  from utils import *
8
- st.title("UniLM Beta Testing")
9
- st.subheader("AI language chatbot by Webraft-AI")
10
- #Picking what NLP task you want to do
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- #Textbox for text user is entering
13
- st.subheader("Start the conversation")
14
- text2 = st.text_input('Human: ') #text is stored in this variable
15
 
16
- load_checkpoint = True
17
- ckpt_path = 'checkpoint_190.pth.tar'
18
- with open('WORDMAP_corpus.json', 'r') as j:
19
- word_map = json.load(j)
20
 
21
  def evaluate(transformer, question, question_mask, max_len, word_map):
22
  """
@@ -27,35 +68,28 @@ def evaluate(transformer, question, question_mask, max_len, word_map):
27
  start_token = word_map['<start>']
28
  encoded = transformer.encode(question, question_mask)
29
  words = torch.LongTensor([[start_token]]).to(device)
30
-
31
  for step in range(max_len - 1):
32
  size = words.shape[1]
33
  target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
34
  target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0)
35
  decoded = transformer.decode(words, target_mask, encoded, question_mask)
36
  predictions = transformer.logit(decoded[:, -1])
37
- _, next_word = torch.max(predictions, dim = 1)
38
  next_word = next_word.item()
39
  if next_word == word_map['<end>']:
40
  break
41
- words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim = 1) # (1,step+2)
42
-
43
  # Construct Sentence
44
  if words.dim() == 2:
45
  words = words.squeeze(0)
46
  words = words.tolist()
47
-
48
  sen_idx = [w for w in words if w not in {word_map['<start>']}]
49
  sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
50
-
51
- return sentence
52
-
53
-
54
- if load_checkpoint:
55
- checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
56
- transformer = checkpoint['transformer']
57
-
58
 
 
59
  def remove_punc(string):
60
  punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
61
  no_punct = ""
@@ -63,12 +97,75 @@ def remove_punc(string):
63
  if char not in punctuations:
64
  no_punct = no_punct + char # space is also a character
65
  return no_punct.lower()
66
- question = remove_punc(text2)
67
-
68
- max_len = 153
69
- enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()]
70
- question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
71
- question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1)
72
- sentence = evaluate(transformer, question, question_mask, int(max_len), word_map)
73
- st.write("UniLM: "+sentence)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
 
1
+
2
  import streamlit as st
3
+ from streamlit_chat import message
4
  import json
5
  import torch
6
  from torch.utils.data import Dataset
7
  import torch.utils.data
8
  from models import *
9
  from utils import *
10
+ # Setting page title and header
11
+ st.set_page_config(page_title="UniLM", page_icon=":robot_face:")
12
+ st.markdown("<h1 style='text-align: center;'>UniLM</h1>", unsafe_allow_html=True)
13
+
14
+
15
+
16
+ # Initialise session state variables
17
+ if 'generated' not in st.session_state:
18
+ st.session_state['generated'] = []
19
+ if 'past' not in st.session_state:
20
+ st.session_state['past'] = []
21
+ if 'messages' not in st.session_state:
22
+ st.session_state['messages'] = [
23
+ {"role": "system", "content": "You are a helpful assistant."}
24
+ ]
25
+ if 'model_name' not in st.session_state:
26
+ st.session_state['model_name'] = []
27
+ if 'cost' not in st.session_state:
28
+ st.session_state['cost'] = []
29
+ if 'total_tokens' not in st.session_state:
30
+ st.session_state['total_tokens'] = []
31
+ if 'total_cost' not in st.session_state:
32
+ st.session_state['total_cost'] = 1
33
+
34
+ # Sidebar - let user choose model, show total cost of current conversation, and let user clear the current conversation
35
+ st.sidebar.title("Settings")
36
+ model_name = st.sidebar.selectbox("Model:", ("30M_6.1K","NONE"))
37
+ counter_placeholder = st.sidebar.empty()
38
+
39
+ clear_button = st.sidebar.button("Clear Conversation", key="clear")
40
+
41
+ # Map model names to OpenAI model IDs
42
+ if model_name == "30M_6.1K":
43
+ model = "30M_6.1K"
44
+ else:
45
+ model = "gpt-4"
46
+
47
+ # reset everything
48
+ if clear_button:
49
+ st.session_state['generated'] = []
50
+ st.session_state['past'] = []
51
+ st.session_state['messages'] = [
52
+ {"role": "system", "content": "You are a helpful assistant."}
53
+ ]
54
+ st.session_state['number_tokens'] = []
55
+ st.session_state['model_name'] = []
56
+ st.session_state['cost'] = []
57
+ st.session_state['total_cost'] = 0.0
58
+ st.session_state['total_tokens'] = []
59
 
 
 
 
60
 
 
 
 
 
61
 
62
  def evaluate(transformer, question, question_mask, max_len, word_map):
63
  """
 
68
  start_token = word_map['<start>']
69
  encoded = transformer.encode(question, question_mask)
70
  words = torch.LongTensor([[start_token]]).to(device)
71
+
72
  for step in range(max_len - 1):
73
  size = words.shape[1]
74
  target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
75
  target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0)
76
  decoded = transformer.decode(words, target_mask, encoded, question_mask)
77
  predictions = transformer.logit(decoded[:, -1])
78
+ _, next_word = torch.max(predictions, dim=1)
79
  next_word = next_word.item()
80
  if next_word == word_map['<end>']:
81
  break
82
+ words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim=1) # (1,step+2)
83
+
84
  # Construct Sentence
85
  if words.dim() == 2:
86
  words = words.squeeze(0)
87
  words = words.tolist()
88
+
89
  sen_idx = [w for w in words if w not in {word_map['<start>']}]
90
  sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
 
 
 
 
 
 
 
 
91
 
92
+ return sentence
93
  def remove_punc(string):
94
  punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
95
  no_punct = ""
 
97
  if char not in punctuations:
98
  no_punct = no_punct + char # space is also a character
99
  return no_punct.lower()
100
+
101
+ if model_name == "30M_6.1K":
102
+ load_checkpoint = True
103
+ ckpt_path = 'checkpoint_190.pth.tar'
104
+ with open('WORDMAP_corpus.json', 'r') as j:
105
+ word_map = json.load(j)
106
+ if load_checkpoint:
107
+ checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
108
+ transformer = checkpoint['transformer']
109
+ else:
110
+ load_checkpoint = True
111
+ ckpt_path = 'checkpoint_190.pth.tar'
112
+ with open('WORDMAP_corpus.json', 'r') as j:
113
+ word_map = json.load(j)
114
+ if load_checkpoint:
115
+ checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
116
+ transformer = checkpoint['transformer']
117
+
118
+
119
+
120
+ # generate a response
121
+ def generate_response(prompt):
122
+ st.session_state['messages'].append({"role": "user", "content": prompt})
123
+ question = remove_punc(prompt)
124
+
125
+ max_len = 153
126
+ enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()]
127
+ question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
128
+ question_mask = (question != 0).to(device).unsqueeze(1).unsqueeze(1)
129
+ sentence = evaluate(transformer, question, question_mask, int(max_len), word_map)
130
+
131
+ response = sentence
132
+ st.session_state['messages'].append({"role": "assistant", "content": response})
133
+
134
+ # print(st.session_state['messages'])
135
+ total_tokens = "153"
136
+ prompt_tokens = "153"
137
+ completion_tokens = "153"
138
+ return response, total_tokens, prompt_tokens, completion_tokens
139
+
140
+
141
+ # container for chat history
142
+ response_container = st.container()
143
+ # container for text box
144
+ container = st.container()
145
+
146
+ with container:
147
+ with st.form(key='my_form', clear_on_submit=True):
148
+ user_input = st.text_area("You:", key='input', height=2)
149
+ submit_button = st.form_submit_button(label='✉')
150
+
151
+ if submit_button and user_input:
152
+ output, total_tokens, prompt_tokens, completion_tokens = generate_response(user_input)
153
+ st.session_state['past'].append(user_input)
154
+ st.session_state['generated'].append(output)
155
+ st.session_state['model_name'].append(model_name)
156
+ st.session_state['total_tokens'].append(total_tokens)
157
+
158
+ # from https://openai.com/pricing#language-models
159
+ if model_name == "30M_6.1K":
160
+ cost = "1"
161
+ else:
162
+ cost = "2"
163
+
164
+
165
+
166
+ if st.session_state['generated']:
167
+ with response_container:
168
+ for i in range(len(st.session_state['generated'])):
169
+ message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
170
+ message(st.session_state["generated"][i], key=str(i))
171