Spaces:
Sleeping
Sleeping
import streamlit as st | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
from pinecone import Pinecone | |
from langchain.vectorstores import Pinecone as LangchainPinecone | |
from langchain_groq import ChatGroq | |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate | |
from langchain.chains import create_retrieval_chain | |
from langchain.chains.combine_documents import create_stuff_documents_chain | |
from langchain_community.chat_message_histories import StreamlitChatMessageHistory | |
from langchain_core.chat_history import BaseChatMessageHistory | |
from langchain_core.runnables.history import RunnableWithMessageHistory | |
from langchain.chains import create_history_aware_retriever | |
import time | |
import os | |
# Embedding setup | |
model_name = "BAAI/bge-small-en" | |
model_kwargs = {"device": "cpu"} | |
encode_kwargs = {"normalize_embeddings": True} | |
embeddings = HuggingFaceBgeEmbeddings( | |
model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs | |
) | |
# Pinecone setup | |
pc = Pinecone(api_key="pcsk_5yLpy7_7DWbGm2s2HTf1NCbo4zFB8KLEZFLT54q3poTUoEFMbf1B9ShUZqpsT7EPnE3Pjw") | |
text_field = "text" | |
index_name = "contentengine" | |
index = pc.Index(index_name) | |
vectorstore = LangchainPinecone(index, embeddings.embed_query, text_field) | |
# Retriever setup | |
retriever = vectorstore.as_retriever( | |
search_type="similarity_score_threshold", | |
search_kwargs={"k": 1, "score_threshold": 0.5}, | |
) | |
llm = ChatGroq(model="llama3-8b-8192", api_key='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', max_tokens=4096) | |
# Retriever prompt setup | |
retriever_prompt = """ | |
Given a chat history and the latest user question which might reference context in the chat history, | |
formulate a standalone question which can be understood without the chat history. | |
Do NOT answer the question, just reformulate it if needed and otherwise return it as is. | |
Chat History: | |
{chat_history} | |
User Question: {input} | |
Standalone question: | |
""" | |
contextualize_q_prompt = ChatPromptTemplate.from_messages([ | |
("system", retriever_prompt), | |
MessagesPlaceholder(variable_name="chat_history"), | |
("human", "{input}"), | |
]) | |
history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt) | |
from langchain_core.prompts import PromptTemplate | |
template = """ | |
Context: This Content Engine is designed to analyze and compare key information across multiple Form 10-K filings for major companies, specifically Alphabet, Tesla, and Uber. The system uses Retrieval-Augmented Generation (RAG) to retrieve and summarize insights, highlight differences, and answer user queries on various financial and operational topics, such as risk factors, revenue, and business models. | |
Chat History: {chat_history} | |
Context: {context} | |
Human: {input} | |
Answer: | |
""" | |
# Define the PromptTemplate with specified input variables | |
custom_rag_prompt = PromptTemplate(template=template, input_variables=["chat_history", "context", "input"]) | |
question_answering_chain = create_stuff_documents_chain(llm, custom_rag_prompt) | |
rag_chain = create_retrieval_chain(history_aware_retriever, question_answering_chain) | |
# ======================================================= Streamlit UI ======================================================= | |
st.title("Chat with Content Engine") | |
# Initialize chat history | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = StreamlitChatMessageHistory(key="chat_messages") | |
# Message history setup | |
def get_chat_history(): | |
return st.session_state.chat_history | |
# Conversational_rag_chain to use the Streamlit chat history | |
conversational_rag_chain = RunnableWithMessageHistory( | |
rag_chain, | |
get_chat_history, | |
input_messages_key="input", | |
history_messages_key="chat_history", | |
output_messages_key="answer" | |
) | |
# Function to interact with the chatbot | |
def chat_with_bot(query: str) -> str: | |
result = conversational_rag_chain.invoke( | |
{"input": query}, | |
config={ | |
"configurable": {"session_id": "streamlit_session"} | |
}, | |
) | |
return result["answer"] | |
# Display chat messages from history | |
for message in st.session_state.chat_history.messages: | |
with st.chat_message(message.type): | |
st.markdown(message.content) | |
# Accept user input | |
if user_input := st.chat_input("Enter your question here..."): | |
# Display user message in chat message container | |
with st.chat_message("human"): | |
st.markdown(user_input) | |
# Display assistant response in chat message container | |
with st.chat_message("ai"): | |
with st.spinner("Thinking..."): | |
response = chat_with_bot(user_input) | |
message_placeholder = st.empty() | |
full_response = "⚠️ **_Reminder: Please double-check information._** \n\n" | |
for chunk in response: | |
full_response += chunk | |
time.sleep(0.01) | |
message_placeholder.markdown(full_response + ":white_circle:", unsafe_allow_html=True) | |