Spaces:
Sleeping
Sleeping
import streamlit as st | |
import random | |
from langchain.chat_models import ChatOpenAI | |
from langchain.schema import HumanMessage, SystemMessage | |
from langchain.document_loaders import TextLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains import RetrievalQA | |
import os | |
from dotenv import load_dotenv | |
import requests | |
from bs4 import BeautifulSoup | |
import pandas as pd | |
from googleapiclient.discovery import build | |
from googleapiclient.errors import HttpError | |
import time | |
from langchain.schema import Document | |
from docx import Document as DocxDocument | |
from PyPDF2 import PdfReader | |
import io | |
# Load environment variables | |
load_dotenv() | |
AI71_BASE_URL = "https://api.ai71.ai/v1/" | |
AI71_API_KEY = os.getenv('AI71_API_KEY') | |
# Initialize session state variables | |
if "custom_personality" not in st.session_state: | |
st.session_state.custom_personality = "" | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# Initialize the Falcon model | |
def get_llm(): | |
return ChatOpenAI( | |
model="tiiuae/falcon-180B-chat", | |
api_key=AI71_API_KEY, | |
base_url=AI71_BASE_URL, | |
streaming=True, | |
) | |
# Initialize embeddings | |
def get_embeddings(): | |
return HuggingFaceEmbeddings() | |
def process_documents(uploaded_files): | |
documents = [] | |
for uploaded_file in uploaded_files: | |
file_extension = os.path.splitext(uploaded_file.name)[1].lower() | |
try: | |
if file_extension in [".txt", ".md"]: | |
content = uploaded_file.getvalue().decode("utf-8") | |
documents.append(Document(page_content=content, metadata={"source": uploaded_file.name})) | |
elif file_extension == ".docx": | |
docx_file = io.BytesIO(uploaded_file.getvalue()) | |
doc = DocxDocument(docx_file) | |
content = "\n".join([para.text for para in doc.paragraphs]) | |
documents.append(Document(page_content=content, metadata={"source": uploaded_file.name})) | |
elif file_extension == ".pdf": | |
pdf_file = io.BytesIO(uploaded_file.getvalue()) | |
pdf_reader = PdfReader(pdf_file) | |
content = "" | |
for page in pdf_reader.pages: | |
content += page.extract_text() | |
documents.append(Document(page_content=content, metadata={"source": uploaded_file.name})) | |
else: | |
st.warning(f"Unsupported file type: {file_extension}") | |
except Exception as e: | |
st.error(f"Error processing file {uploaded_file.name}: {str(e)}") | |
if not documents: | |
return None | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) | |
texts = text_splitter.split_documents(documents) | |
vectorstore = FAISS.from_documents(texts, get_embeddings()) | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
qa_chain = RetrievalQA.from_chain_type( | |
llm=get_llm(), | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=True, | |
) | |
return qa_chain | |
def get_chatbot_response(user_input, qa_chain=None, personality="default", web_search=False): | |
system_message = get_personality_prompt(personality) | |
web_info = "" | |
if web_search: | |
web_results = search_web_duckduckgo(user_input) | |
web_info = "\n\n".join([f"Title: {result['title']}\nLink: {result['link']}\nSnippet: {result['snippet']}" for result in web_results]) | |
user_input += f"\n\nWeb search results:\n{web_info}" | |
if qa_chain: | |
result = qa_chain({"query": user_input}) | |
response = result['result'] | |
source_docs = result.get('source_documents', []) | |
else: | |
messages = [ | |
SystemMessage(content=system_message), | |
HumanMessage(content=user_input) | |
] | |
response = get_llm().invoke(messages).content | |
source_docs = [] | |
return response, source_docs, web_results if web_search else None | |
def get_personality_prompt(personality): | |
personalities = { | |
"default": "You are a helpful assistant.", | |
"sherlock": "You are Sherlock Holmes, the world's greatest detective. Respond with keen observation and deductive reasoning.", | |
"yoda": "Wise and cryptic, you are. Like Yoda from Star Wars, speak you must.", | |
"shakespeare": "Thou art the Bard himself. In iambic pentameter, respond with eloquence and poetic flair.", | |
"custom": st.session_state.custom_personality | |
} | |
return personalities.get(personality, personalities["default"]) | |
def search_web_duckduckgo(query: str, num_results: int = 3, max_retries: int = 3): | |
api_key = os.getenv('api_key') | |
cse_id = os.getenv('cse_id') | |
for attempt in range(max_retries): | |
try: | |
service = build("customsearch", "v1", developerKey=api_key) | |
res = service.cse().list(q=query, cx=cse_id, num=num_results).execute() | |
results = [] | |
if "items" in res: | |
for item in res["items"]: | |
result = { | |
"title": item["title"], | |
"link": item["link"], | |
"snippet": item.get("snippet", "") | |
} | |
results.append(result) | |
return results | |
except HttpError as e: | |
print(f"HTTP error occurred: {e}. Attempt {attempt + 1} of {max_retries}") | |
except Exception as e: | |
print(f"An unexpected error occurred: {e}. Attempt {attempt + 1} of {max_retries}") | |
time.sleep(2 ** attempt) | |
print("Max retries reached. No results found.") | |
return [] | |
def main(): | |
st.set_page_config(page_title="S.H.E.R.L.O.C.K. Chatbot", page_icon="π΅οΈ", layout="wide") | |
st.title("S.H.E.R.L.O.C.K. Chatbot") | |
# Sidebar | |
with st.sidebar: | |
st.image("", use_column_width=True) | |
st.subheader("π Document Upload") | |
uploaded_files = st.file_uploader("Upload documents", type=["txt", "md", "docx", "pdf"], accept_multiple_files=True) | |
st.subheader("π Chatbot Personality") | |
personality = st.selectbox("Choose chatbot personality", ["default", "sherlock", "yoda", "shakespeare", "custom"]) | |
if personality == "custom": | |
st.session_state.custom_personality = st.text_area("Enter custom personality details:", value=st.session_state.custom_personality) | |
st.subheader("π Web Search") | |
web_search = st.checkbox("Enable web search") | |
st.subheader("π¬ Chat Mode") | |
chat_mode = st.radio("Select chat mode", ["General Chat", "Document Chat"]) | |
if st.button("Clear Chat History"): | |
st.session_state.messages = [] | |
st.rerun() | |
# Main content | |
if uploaded_files: | |
qa_chain = process_documents(uploaded_files) | |
if qa_chain: | |
st.success("Documents processed successfully!") | |
else: | |
st.warning("No valid documents were uploaded or processed.") | |
else: | |
qa_chain = None | |
# Chat interface | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if prompt := st.chat_input("What is your question?"): | |
st.chat_message("user").markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
if chat_mode == "General Chat" or not qa_chain: | |
response, _, web_results = get_chatbot_response(prompt, personality=personality, web_search=web_search) | |
else: | |
response, source_docs, web_results = get_chatbot_response(prompt, qa_chain, personality, web_search) | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
if chat_mode == "Document Chat" and qa_chain and source_docs: | |
with st.expander("Source Documents"): | |
for doc in source_docs: | |
st.markdown(f"**Source:** {doc.metadata.get('source', 'Unknown')}") | |
st.markdown(doc.page_content[:200] + "...") | |
if web_search and web_results: | |
with st.expander("Web Search Results"): | |
for result in web_results: | |
st.markdown(f"**[{result['title']}]({result['link']})**") | |
st.markdown(result['snippet']) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
# Chat history and download | |
with st.sidebar: | |
st.subheader("π Chat History") | |
history_expander = st.expander("View Chat History") | |
with history_expander: | |
for message in st.session_state.messages: | |
st.text(f"{message['role']}: {message['content'][:50]}...") | |
if st.session_state.messages: | |
chat_history_df = pd.DataFrame(st.session_state.messages) | |
csv = chat_history_df.to_csv(index=False) | |
st.download_button( | |
label="π₯ Download Chat History", | |
data=csv, | |
file_name="chat_history.csv", | |
mime="text/csv", | |
) | |
st.sidebar.markdown("---") | |
st.sidebar.markdown("Powered by Falcon-180B and Streamlit") | |
if __name__ == "__main__": | |
main() |