Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import openai | |
import PyPDF2 | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain import OpenAI | |
from langchain import VectorDBQA | |
from langchain.document_loaders import UnstructuredFileLoader, UnstructuredPDFLoader | |
from langchain.text_splitter import CharacterTextSplitter | |
import nltk | |
from streamlit_chat import message | |
nltk.download("punkt") | |
def run_query_app(username): | |
openai_api_key = st.sidebar.text_input("OpenAI API Key", key="openai_api_key_input", type="password") | |
uploaded_file = st.file_uploader("Upload a file", type=['txt', 'pdf'], key="file_uploader") | |
if uploaded_file: | |
# Save the uploaded file | |
file_path = os.path.join('./uploaded_files', uploaded_file.name) | |
with open(file_path, "wb") as f: | |
f.write(uploaded_file.read()) | |
# Initialize OpenAIEmbeddings | |
os.environ['OPENAI_API_KEY'] = openai_api_key | |
# Initialize OpenAIEmbeddings | |
embeddings = OpenAIEmbeddings(openai_api_key=os.environ['OPENAI_API_KEY']) | |
# Load the file as document | |
_, ext = os.path.splitext(file_path) | |
if ext == '.txt': | |
loader = UnstructuredFileLoader(file_path) | |
elif ext == '.pdf': | |
loader = UnstructuredPDFLoader(file_path) | |
else: | |
st.write("Unsupported file format.") | |
return | |
documents = loader.load() | |
# Split the documents into texts | |
text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0) | |
texts = text_splitter.split_documents(documents) | |
# Create Chroma vectorstore from documents | |
doc_search = Chroma.from_documents(texts, embeddings) | |
# Initialize VectorDBQA | |
chain = VectorDBQA.from_chain_type(llm=OpenAI(), chain_type="stuff", vectorstore=doc_search) | |
if 'messages' not in st.session_state: | |
st.session_state['messages'] = [] | |
if 'past' not in st.session_state: | |
st.session_state['past'] = [] | |
if 'generated' not in st.session_state: | |
st.session_state['generated'] = [] | |
def update_chat(messages, sender, text): | |
message = {'sender': sender, 'text': text} | |
messages.append(message) | |
return messages | |
def get_response(chain, messages): | |
input_text = [m['text'] for m in messages if m['sender'] == 'user'] | |
result = chain.run(input_text[-1]) | |
return result | |
def get_text(): | |
input_text = st.text_input("You: ", key="input") | |
return input_text | |
query = get_text() | |
user_input = query | |
if st.button("Run Query"): | |
with st.spinner("Generating..."): | |
messages = st.session_state.get('messages', []) | |
messages = update_chat(messages, "user", query) | |
response = get_response(chain, messages) | |
messages = update_chat(messages, "assistant", response) | |
st.session_state['messages'] = messages | |
st.session_state['past'].append(query) | |
st.session_state['generated'].append(response) | |
if uploaded_file is not None: | |
message(f"You are chatting with {uploaded_file.name}. Ask anything about it?") | |
if st.session_state['generated']: | |
for i in range(len(st.session_state['generated']) - 1, -1, -1): | |
message(st.session_state['past'][i], is_user=True, key=str(i) + '_user') | |
message(st.session_state['generated'][i], key=str(i)) | |
with st.expander("Show Messages"): | |
for i, msg in enumerate(st.session_state['messages']): | |
if msg['sender'] == 'user': | |
message("User", msg['text'], key=f"user_{i}") | |
else: | |
message("Assistant", msg['text'], key=f"assistant_{i}") | |
if __name__ == '__main__': | |
run_query_app() | |