chatlm / app.py
DHRUV SHEKHAWAT
Update app.py
de30062
raw
history blame
2.57 kB
import streamlit as st
import json
import torch
from torch.utils.data import Dataset
import torch.utils.data
from models import *
from utils import *
st.title("UniLM chatbot")
st.subheader("AI language chatbot by Webraft-AI")
#Picking what NLP task you want to do
#Textbox for text user is entering
st.subheader("Start the conversation")
text2 = st.text_input('Human: ') #text is stored in this variable
load_checkpoint = True
ckpt_path = 'checkpoint_9.pth.tar'
with open('WORDMAP_corpus.json', 'r') as j:
word_map = json.load(j)
def evaluate(transformer, question, question_mask, max_len, word_map):
"""
Performs Greedy Decoding with a batch size of 1
"""
rev_word_map = {v: k for k, v in word_map.items()}
transformer.eval()
start_token = word_map['<start>']
encoded = transformer.encode(question, question_mask)
words = torch.LongTensor([[start_token]]).to(device)
for step in range(max_len - 1):
size = words.shape[1]
target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0)
decoded = transformer.decode(words, target_mask, encoded, question_mask)
predictions = transformer.logit(decoded[:, -1])
_, next_word = torch.max(predictions, dim = 1)
next_word = next_word.item()
if next_word == word_map['<end>']:
break
words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim = 1) # (1,step+2)
# Construct Sentence
if words.dim() == 2:
words = words.squeeze(0)
words = words.tolist()
sen_idx = [w for w in words if w not in {word_map['<start>']}]
sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
return sentence
if load_checkpoint:
checkpoint = torch.load(ckpt_path)
transformer = checkpoint['transformer']
def remove_punc(string):
punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
no_punct = ""
for char in string:
if char not in punctuations:
no_punct = no_punct + char # space is also a character
return no_punct.lower()
question = remove_punc(text2)
max_len = 128
enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()]
question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1)
sentence = evaluate(transformer, question, question_mask, int(max_len), word_map)
st.write("UniLM: "+sentence)