Spaces:
Sleeping
Sleeping
import streamlit as st | |
import asyncio | |
import os | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.embeddings import OpenAIEmbeddings | |
from langchain.chat_models import ChatOpenAI | |
from langchain.schema import SystemMessage, HumanMessage, AIMessage | |
from PyPDF2 import PdfReader | |
import aiohttp | |
from io import BytesIO | |
# Set up API key | |
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"] | |
# Set up prompts | |
system_template = "You are an AI assistant answering questions about AI. Use the following context to answer the user's question. If you cannot find the answer in the context, say you don't know the answer but you can try to help with related information." | |
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template) | |
human_template = "Context:\n{context}\n\nQuestion:\n{question}" | |
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) | |
chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) | |
# Define RetrievalAugmentedQAPipeline class | |
class RetrievalAugmentedQAPipeline: | |
def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None: | |
self.llm = llm | |
self.vector_db = vector_db | |
async def arun_pipeline(self, user_query: str, chat_history: list): | |
context_docs = self.vector_db.similarity_search(user_query, k=2) | |
context_list = [doc.page_content for doc in context_docs] | |
context_prompt = "\n".join(context_list) | |
max_context_length = 12000 | |
if len(context_prompt) > max_context_length: | |
context_prompt = context_prompt[:max_context_length] | |
messages = [SystemMessage(content=system_template)] | |
messages.extend(chat_history) | |
messages.append(HumanMessage(content=human_template.format(context=context_prompt, question=user_query))) | |
response = await self.llm.agenerate([messages]) | |
return {"response": response.generations[0][0].text} | |
# PDF processing functions | |
async def fetch_pdf(session, url): | |
async with session.get(url) as response: | |
if response.status == 200: | |
return await response.read() | |
else: | |
return None | |
async def process_pdf(pdf_content): | |
pdf_reader = PdfReader(BytesIO(pdf_content)) | |
text = "\n".join([page.extract_text() for page in pdf_reader.pages]) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40) | |
return text_splitter.split_text(text) | |
def initialize_pipeline(): | |
return asyncio.run(main()) | |
# Main execution | |
async def main(): | |
pdf_urls = [ | |
"https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf", | |
"https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf", | |
] | |
all_chunks = [] | |
async with aiohttp.ClientSession() as session: | |
pdf_contents = await asyncio.gather(*[fetch_pdf(session, url) for url in pdf_urls]) | |
for pdf_content in pdf_contents: | |
if pdf_content: | |
chunks = await process_pdf(pdf_content) | |
all_chunks.extend(chunks) | |
embeddings = OpenAIEmbeddings() | |
vector_db = Chroma.from_texts(all_chunks, embeddings) | |
chat_openai = ChatOpenAI() | |
return RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai) | |
# Streamlit UI | |
st.title("Ask About AI!") | |
# Initialize session state for chat history | |
if "chat_history" not in st.session_state: | |
st.session_state.chat_history = [] | |
pipeline = initialize_pipeline() | |
# Display chat history | |
for message in st.session_state.chat_history: | |
if isinstance(message, HumanMessage): | |
st.write("You:", message.content) | |
elif isinstance(message, AIMessage): | |
st.write("AI:", message.content) | |
user_query = st.text_input("Enter your question about AI:") | |
if user_query: | |
# Add user message to chat history | |
st.session_state.chat_history.append(HumanMessage(content=user_query)) | |
with st.spinner("Generating response..."): | |
result = asyncio.run(pipeline.arun_pipeline(user_query, st.session_state.chat_history)) | |
# Add AI response to chat history | |
ai_message = AIMessage(content=result["response"]) | |
st.session_state.chat_history.append(ai_message) | |
# Display the latest response | |
st.write("AI:", result["response"]) | |
# Add a button to clear chat history | |
if st.button("Clear Chat History"): | |
st.session_state.chat_history = [] | |
st.experimental_rerun() |