File size: 5,952 Bytes
e06f3ec 8a998a2 e06f3ec 8a998a2 e06f3ec 8a998a2 e06f3ec 8a998a2 e06f3ec 8a998a2 e06f3ec 8a998a2 e06f3ec 8a998a2 e06f3ec 6817b3e e06f3ec 8a998a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import streamlit as st
from streamlit_chat import message
import json
import torch
from torch.utils.data import Dataset
import torch.utils.data
from models import *
from utils import *
# Setting page title and header
st.set_page_config(page_title="UniLM", page_icon=":robot_face:")
st.markdown("<h1 style='text-align: center;'>UniLM</h1>", unsafe_allow_html=True)
# Initialise session state variables
if 'generated' not in st.session_state:
st.session_state['generated'] = []
if 'past' not in st.session_state:
st.session_state['past'] = []
if 'messages' not in st.session_state:
st.session_state['messages'] = [
{"role": "system", "content": "You are a helpful assistant."}
]
if 'model_name' not in st.session_state:
st.session_state['model_name'] = []
if 'cost' not in st.session_state:
st.session_state['cost'] = []
if 'total_tokens' not in st.session_state:
st.session_state['total_tokens'] = []
if 'total_cost' not in st.session_state:
st.session_state['total_cost'] = 1
# Sidebar - let user choose model, show total cost of current conversation, and let user clear the current conversation
st.sidebar.title("Settings")
model_name = st.sidebar.selectbox("Model:", ("30M_6.1K","NONE"))
counter_placeholder = st.sidebar.empty()
clear_button = st.sidebar.button("Clear Conversation", key="clear")
# Map model names to OpenAI model IDs
if model_name == "30M_6.1K":
model = "30M_6.1K"
else:
model = "gpt-4"
# reset everything
if clear_button:
st.session_state['generated'] = []
st.session_state['past'] = []
st.session_state['messages'] = [
{"role": "system", "content": "You are a helpful assistant."}
]
st.session_state['number_tokens'] = []
st.session_state['model_name'] = []
st.session_state['cost'] = []
st.session_state['total_cost'] = 0.0
st.session_state['total_tokens'] = []
def evaluate(transformer, question, question_mask, max_len, word_map):
"""
Performs Greedy Decoding with a batch size of 1
"""
rev_word_map = {v: k for k, v in word_map.items()}
transformer.eval()
start_token = word_map['<start>']
encoded = transformer.encode(question, question_mask)
words = torch.LongTensor([[start_token]]).to(device)
for step in range(max_len - 1):
size = words.shape[1]
target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0)
decoded = transformer.decode(words, target_mask, encoded, question_mask)
predictions = transformer.logit(decoded[:, -1])
_, next_word = torch.max(predictions, dim=1)
next_word = next_word.item()
if next_word == word_map['<end>']:
break
words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim=1) # (1,step+2)
# Construct Sentence
if words.dim() == 2:
words = words.squeeze(0)
words = words.tolist()
sen_idx = [w for w in words if w not in {word_map['<start>']}]
sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
return sentence
def remove_punc(string):
punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
no_punct = ""
for char in string:
if char not in punctuations:
no_punct = no_punct + char # space is also a character
return no_punct.lower()
if model_name == "30M_6.1K":
load_checkpoint = True
ckpt_path = 'checkpoint_190.pth.tar'
with open('WORDMAP_corpus.json', 'r') as j:
word_map = json.load(j)
if load_checkpoint:
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
transformer = checkpoint['transformer']
else:
load_checkpoint = True
ckpt_path = 'checkpoint_190.pth.tar'
with open('WORDMAP_corpus.json', 'r') as j:
word_map = json.load(j)
if load_checkpoint:
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
transformer = checkpoint['transformer']
# generate a response
def generate_response(prompt):
st.session_state['messages'].append({"role": "user", "content": prompt})
question = remove_punc(prompt)
max_len = 153
enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()]
question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
question_mask = (question != 0).to(device).unsqueeze(1).unsqueeze(1)
sentence = evaluate(transformer, question, question_mask, int(max_len), word_map)
response = sentence
st.session_state['messages'].append({"role": "assistant", "content": response})
# print(st.session_state['messages'])
total_tokens = "153"
prompt_tokens = "153"
completion_tokens = "153"
return response, total_tokens, prompt_tokens, completion_tokens
# container for chat history
response_container = st.container()
# container for text box
container = st.container()
with container:
with st.form(key='my_form', clear_on_submit=True):
user_input = st.text_area("You:", key='input', height=2)
submit_button = st.form_submit_button(label='✉')
if submit_button and user_input:
output, total_tokens, prompt_tokens, completion_tokens = generate_response(user_input)
st.session_state['past'].append(user_input)
st.session_state['generated'].append(output)
st.session_state['model_name'].append(model_name)
st.session_state['total_tokens'].append(total_tokens)
# from https://openai.com/pricing#language-models
if model_name == "30M_6.1K":
cost = "1"
else:
cost = "2"
if st.session_state['generated']:
with response_container:
for i in range(len(st.session_state['generated'])):
message(st.session_state["past"][i], is_user=True, key=str(i) + '_user')
message(st.session_state["generated"][i], key=str(i))
|