import streamlit as st import json import torch from torch.utils.data import Dataset import torch.utils.data from models import * from utils import * st.title("UniLM chatbot") st.subheader("AI language chatbot by Webraft-AI") #Picking what NLP task you want to do #Textbox for text user is entering st.subheader("Start the conversation") text2 = st.text_input('Human: ') #text is stored in this variable load_checkpoint = True ckpt_path = 'checkpoint_9.pth.tar' with open('WORDMAP_corpus.json', 'r') as j: word_map = json.load(j) def evaluate(transformer, question, question_mask, max_len, word_map): """ Performs Greedy Decoding with a batch size of 1 """ rev_word_map = {v: k for k, v in word_map.items()} transformer.eval() start_token = word_map[''] encoded = transformer.encode(question, question_mask) words = torch.LongTensor([[start_token]]).to(device) for step in range(max_len - 1): size = words.shape[1] target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8) target_mask = target_mask.to(device).unsqueeze(0).unsqueeze(0) decoded = transformer.decode(words, target_mask, encoded, question_mask) predictions = transformer.logit(decoded[:, -1]) _, next_word = torch.max(predictions, dim = 1) next_word = next_word.item() if next_word == word_map['']: break words = torch.cat([words, torch.LongTensor([[next_word]]).to(device)], dim = 1) # (1,step+2) # Construct Sentence if words.dim() == 2: words = words.squeeze(0) words = words.tolist() sen_idx = [w for w in words if w not in {word_map['']}] sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))]) return sentence if load_checkpoint: checkpoint = torch.load(ckpt_path) transformer = checkpoint['transformer'] def remove_punc(string): punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~''' no_punct = "" for char in string: if char not in punctuations: no_punct = no_punct + char # space is also a character return no_punct.lower() question = remove_punc(text2) max_len = 128 enc_qus = [word_map.get(word, word_map['']) for word in question.split()] question = torch.LongTensor(enc_qus).to(device).unsqueeze(0) question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1) sentence = evaluate(transformer, question, question_mask, int(max_len), word_map) st.write("UniLM: "+sentence)