Spaces:
Sleeping
Sleeping
File size: 3,398 Bytes
a5b8812 b2cf88a a5b8812 b2cf88a a5b8812 b2cf88a a5b8812 b2cf88a a930d50 b2cf88a a930d50 b2cf88a a5b8812 b2cf88a a5b8812 b2cf88a a5b8812 b2cf88a a5b8812 b2cf88a a5b8812 b2cf88a a5b8812 b2cf88a a5b8812 b2cf88a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import streamlit as st
import os
import json
import torch
import numpy as np
from utils import ModelWrapper
from sklearn.metrics.pairwise import cosine_similarity
st.title('HRA Document QA')
with st.spinner("Please wait for loading the models"):
model_loader = ModelWrapper()
with st.chat_message("assistant"):
st.write("Hello 👋 I am an HRA chatbot~")
st.write("I know everything about the leadership of HRA.")
st.write("Please ask your questions about the leadership of HRA. For example, you can ask 'Where did Robert Kauffman graduate?', 'What's the position for Fred Danback?' ")
question = st.chat_input("Please ask me some questions about the leadership of HRA:")
if question:
with st.chat_message("assistant"):
st.write("You asked a question:")
with st.chat_message("user"):
st.write(question)
# get the embeddings for the question
question_embeddings = model_loader.get_embeddings(question, 0)
# get the embeddings of all the documents
if 0:
with st.spinner("Please wait for computing the embeddings"):
files = os.listdir("./documents")
document_embeddings = {}
for file in files:
# open document
f = open("./documents/"+file,"r", encoding="utf-8")
f = f.read()
# get the embedding of the document
document_embeddings[file] = model_loader.get_embeddings(f, 1).tolist()
# save the embeddings of all the documents as vector database
with open("./vectors/embeddings.json","w") as outfile:
outfile.write(json.dumps(document_embeddings, indent=4))
embeddings_file = open("./vectors/embeddings.json","r")
document_embeddings = json.load(embeddings_file)
# linear search for the most relevant documnet
max_similarity = -1
most_relevant_document = None
for document in document_embeddings:
cur_similarity = cosine_similarity(question_embeddings, document_embeddings[document])
if cur_similarity > max_similarity:
most_relevant_document = document
max_similarity = cur_similarity
with st.chat_message("assistant"):
if max_similarity < 0.35:
st.write("Sorry we can't find relevant document")
else:
st.write("The most relevant document is:")
st.write(most_relevant_document)
st.write("And the cosine similarity is:" + str(max_similarity))
if max_similarity >= 0.35:
with open("./documents/"+most_relevant_document, "r", encoding="utf-8") as f:
f = f.read()
inputs = model_loader.tokenizer(question, f, return_tensors="pt")
with torch.no_grad():
outputs = model_loader.model_qa(**inputs)
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
predict_answer = model_loader.tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
with st.chat_message("assistant"):
st.write("Answer:")
if predict_answer:
st.write(predict_answer)
else:
st.write(f)
|