File size: 4,011 Bytes
19aa68f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
import os
import openai
import PyPDF2
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain import OpenAI
from langchain import VectorDBQA
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader
from langchain.text_splitter import CharacterTextSplitter
import nltk
from streamlit_chat import message

nltk.download("punkt")


def run_query_app(username):
    openai_api_key = st.sidebar.text_input("OpenAI API Key", key="openai_api_key_input", type="password")

    uploaded_file = st.file_uploader("Upload a file", type=['txt', 'pdf'], key="file_uploader")
    if uploaded_file:
        # Save the uploaded file
        file_path = os.path.join('./uploaded_files', uploaded_file.name)
        with open(file_path, "wb") as f:
            f.write(uploaded_file.read())

        # Initialize OpenAIEmbeddings

        os.environ['OPENAI_API_KEY'] = openai_api_key

        # Initialize OpenAIEmbeddings
        embeddings = OpenAIEmbeddings(openai_api_key=os.environ['OPENAI_API_KEY'])

        # Load the file as document
        _, ext = os.path.splitext(file_path)
        if ext == '.txt':
            loader = UnstructuredFileLoader(file_path)
        elif ext == '.pdf':
            loader = UnstructuredPDFLoader(file_path)
        else:
            st.write("Unsupported file format.")
            return

        documents = loader.load()

        # Split the documents into texts
        text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0)
        texts = text_splitter.split_documents(documents)

        # Create Chroma vectorstore from documents
        doc_search = Chroma.from_documents(texts, embeddings)

        # Initialize VectorDBQA
        chain = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type="stuff", vectorstore=doc_search)

        if 'messages' not in st.session_state:
            st.session_state['messages'] = []

        if 'past' not in st.session_state:
            st.session_state['past'] = []

        if 'generated' not in st.session_state:
            st.session_state['generated'] = []

        def update_chat(messages, sender, text):
            message = {'sender': sender, 'text': text}
            messages.append(message)
            return messages

        def get_response(chain, messages):
            input_text = [m['text'] for m in messages if m['sender'] == 'user']
            result = chain.run(input_text[-1])
            return result

        def get_text():
            input_text = st.text_input("You: ", key="input")
            return input_text

        query = get_text()
        user_input = query

        if st.button("Run Query"):
            with st.spinner("Generating..."):
                messages = st.session_state.get('messages', [])
                messages = update_chat(messages, "user", query)
                response = get_response(chain, messages)
                messages = update_chat(messages, "assistant", response)
                st.session_state['messages'] = messages
                st.session_state['past'].append(query)
                st.session_state['generated'].append(response)
        if uploaded_file is not None:
            message(f"You are chatting with {uploaded_file.name}. Ask anything about it?")
        if st.session_state['generated']:

            for i in range(len(st.session_state['generated']) - 1, -1, -1):
                message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
                message(st.session_state['generated'][i], key=str(i))

            with st.expander("Show Messages"):
                for i, msg in enumerate(st.session_state['messages']):
                    if msg['sender'] == 'user':
                        message("User", msg['text'], key=f"user_{i}")
                    else:
                        message("Assistant", msg['text'], key=f"assistant_{i}")

    if __name__ == '__main__':
        run_query_app()