|
import streamlit as st |
|
import json |
|
import torch |
|
from torch.utils.data import Dataset |
|
import torch.utils.data |
|
from models import * |
|
from utils import * |
|
st.title("UniLM chatbot") |
|
st.subheader("AI language chatbot by Webraft-AI") |
|
|
|
|
|
|
|
st.subheader("Start the conversation") |
|
text2 = st.text_input('Human: ') |
|
|
|
load_checkpoint = True |
|
ckpt_path = 'checkpoint_9.pth.tar' |
|
with open('WORDMAP_corpus.json', 'r') as j: |
|
word_map = json.load(j) |
|
|
|
def evaluate(transformer, question, question_mask, max_len, word_map): |
|
""" |
|
Performs Greedy Decoding with a batch size of 1 |
|
""" |
|
rev_word_map = {v: k for k, v in word_map.items()} |
|
transformer.eval() |
|
start_token = word_map['<start>'] |
|
encoded = transformer.encode(question, question_mask) |
|
words = torch.LongTensor([[start_token]]).to(device) |
|
|
|
for step in range(max_len - 1): |
|
size = words.shape[1] |
|
target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8) |
|
target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0) |
|
decoded = transformer.decode(words, target_mask, encoded, question_mask) |
|
predictions = transformer.logit(decoded[:, -1]) |
|
_, next_word = torch.max(predictions, dim = 1) |
|
next_word = next_word.item() |
|
if next_word == word_map['<end>']: |
|
break |
|
words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim = 1) |
|
|
|
|
|
if words.dim() == 2: |
|
words = words.squeeze(0) |
|
words = words.tolist() |
|
|
|
sen_idx = [w for w in words if w not in {word_map['<start>']}] |
|
sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))]) |
|
|
|
return sentence |
|
|
|
|
|
if load_checkpoint: |
|
checkpoint = torch.load(ckpt_path) |
|
transformer = checkpoint['transformer'] |
|
|
|
|
|
def remove_punc(string): |
|
punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~''' |
|
no_punct = "" |
|
for char in string: |
|
if char not in punctuations: |
|
no_punct = no_punct + char |
|
return no_punct.lower() |
|
question = remove_punc(text2) |
|
|
|
max_len = 128 |
|
enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()] |
|
question = torch.LongTensor(enc_qus).to(device).unsqueeze(0) |
|
question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1) |
|
sentence = evaluate(transformer, question, question_mask, int(max_len), word_map) |
|
st.write("UniLM: "+sentence) |
|
|
|
|