File size: 19,686 Bytes
600acaa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
import streamlit as st
import random
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, SystemMessage
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
import os
from dotenv import load_dotenv
import requests
from bs4 import BeautifulSoup
import pandas as pd
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
import time
from langchain.schema import Document
from docx import Document as DocxDocument
from PyPDF2 import PdfReader
import io
# Load environment variables
load_dotenv()
AI71_BASE_URL = "https://api.ai71.ai/v1/"
AI71_API_KEY = "api71-api-92fc2ef9-9f3c-47e5-a019-18e257b04af2"
# Initialize session state variables
if "custom_personality" not in st.session_state:
st.session_state.custom_personality = ""
if "messages" not in st.session_state:
st.session_state.messages = []
# Initialize the Falcon model
@st.cache_resource
def get_llm():
return ChatOpenAI(
model="tiiuae/falcon-180B-chat",
api_key=AI71_API_KEY,
base_url=AI71_BASE_URL,
streaming=True,
)
# Initialize embeddings
@st.cache_resource
def get_embeddings():
return HuggingFaceEmbeddings()
def process_documents(uploaded_files):
documents = []
for uploaded_file in uploaded_files:
file_extension = os.path.splitext(uploaded_file.name)[1].lower()
try:
if file_extension in [".txt", ".md"]:
content = uploaded_file.getvalue().decode("utf-8")
documents.append(Document(page_content=content, metadata={"source": uploaded_file.name}))
elif file_extension == ".docx":
docx_file = io.BytesIO(uploaded_file.getvalue())
doc = DocxDocument(docx_file)
content = "\n".join([para.text for para in doc.paragraphs])
documents.append(Document(page_content=content, metadata={"source": uploaded_file.name}))
elif file_extension == ".pdf":
pdf_file = io.BytesIO(uploaded_file.getvalue())
pdf_reader = PdfReader(pdf_file)
content = ""
for page in pdf_reader.pages:
content += page.extract_text()
documents.append(Document(page_content=content, metadata={"source": uploaded_file.name}))
else:
st.warning(f"Unsupported file type: {file_extension}")
except Exception as e:
st.error(f"Error processing file {uploaded_file.name}: {str(e)}")
if not documents:
return None
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
texts = text_splitter.split_documents(documents)
vectorstore = FAISS.from_documents(texts, get_embeddings())
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
qa_chain = RetrievalQA.from_chain_type(
llm=get_llm(),
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
)
return qa_chain
def get_chatbot_response(user_input, qa_chain=None, personality="default", web_search=False):
system_message = get_personality_prompt(personality)
web_info = ""
if web_search:
web_results = search_web_duckduckgo(user_input)
web_info = "\n\n".join([f"Title: {result['title']}\nLink: {result['link']}\nSnippet: {result['snippet']}" for result in web_results])
user_input += f"\n\nWeb search results:\n{web_info}"
if qa_chain:
result = qa_chain({"query": user_input})
response = result['result']
source_docs = result.get('source_documents', [])
else:
messages = [
SystemMessage(content=system_message),
HumanMessage(content=user_input)
]
response = get_llm().invoke(messages).content
source_docs = []
return response, source_docs, web_results if web_search else None
def get_personality_prompt(personality):
personalities = {
"default": "You are a helpful assistant.",
"sherlock": "You are Sherlock Holmes, the world's greatest detective. Respond with keen observation and deductive reasoning.",
"yoda": "Wise and cryptic, you are. Like Yoda from Star Wars, speak you must.",
"shakespeare": "Thou art the Bard himself. In iambic pentameter, respond with eloquence and poetic flair.",
"custom": st.session_state.custom_personality
}
return personalities.get(personality, personalities["default"])
def search_web_duckduckgo(query: str, num_results: int = 3, max_retries: int = 3):
api_key = "AIzaSyD-1OMuZ0CxGAek0PaXrzHOmcDWFvZQtm8"
cse_id = "877170db56f5c4629"
for attempt in range(max_retries):
try:
service = build("customsearch", "v1", developerKey=api_key)
res = service.cse().list(q=query, cx=cse_id, num=num_results).execute()
results = []
if "items" in res:
for item in res["items"]:
result = {
"title": item["title"],
"link": item["link"],
"snippet": item.get("snippet", "")
}
results.append(result)
return results
except HttpError as e:
print(f"HTTP error occurred: {e}. Attempt {attempt + 1} of {max_retries}")
except Exception as e:
print(f"An unexpected error occurred: {e}. Attempt {attempt + 1} of {max_retries}")
time.sleep(2 ** attempt)
print("Max retries reached. No results found.")
return []
def main():
st.set_page_config(page_title="S.H.E.R.L.O.C.K. Chatbot", page_icon="π΅οΈ", layout="wide")
st.title("S.H.E.R.L.O.C.K. Chatbot")
# Sidebar
with st.sidebar:
st.image("", use_column_width=True)
st.subheader("π Document Upload")
uploaded_files = st.file_uploader("Upload documents", type=["txt", "md", "docx", "pdf"], accept_multiple_files=True)
st.subheader("π Chatbot Personality")
personality = st.selectbox("Choose chatbot personality", ["default", "sherlock", "yoda", "shakespeare", "custom"])
if personality == "custom":
st.session_state.custom_personality = st.text_area("Enter custom personality details:", value=st.session_state.custom_personality)
st.subheader("π Web Search")
web_search = st.checkbox("Enable web search")
st.subheader("π¬ Chat Mode")
chat_mode = st.radio("Select chat mode", ["General Chat", "Document Chat"])
if st.button("Clear Chat History"):
st.session_state.messages = []
st.rerun()
# Main content
if uploaded_files:
qa_chain = process_documents(uploaded_files)
if qa_chain:
st.success("Documents processed successfully!")
else:
st.warning("No valid documents were uploaded or processed.")
else:
qa_chain = None
# Chat interface
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("What is your question?"):
st.chat_message("user").markdown(prompt)
st.session_state.messages.append({"role": "user", "content": prompt})
if chat_mode == "General Chat" or not qa_chain:
response, _, web_results = get_chatbot_response(prompt, personality=personality, web_search=web_search)
else:
response, source_docs, web_results = get_chatbot_response(prompt, qa_chain, personality, web_search)
with st.chat_message("assistant"):
st.markdown(response)
if chat_mode == "Document Chat" and qa_chain and source_docs:
with st.expander("Source Documents"):
for doc in source_docs:
st.markdown(f"**Source:** {doc.metadata.get('source', 'Unknown')}")
st.markdown(doc.page_content[:200] + "...")
if web_search and web_results:
with st.expander("Web Search Results"):
for result in web_results:
st.markdown(f"**[{result['title']}]({result['link']})**")
st.markdown(result['snippet'])
st.session_state.messages.append({"role": "assistant", "content": response})
# Chat history and download
with st.sidebar:
st.subheader("π Chat History")
history_expander = st.expander("View Chat History")
with history_expander:
for message in st.session_state.messages:
st.text(f"{message['role']}: {message['content'][:50]}...")
if st.session_state.messages:
chat_history_df = pd.DataFrame(st.session_state.messages)
csv = chat_history_df.to_csv(index=False)
st.download_button(
label="π₯ Download Chat History",
data=csv,
file_name="chat_history.csv",
mime="text/csv",
)
st.sidebar.markdown("---")
st.sidebar.markdown("Powered by Falcon-180B and Streamlit")
if __name__ == "__main__":
main() |