import streamlit as st import os import json import torch import numpy as np from utils import ModelWrapper from sklearn.metrics.pairwise import cosine_similarity st.title('HRA Document QA') with st.spinner("Please wait for loading the models"): model_loader = ModelWrapper() with st.chat_message("assistant"): st.write("Hello 👋 I am an HRA chatbot~") st.write("I know everything about the leadership of HRA.") st.write("Please ask your questions about the leadership of HRA. For example, you can ask 'Where did Robert Kauffman graduate?', 'What's the position for Fred Danback?' ") question = st.chat_input("Please ask me some questions about the leadership of HRA:") if question: with st.chat_message("assistant"): st.write("You asked a question:") with st.chat_message("user"): st.write(question) # get the embeddings for the question question_embeddings = model_loader.get_embeddings(question, 0) # get the embeddings of all the documents if 0: with st.spinner("Please wait for computing the embeddings"): files = os.listdir("./documents") document_embeddings = {} for file in files: # open document f = open("./documents/"+file,"r", encoding="utf-8") f = f.read() # get the embedding of the document document_embeddings[file] = model_loader.get_embeddings(f, 1).tolist() # save the embeddings of all the documents as vector database with open("./vectors/embeddings.json","w") as outfile: outfile.write(json.dumps(document_embeddings, indent=4)) embeddings_file = open("./vectors/embeddings.json","r") document_embeddings = json.load(embeddings_file) # linear search for the most relevant documnet max_similarity = -1 most_relevant_document = None for document in document_embeddings: cur_similarity = cosine_similarity(question_embeddings, document_embeddings[document]) if cur_similarity > max_similarity: most_relevant_document = document max_similarity = cur_similarity with st.chat_message("assistant"): if max_similarity < 0.35: st.write("Sorry we can't find relevant document") else: st.write("The most relevant document is:") st.write(most_relevant_document) st.write("And the cosine similarity is:" + str(max_similarity)) if max_similarity >= 0.35: with open("./documents/"+most_relevant_document, "r", encoding="utf-8") as f: f = f.read() inputs = model_loader.tokenizer(question, f, return_tensors="pt") with torch.no_grad(): outputs = model_loader.model_qa(**inputs) answer_start_index = outputs.start_logits.argmax() answer_end_index = outputs.end_logits.argmax() predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1] predict_answer = model_loader.tokenizer.decode(predict_answer_tokens, skip_special_tokens=True) with st.chat_message("assistant"): st.write("Answer:") if predict_answer: st.write(predict_answer) else: st.write(f)