Spaces:
Sleeping
Sleeping
File size: 4,654 Bytes
eb5a0c9 9cd41db 30fc578 2e3feae a4b393c eb5a0c9 a4b393c 9cd41db eb5a0c9 9cd41db eb5a0c9 a4b393c eb5a0c9 a4b393c eb5a0c9 9cd41db 6f82650 eb5a0c9 f9aa448 eb5a0c9 f9aa448 eb5a0c9 a4b393c eb5a0c9 a4b393c f9aa448 eb5a0c9 a4b393c 2d41cae a4b393c eb5a0c9 a4b393c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import streamlit as st
import asyncio
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.schema import SystemMessage, HumanMessage, AIMessage
from PyPDF2 import PdfReader
import aiohttp
from io import BytesIO
# Set up API key
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]
# Set up prompts
system_template = "You are an AI assistant answering questions about AI. Use the following context to answer the user's question. If you cannot find the answer in the context, say you don't know the answer but you can try to help with related information."
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)
human_template = "Context:\n{context}\n\nQuestion:\n{question}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
# Define RetrievalAugmentedQAPipeline class
class RetrievalAugmentedQAPipeline:
def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None:
self.llm = llm
self.vector_db = vector_db
async def arun_pipeline(self, user_query: str, chat_history: list):
context_docs = self.vector_db.similarity_search(user_query, k=2)
context_list = [doc.page_content for doc in context_docs]
context_prompt = "\n".join(context_list)
max_context_length = 12000
if len(context_prompt) > max_context_length:
context_prompt = context_prompt[:max_context_length]
messages = [SystemMessage(content=system_template)]
messages.extend(chat_history)
messages.append(HumanMessage(content=human_template.format(context=context_prompt, question=user_query)))
response = await self.llm.agenerate([messages])
return {"response": response.generations[0][0].text}
# PDF processing functions
async def fetch_pdf(session, url):
async with session.get(url) as response:
if response.status == 200:
return await response.read()
else:
return None
async def process_pdf(pdf_content):
pdf_reader = PdfReader(BytesIO(pdf_content))
text = "\n".join([page.extract_text() for page in pdf_reader.pages])
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40)
return text_splitter.split_text(text)
@st.cache_resource
def initialize_pipeline():
return asyncio.run(main())
# Main execution
async def main():
pdf_urls = [
"https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf",
"https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf",
]
all_chunks = []
async with aiohttp.ClientSession() as session:
pdf_contents = await asyncio.gather(*[fetch_pdf(session, url) for url in pdf_urls])
for pdf_content in pdf_contents:
if pdf_content:
chunks = await process_pdf(pdf_content)
all_chunks.extend(chunks)
embeddings = OpenAIEmbeddings()
vector_db = Chroma.from_texts(all_chunks, embeddings)
chat_openai = ChatOpenAI()
return RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai)
# Streamlit UI
st.title("Ask About AI!")
# Initialize session state for chat history
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
pipeline = initialize_pipeline()
# Display chat history
for message in st.session_state.chat_history:
if isinstance(message, HumanMessage):
st.write("You:", message.content)
elif isinstance(message, AIMessage):
st.write("AI:", message.content)
user_query = st.text_input("Enter your question about AI:")
if user_query:
# Add user message to chat history
st.session_state.chat_history.append(HumanMessage(content=user_query))
with st.spinner("Generating response..."):
result = asyncio.run(pipeline.arun_pipeline(user_query, st.session_state.chat_history))
# Add AI response to chat history
ai_message = AIMessage(content=result["response"])
st.session_state.chat_history.append(ai_message)
# Display the latest response
st.write("AI:", result["response"])
# Add a button to clear chat history
if st.button("Clear Chat History"):
st.session_state.chat_history = []
st.experimental_rerun() |