AIE4Midterm / app.py
Technocoloredgeek's picture
Update app.py
a4b393c verified
raw
history blame
4.65 kB
import streamlit as st
import asyncio
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.schema import SystemMessage, HumanMessage, AIMessage
from PyPDF2 import PdfReader
import aiohttp
from io import BytesIO
# Set up API key
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
# Set up prompts
system_template = "You are an AI assistant answering questions about AI. Use the following context to answer the user's question. If you cannot find the answer in the context, say you don't know the answer but you can try to help with related information."
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
human_template = "Context:\n{context}\n\nQuestion:\n{question}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
# Define RetrievalAugmentedQAPipeline class
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None:
self.llm = llm
self.vector_db = vector_db
async def arun_pipeline(self, user_query: str, chat_history: list):
context_docs = self.vector_db.similarity_search(user_query, k=2)
context_list = [doc.page_content for doc in context_docs]
context_prompt = "\n".join(context_list)
max_context_length = 12000
if len(context_prompt) > max_context_length:
context_prompt = context_prompt[:max_context_length]
messages = [SystemMessage(content=system_template)]
messages.extend(chat_history)
messages.append(HumanMessage(content=human_template.format(context=context_prompt, question=user_query)))
response = await self.llm.agenerate([messages])
return {"response": response.generations[0][0].text}
# PDF processing functions
async def fetch_pdf(session, url):
async with session.get(url) as response:
if response.status == 200:
return await response.read()
else:
return None
async def process_pdf(pdf_content):
pdf_reader = PdfReader(BytesIO(pdf_content))
text = "\n".join([page.extract_text() for page in pdf_reader.pages])
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40)
return text_splitter.split_text(text)
@st.cache_resource
def initialize_pipeline():
return asyncio.run(main())
# Main execution
async def main():
pdf_urls = [
"https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf",
"https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf",
]
all_chunks = []
async with aiohttp.ClientSession() as session:
pdf_contents = await asyncio.gather(*[fetch_pdf(session, url) for url in pdf_urls])
for pdf_content in pdf_contents:
if pdf_content:
chunks = await process_pdf(pdf_content)
all_chunks.extend(chunks)
embeddings = OpenAIEmbeddings()
vector_db = Chroma.from_texts(all_chunks, embeddings)
chat_openai = ChatOpenAI()
return RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai)
# Streamlit UI
st.title("Ask About AI!")
# Initialize session state for chat history
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
pipeline = initialize_pipeline()
# Display chat history
for message in st.session_state.chat_history:
if isinstance(message, HumanMessage):
st.write("You:", message.content)
elif isinstance(message, AIMessage):
st.write("AI:", message.content)
user_query = st.text_input("Enter your question about AI:")
if user_query:
# Add user message to chat history
st.session_state.chat_history.append(HumanMessage(content=user_query))
with st.spinner("Generating response..."):
result = asyncio.run(pipeline.arun_pipeline(user_query, st.session_state.chat_history))
# Add AI response to chat history
ai_message = AIMessage(content=result["response"])
st.session_state.chat_history.append(ai_message)
# Display the latest response
st.write("AI:", result["response"])
# Add a button to clear chat history
if st.button("Clear Chat History"):
st.session_state.chat_history = []
st.experimental_rerun()