File size: 3,398 Bytes
a5b8812
b2cf88a
 
a5b8812
b2cf88a
 
 
a5b8812
 
 
b2cf88a
 
 
a5b8812
b2cf88a
 
 
a930d50
b2cf88a
a930d50
b2cf88a
 
 
 
 
 
a5b8812
b2cf88a
 
 
a5b8812
b2cf88a
 
 
 
 
 
 
 
a5b8812
b2cf88a
 
 
 
 
 
 
 
 
a5b8812
b2cf88a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5b8812
b2cf88a
a5b8812
 
 
 
 
b2cf88a
 
 
 
 
 
 
 
a5b8812
b2cf88a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
import os
import json
import torch
import numpy as np
from utils import ModelWrapper
from sklearn.metrics.pairwise import cosine_similarity

st.title('HRA Document QA')

with st.spinner("Please wait for loading the models"):
    model_loader = ModelWrapper()


with st.chat_message("assistant"):
    st.write("Hello 👋 I am an HRA chatbot~")
    st.write("I know everything about the leadership of HRA.")
    st.write("Please ask your questions about the leadership of HRA. For example, you can ask 'Where did Robert Kauffman graduate?', 'What's the position for Fred Danback?' ")

question = st.chat_input("Please ask me some questions about the leadership of HRA:")

if question:
    with st.chat_message("assistant"):
        st.write("You asked a question:")
    with st.chat_message("user"):
        st.write(question)
    
    # get the embeddings for the question
    question_embeddings = model_loader.get_embeddings(question, 0)
    # get the embeddings of all the documents
    
    if 0:
        with st.spinner("Please wait for computing the embeddings"):
            files = os.listdir("./documents")
            document_embeddings = {}
            for file in files:
                # open document
                f = open("./documents/"+file,"r", encoding="utf-8")
                f = f.read()
    
                # get the embedding of the document
                document_embeddings[file] = model_loader.get_embeddings(f, 1).tolist()
    
        # save the embeddings of all the documents as vector database
        with open("./vectors/embeddings.json","w") as outfile:
            outfile.write(json.dumps(document_embeddings, indent=4))
    
    embeddings_file = open("./vectors/embeddings.json","r")
    document_embeddings = json.load(embeddings_file)

    # linear search for the most relevant documnet
    max_similarity = -1
    most_relevant_document = None
    for document in document_embeddings:
        cur_similarity = cosine_similarity(question_embeddings, document_embeddings[document])
        if cur_similarity > max_similarity:
            most_relevant_document = document
            max_similarity = cur_similarity
    
    with st.chat_message("assistant"):
        if max_similarity < 0.35:
            st.write("Sorry we can't find relevant document")
        else:
            st.write("The most relevant document is:")
            st.write(most_relevant_document)
            st.write("And the cosine similarity is:" + str(max_similarity))
    
    if max_similarity >= 0.35:       
        with open("./documents/"+most_relevant_document, "r", encoding="utf-8") as f:
            f = f.read()
            inputs = model_loader.tokenizer(question, f, return_tensors="pt")
            with torch.no_grad():
                outputs = model_loader.model_qa(**inputs)

            answer_start_index = outputs.start_logits.argmax()
            answer_end_index = outputs.end_logits.argmax()

            predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
            predict_answer = model_loader.tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)

            with st.chat_message("assistant"):
                st.write("Answer:")
                if predict_answer:
                    st.write(predict_answer)
                else:
                    st.write(f)