study-sherlock / pages /chatbot.py
Johan713's picture
Upload 13 files
5347681 verified
raw
history blame
19.6 kB
import streamlit as st
import random
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, SystemMessage
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
import os
from dotenv import load_dotenv
import requests
from bs4 import BeautifulSoup
import pandas as pd
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
import time
from langchain.schema import Document
from docx import Document as DocxDocument
from PyPDF2 import PdfReader
import io
# Load environment variables
load_dotenv()
AI71_BASE_URL = "https://api.ai71.ai/v1/"
AI71_API_KEY = os.getenv('AI71_API_KEY')
# Initialize session state variables
if "custom_personality" not in st.session_state:
st.session_state.custom_personality = ""
if "messages" not in st.session_state:
st.session_state.messages = []
# Initialize the Falcon model
@st.cache_resource
def get_llm():
return ChatOpenAI(
model="tiiuae/falcon-180B-chat",
api_key=AI71_API_KEY,
base_url=AI71_BASE_URL,
streaming=True,
)
# Initialize embeddings
@st.cache_resource
def get_embeddings():
return HuggingFaceEmbeddings()
def process_documents(uploaded_files):
documents = []
for uploaded_file in uploaded_files:
file_extension = os.path.splitext(uploaded_file.name)[1].lower()
try:
if file_extension in [".txt", ".md"]:
content = uploaded_file.getvalue().decode("utf-8")
documents.append(Document(page_content=content, metadata={"source": uploaded_file.name}))
elif file_extension == ".docx":
docx_file = io.BytesIO(uploaded_file.getvalue())
doc = DocxDocument(docx_file)
content = "\n".join([para.text for para in doc.paragraphs])
documents.append(Document(page_content=content, metadata={"source": uploaded_file.name}))
elif file_extension == ".pdf":
pdf_file = io.BytesIO(uploaded_file.getvalue())
pdf_reader = PdfReader(pdf_file)
content = ""
for page in pdf_reader.pages:
content += page.extract_text()
documents.append(Document(page_content=content, metadata={"source": uploaded_file.name}))
else:
st.warning(f"Unsupported file type: {file_extension}")
except Exception as e:
st.error(f"Error processing file {uploaded_file.name}: {str(e)}")
if not documents:
return None
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)
vectorstore = FAISS.from_documents(texts, get_embeddings())
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
qa_chain = RetrievalQA.from_chain_type(
llm=get_llm(),
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
)
return qa_chain
def get_chatbot_response(user_input, qa_chain=None, personality="default", web_search=False):
system_message = get_personality_prompt(personality)
web_info = ""
if web_search:
web_results = search_web_duckduckgo(user_input)
web_info = "\n\n".join([f"Title: {result['title']}\nLink: {result['link']}\nSnippet: {result['snippet']}" for result in web_results])
user_input += f"\n\nWeb search results:\n{web_info}"
if qa_chain:
result = qa_chain({"query": user_input})
response = result['result']
source_docs = result.get('source_documents', [])
else:
messages = [
SystemMessage(content=system_message),
HumanMessage(content=user_input)
]
response = get_llm().invoke(messages).content
source_docs = []
return response, source_docs, web_results if web_search else None
def get_personality_prompt(personality):
personalities = {
"default": "You are a helpful assistant.",
"sherlock": "You are Sherlock Holmes, the world's greatest detective. Respond with keen observation and deductive reasoning.",
"yoda": "Wise and cryptic, you are. Like Yoda from Star Wars, speak you must.",
"shakespeare": "Thou art the Bard himself. In iambic pentameter, respond with eloquence and poetic flair.",
"custom": st.session_state.custom_personality
}
return personalities.get(personality, personalities["default"])
def search_web_duckduckgo(query: str, num_results: int = 3, max_retries: int = 3):
api_key = os.getenv('api_key')
cse_id = os.getenv('cse_id')
for attempt in range(max_retries):
try:
service = build("customsearch", "v1", developerKey=api_key)
res = service.cse().list(q=query, cx=cse_id, num=num_results).execute()
results = []
if "items" in res:
for item in res["items"]:
result = {
"title": item["title"],
"link": item["link"],
"snippet": item.get("snippet", "")
}
results.append(result)
return results
except HttpError as e:
print(f"HTTP error occurred: {e}. Attempt {attempt + 1} of {max_retries}")
except Exception as e:
print(f"An unexpected error occurred: {e}. Attempt {attempt + 1} of {max_retries}")
time.sleep(2 ** attempt)
print("Max retries reached. No results found.")
return []
def main():
st.set_page_config(page_title="S.H.E.R.L.O.C.K. Chatbot", page_icon="πŸ•΅οΈ", layout="wide")
st.title("S.H.E.R.L.O.C.K. Chatbot")
# Sidebar
with st.sidebar:
st.image("", use_column_width=True)
st.subheader("πŸ“ Document Upload")
uploaded_files = st.file_uploader("Upload documents", type=["txt", "md", "docx", "pdf"], accept_multiple_files=True)
st.subheader("🎭 Chatbot Personality")
personality = st.selectbox("Choose chatbot personality", ["default", "sherlock", "yoda", "shakespeare", "custom"])
if personality == "custom":
st.session_state.custom_personality = st.text_area("Enter custom personality details:", value=st.session_state.custom_personality)
st.subheader("🌐 Web Search")
web_search = st.checkbox("Enable web search")
st.subheader("πŸ’¬ Chat Mode")
chat_mode = st.radio("Select chat mode", ["General Chat", "Document Chat"])
if st.button("Clear Chat History"):
st.session_state.messages = []
st.rerun()
# Main content
if uploaded_files:
qa_chain = process_documents(uploaded_files)
if qa_chain:
st.success("Documents processed successfully!")
else:
st.warning("No valid documents were uploaded or processed.")
else:
qa_chain = None
# Chat interface
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("What is your question?"):
st.chat_message("user").markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
if chat_mode == "General Chat" or not qa_chain:
response, _, web_results = get_chatbot_response(prompt, personality=personality, web_search=web_search)
else:
response, source_docs, web_results = get_chatbot_response(prompt, qa_chain, personality, web_search)
with st.chat_message("assistant"):
st.markdown(response)
if chat_mode == "Document Chat" and qa_chain and source_docs:
with st.expander("Source Documents"):
for doc in source_docs:
st.markdown(f"**Source:** {doc.metadata.get('source', 'Unknown')}")
st.markdown(doc.page_content[:200] + "...")
if web_search and web_results:
with st.expander("Web Search Results"):
for result in web_results:
st.markdown(f"**[{result['title']}]({result['link']})**")
st.markdown(result['snippet'])
st.session_state.messages.append({"role": "assistant", "content": response})
# Chat history and download
with st.sidebar:
st.subheader("πŸ“œ Chat History")
history_expander = st.expander("View Chat History")
with history_expander:
for message in st.session_state.messages:
st.text(f"{message['role']}: {message['content'][:50]}...")
if st.session_state.messages:
chat_history_df = pd.DataFrame(st.session_state.messages)
csv = chat_history_df.to_csv(index=False)
st.download_button(
label="πŸ“₯ Download Chat History",
data=csv,
file_name="chat_history.csv",
mime="text/csv",
)
st.sidebar.markdown("---")
st.sidebar.markdown("Powered by Falcon-180B and Streamlit")
if __name__ == "__main__":
main()