sudip2003 commited on
Commit
f270dea
·
verified ·
1 Parent(s): 4d25685

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ from langchain_community.embeddings import HuggingFaceBgeEmbeddings
4
+ from pinecone import Pinecone
5
+ from langchain.vectorstores import Pinecone as LangchainPinecone
6
+ from langchain_groq import ChatGroq
7
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate
8
+ from langchain.chains import create_retrieval_chain
9
+ from langchain.chains.combine_documents import create_stuff_documents_chain
10
+ from langchain_community.chat_message_histories import StreamlitChatMessageHistory
11
+ from langchain_core.chat_history import BaseChatMessageHistory
12
+ from langchain_core.runnables.history import RunnableWithMessageHistory
13
+ from langchain.chains import create_history_aware_retriever
14
+ import time
15
+ import os
16
+
17
+
18
+ # Embedding setup
19
+ model_name = "BAAI/bge-small-en"
20
+ model_kwargs = {"device": "cpu"}
21
+ encode_kwargs = {"normalize_embeddings": True}
22
+ embeddings = HuggingFaceBgeEmbeddings(
23
+ model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs
24
+ )
25
+
26
+ # Pinecone setup
27
+ pc = Pinecone(api_key="pcsk_5yLpy7_7DWbGm2s2HTf1NCbo4zFB8KLEZFLT54q3poTUoEFMbf1B9ShUZqpsT7EPnE3Pjw")
28
+ text_field = "text"
29
+ index_name = "contentengine"
30
+ index = pc.Index(index_name)
31
+ vectorstore = LangchainPinecone(index, embeddings.embed_query, text_field)
32
+
33
+ # Retriever setup
34
+ retriever = vectorstore.as_retriever(
35
+ search_type="similarity_score_threshold",
36
+ search_kwargs={"k": 1, "score_threshold": 0.5},
37
+ )
38
+
39
+
40
+
41
+ llm = ChatGroq(model="llama3-8b-8192", api_key='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', max_tokens=4096)
42
+
43
+
44
+ # Retriever prompt setup
45
+ retriever_prompt = """
46
+ Given a chat history and the latest user question which might reference context in the chat history,
47
+ formulate a standalone question which can be understood without the chat history.
48
+ Do NOT answer the question, just reformulate it if needed and otherwise return it as is.
49
+
50
+ Chat History:
51
+ {chat_history}
52
+
53
+ User Question: {input}
54
+
55
+ Standalone question:
56
+ """
57
+ contextualize_q_prompt = ChatPromptTemplate.from_messages([
58
+ ("system", retriever_prompt),
59
+ MessagesPlaceholder(variable_name="chat_history"),
60
+ ("human", "{input}"),
61
+ ])
62
+
63
+ history_aware_retriever = create_history_aware_retriever(llm, retriever, contextualize_q_prompt)
64
+
65
+ from langchain_core.prompts import PromptTemplate
66
+
67
+ template = """
68
+ Context: This Content Engine is designed to analyze and compare key information across multiple Form 10-K filings for major companies, specifically Alphabet, Tesla, and Uber. The system uses Retrieval-Augmented Generation (RAG) to retrieve and summarize insights, highlight differences, and answer user queries on various financial and operational topics, such as risk factors, revenue, and business models.
69
+
70
+ Chat History: {chat_history}
71
+ Context: {context}
72
+ Human: {input}
73
+
74
+ Answer:
75
+ """
76
+
77
+ # Define the PromptTemplate with specified input variables
78
+ custom_rag_prompt = PromptTemplate(template=template, input_variables=["chat_history", "context", "input"])
79
+
80
+ question_answering_chain = create_stuff_documents_chain(llm, custom_rag_prompt)
81
+ rag_chain = create_retrieval_chain(history_aware_retriever, question_answering_chain)
82
+
83
+ # ======================================================= Streamlit UI =======================================================
84
+
85
+ st.title("Chat with Content Engine")
86
+
87
+ # Initialize chat history
88
+ if "chat_history" not in st.session_state:
89
+ st.session_state.chat_history = StreamlitChatMessageHistory(key="chat_messages")
90
+
91
+ # Message history setup
92
+ def get_chat_history():
93
+ return st.session_state.chat_history
94
+
95
+ # Conversational_rag_chain to use the Streamlit chat history
96
+ conversational_rag_chain = RunnableWithMessageHistory(
97
+ rag_chain,
98
+ get_chat_history,
99
+ input_messages_key="input",
100
+ history_messages_key="chat_history",
101
+ output_messages_key="answer"
102
+ )
103
+
104
+ # Function to interact with the chatbot
105
+ def chat_with_bot(query: str) -> str:
106
+ result = conversational_rag_chain.invoke(
107
+ {"input": query},
108
+ config={
109
+ "configurable": {"session_id": "streamlit_session"}
110
+ },
111
+ )
112
+ return result["answer"]
113
+
114
+ # Display chat messages from history
115
+ for message in st.session_state.chat_history.messages:
116
+ with st.chat_message(message.type):
117
+ st.markdown(message.content)
118
+
119
+ # Accept user input
120
+ if user_input := st.chat_input("Enter your question here..."):
121
+
122
+ # Display user message in chat message container
123
+ with st.chat_message("human"):
124
+ st.markdown(user_input)
125
+
126
+ # Display assistant response in chat message container
127
+ with st.chat_message("ai"):
128
+ with st.spinner("Thinking..."):
129
+ response = chat_with_bot(user_input)
130
+ message_placeholder = st.empty()
131
+ full_response = "⚠️ **_Reminder: Please double-check information._** \n\n"
132
+ for chunk in response:
133
+ full_response += chunk
134
+ time.sleep(0.01)
135
+ message_placeholder.markdown(full_response + ":white_circle:", unsafe_allow_html=True)
136
+