Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,7 +6,7 @@ from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTem
|
|
6 |
from langchain_community.vectorstores import Chroma
|
7 |
from langchain_community.embeddings import OpenAIEmbeddings
|
8 |
from langchain.chat_models import ChatOpenAI
|
9 |
-
from langchain.schema import SystemMessage, HumanMessage
|
10 |
from PyPDF2 import PdfReader
|
11 |
import aiohttp
|
12 |
from io import BytesIO
|
@@ -15,7 +15,7 @@ from io import BytesIO
|
|
15 |
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
|
16 |
|
17 |
# Set up prompts
|
18 |
-
system_template = "Use the following context to answer
|
19 |
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
|
20 |
|
21 |
human_template = "Context:\n{context}\n\nQuestion:\n{question}"
|
@@ -29,7 +29,7 @@ class RetrievalAugmentedQAPipeline:
|
|
29 |
self.llm = llm
|
30 |
self.vector_db = vector_db
|
31 |
|
32 |
-
async def arun_pipeline(self, user_query: str):
|
33 |
context_docs = self.vector_db.similarity_search(user_query, k=2)
|
34 |
context_list = [doc.page_content for doc in context_docs]
|
35 |
context_prompt = "\n".join(context_list)
|
@@ -38,7 +38,9 @@ class RetrievalAugmentedQAPipeline:
|
|
38 |
if len(context_prompt) > max_context_length:
|
39 |
context_prompt = context_prompt[:max_context_length]
|
40 |
|
41 |
-
messages =
|
|
|
|
|
42 |
|
43 |
response = await self.llm.agenerate([messages])
|
44 |
return {"response": response.generations[0][0].text}
|
@@ -86,13 +88,36 @@ async def main():
|
|
86 |
# Streamlit UI
|
87 |
st.title("Ask About AI!")
|
88 |
|
|
|
|
|
|
|
|
|
89 |
pipeline = initialize_pipeline()
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
user_query = st.text_input("Enter your question about AI:")
|
92 |
|
93 |
if user_query:
|
|
|
|
|
|
|
94 |
with st.spinner("Generating response..."):
|
95 |
-
result = asyncio.run(pipeline.arun_pipeline(user_query))
|
96 |
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
from langchain_community.vectorstores import Chroma
|
7 |
from langchain_community.embeddings import OpenAIEmbeddings
|
8 |
from langchain.chat_models import ChatOpenAI
|
9 |
+
from langchain.schema import SystemMessage, HumanMessage, AIMessage
|
10 |
from PyPDF2 import PdfReader
|
11 |
import aiohttp
|
12 |
from io import BytesIO
|
|
|
15 |
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
|
16 |
|
17 |
# Set up prompts
|
18 |
+
system_template = "You are an AI assistant answering questions about AI. Use the following context to answer the user's question. If you cannot find the answer in the context, say you don't know the answer but you can try to help with related information."
|
19 |
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
|
20 |
|
21 |
human_template = "Context:\n{context}\n\nQuestion:\n{question}"
|
|
|
29 |
self.llm = llm
|
30 |
self.vector_db = vector_db
|
31 |
|
32 |
+
async def arun_pipeline(self, user_query: str, chat_history: list):
|
33 |
context_docs = self.vector_db.similarity_search(user_query, k=2)
|
34 |
context_list = [doc.page_content for doc in context_docs]
|
35 |
context_prompt = "\n".join(context_list)
|
|
|
38 |
if len(context_prompt) > max_context_length:
|
39 |
context_prompt = context_prompt[:max_context_length]
|
40 |
|
41 |
+
messages = [SystemMessage(content=system_template)]
|
42 |
+
messages.extend(chat_history)
|
43 |
+
messages.append(HumanMessage(content=human_template.format(context=context_prompt, question=user_query)))
|
44 |
|
45 |
response = await self.llm.agenerate([messages])
|
46 |
return {"response": response.generations[0][0].text}
|
|
|
88 |
# Streamlit UI
|
89 |
st.title("Ask About AI!")
|
90 |
|
91 |
+
# Initialize session state for chat history
|
92 |
+
if "chat_history" not in st.session_state:
|
93 |
+
st.session_state.chat_history = []
|
94 |
+
|
95 |
pipeline = initialize_pipeline()
|
96 |
|
97 |
+
# Display chat history
|
98 |
+
for message in st.session_state.chat_history:
|
99 |
+
if isinstance(message, HumanMessage):
|
100 |
+
st.write("You:", message.content)
|
101 |
+
elif isinstance(message, AIMessage):
|
102 |
+
st.write("AI:", message.content)
|
103 |
+
|
104 |
user_query = st.text_input("Enter your question about AI:")
|
105 |
|
106 |
if user_query:
|
107 |
+
# Add user message to chat history
|
108 |
+
st.session_state.chat_history.append(HumanMessage(content=user_query))
|
109 |
+
|
110 |
with st.spinner("Generating response..."):
|
111 |
+
result = asyncio.run(pipeline.arun_pipeline(user_query, st.session_state.chat_history))
|
112 |
|
113 |
+
# Add AI response to chat history
|
114 |
+
ai_message = AIMessage(content=result["response"])
|
115 |
+
st.session_state.chat_history.append(ai_message)
|
116 |
+
|
117 |
+
# Display the latest response
|
118 |
+
st.write("AI:", result["response"])
|
119 |
+
|
120 |
+
# Add a button to clear chat history
|
121 |
+
if st.button("Clear Chat History"):
|
122 |
+
st.session_state.chat_history = []
|
123 |
+
st.experimental_rerun()
|