pdfchat / app.py
markytools's picture
Update app.py
161c0d8
raw
history blame
7.68 kB
import streamlit as st
from tempfile import NamedTemporaryFile
import pprint
import google.generativeai as palm
import os
from dotenv import load_dotenv, find_dotenv
from langchain.embeddings import GooglePalmEmbeddings
from langchain.llms import GooglePalm
from langchain.document_loaders import UnstructuredURLLoader #load urls into docoument-loader
from langchain.chains.question_answering import load_qa_chain
from langchain.indexes import VectorstoreIndexCreator #vectorize db index with chromadb
from langchain.text_splitter import CharacterTextSplitter #text splitter
from langchain.chains import RetrievalQA
from langchain.document_loaders import UnstructuredPDFLoader #load pdf
from langchain.agents import create_pandas_dataframe_agent
# from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
import pandas as pd
import numpy as np
import pprint
isPswdValid = False
try:
pswdVal = st.experimental_get_query_params()['pwd'][0]
if pswdVal==st.secrets["PSWD"]:
isPswdValid = True
except:
pass
if not isPswdValid:
st.write("Invalid Password")
else:
radioButtonList = ["E-commerce CSV (https://www.kaggle.com/datasets/mervemenekse/ecommerce-dataset)",
"Upload my own CSV",
"Upload my own PDF",
"URL Chat with Google's Latest Earnings (https://abc.xyz/investor/)",
"Enter my own URL"]
# Add some designs to the radio buttons
st.markdown("""
<style>
.stRadio {
padding: 10px;
border-radius: 5px;
background-color: #f5f5f5;
}
.stRadio input[type="radio"] {
position: absolute;
opacity: 0;
cursor: pointer;
}
.stRadio label {
display: flex;
justify-content: center;
align-items: center;
cursor: pointer;
font-size: 16px;
color: #333;
}
.stRadio label:hover {
color: #000;
}
.stRadio.st-selected input[type="radio"] ~ label {
color: #000;
background-color: #d9d9d9;
}
</style>
""", unsafe_allow_html=True)
genre = st.radio(
"Tired of reading your files? Chat with it using AI! Choose dataset to finetune", radioButtonList, index=0
)
# Initialize language model
load_dotenv(find_dotenv()) # read local .env file
api_key = st.secrets["PALM_API_KEY"] # put your API key here
os.environ["GOOGLE_API_KEY"] = st.secrets["PALM_API_KEY"]
palm.configure(api_key=api_key)
llm = GooglePalm()
llm.temperature = 0.1
pdfCSVURLText = ""
if genre==radioButtonList[0]:
pdfCSVURLText = "CSV"
exampleQuestion = "Question1: What was the most sold item? Question2: What was the most common payment?"
dataDF = pd.read_csv('EcommerceDataset.csv', encoding= 'unicode_escape')
# st.write('You selected comedy.')
# else:
# st.write(f'''Password streamlit app: {st.secrets["PSWD"]}''')
elif genre==radioButtonList[1]:
pdfCSVURLText = "CSV"
exampleQuestion = "What are the data columns?"
elif genre==radioButtonList[2]:
pdfCSVURLText = "PDF"
exampleQuestion = "Can you summarize the contents?"
elif genre==radioButtonList[3]:
pdfCSVURLText = "URL"
exampleQuestion = "What is Google's latest earnings?"
urls = ['https://abc.xyz/investor/']
loader = [UnstructuredURLLoader(urls=urls)]
index = VectorstoreIndexCreator(
embedding=GooglePalmEmbeddings(),
text_splitter=CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)).from_loaders(loader)
chain = RetrievalQA.from_chain_type(llm=llm,
chain_type="stuff",
retriever=index.vectorstore.as_retriever(),
input_key="question")
elif genre==radioButtonList[4]:
pdfCSVURLText = "URL"
exampleQuestion = "Can you summarize the contents?"
isCustomURL = genre==radioButtonList[4]
urlInput = st.text_input('Enter your own URL', '', placeholder="Type your URL here (e.g. https://abc.xyz/investor/)", disabled=not isCustomURL)
isCustomPDF = genre==radioButtonList[1] or genre==radioButtonList[2]
uploaded_file = st.file_uploader(f"Upload your own {pdfCSVURLText} here", type=pdfCSVURLText.lower(), disabled=not isCustomPDF)
uploadedFilename = ""
if uploaded_file is not None:
with NamedTemporaryFile(dir='.', suffix=f'.{pdfCSVURLText.lower()}') as f:
f.write(uploaded_file.getbuffer())
uploadedFilename = f.name
if genre==radioButtonList[1]: # Custom CSV Upload
dataDF = pd.read_csv(uploadedFilename, encoding= 'unicode_escape')
elif genre==radioButtonList[2]: # Custom PDF Upload
pdf_loaders = [UnstructuredPDFLoader(uploadedFilename)]
pdf_index = VectorstoreIndexCreator(
embedding=GooglePalmEmbeddings(),
text_splitter=CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)).from_loaders(pdf_loaders)
pdf_chain = RetrievalQA.from_chain_type(llm=llm,
chain_type="stuff",
retriever=pdf_index.vectorstore.as_retriever(),
input_key="question")
enableChatBox = False
if genre==radioButtonList[0]: # E-commerce CSV
enableChatBox = True
elif genre==radioButtonList[1]: # Custom CSV Upload
enableChatBox = uploadedFilename[-4:]==".csv"
elif genre==radioButtonList[2]: # Custom PDF Upload
enableChatBox = uploadedFilename[-4:]==".pdf"
elif genre==radioButtonList[3]: # Google Alphabet URL Earnings Report
enableChatBox = True
elif genre==radioButtonList[4]: # Custom URL
enableChatBox = True
chatTextStr = st.text_input(f'Ask me anything about this {pdfCSVURLText}', '', placeholder=f"Type here (e.g. {exampleQuestion})", disabled=not enableChatBox)
chatWithPDFButton = "CLICK HERE TO START CHATTING"
if st.button(chatWithPDFButton, disabled=not enableChatBox and not chatTextStr): # Button Cliked
if genre==radioButtonList[0]: # E-commerce CSV
# Initializing the agent
agent = create_pandas_dataframe_agent(llm, dataDF, verbose=False)
answer = agent.run(chatTextStr)
st.write(answer)
elif genre==radioButtonList[1]: # Custom CSV Upload
# Initializing the agent
agent = create_pandas_dataframe_agent(llm, dataDF, verbose=False)
answer = agent.run(chatTextStr)
st.write(answer)
elif genre==radioButtonList[2]: # Custom PDF Upload
pdf_answer = pdf_chain.run(chatTextStr)
st.write(pdf_answer)
elif genre==radioButtonList[3]: # Google Alphabet URL Earnings Report
answer = chain.run(chatTextStr)
st.write(answer)
elif genre==radioButtonList[4]: # Custom URL
urls = [urlInput]
loader = [UnstructuredURLLoader(urls=urls)]
index = VectorstoreIndexCreator(
embedding=GooglePalmEmbeddings(),
text_splitter=CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)).from_loaders(loader)
chain = RetrievalQA.from_chain_type(llm=llm,
chain_type="stuff",
retriever=index.vectorstore.as_retriever(),
input_key="question")
answer = chain.run(chatTextStr)
st.write(answer)