File size: 4,654 Bytes
eb5a0c9
 
 
 
9cd41db
30fc578
 
2e3feae
a4b393c
eb5a0c9
 
 
 
 
 
 
 
a4b393c
9cd41db
eb5a0c9
9cd41db
 
 
 
eb5a0c9
 
 
 
 
 
 
a4b393c
eb5a0c9
 
 
 
 
 
 
 
a4b393c
 
 
eb5a0c9
9cd41db
6f82650
eb5a0c9
 
 
 
 
 
 
 
 
 
 
 
f9aa448
eb5a0c9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9aa448
eb5a0c9
a4b393c
 
 
 
eb5a0c9
 
a4b393c
 
 
 
 
 
 
f9aa448
eb5a0c9
 
a4b393c
 
 
2d41cae
a4b393c
eb5a0c9
a4b393c
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import streamlit as st
import asyncio
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.schema import SystemMessage, HumanMessage, AIMessage
from PyPDF2 import PdfReader
import aiohttp
from io import BytesIO

# Set up API key
os.environ["OPENAI_API_KEY"] = st.secrets["OPENAI_API_KEY"]

# Set up prompts
system_template = "You are an AI assistant answering questions about AI. Use the following context to answer the user's question. If you cannot find the answer in the context, say you don't know the answer but you can try to help with related information."
system_message_prompt = SystemMessagePromptTemplate.from_template(system_template)

human_template = "Context:\n{context}\n\nQuestion:\n{question}"
human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)

chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])

# Define RetrievalAugmentedQAPipeline class
class RetrievalAugmentedQAPipeline:
    def __init__(self, llm: ChatOpenAI, vector_db: Chroma) -> None:
        self.llm = llm
        self.vector_db = vector_db

    async def arun_pipeline(self, user_query: str, chat_history: list):
        context_docs = self.vector_db.similarity_search(user_query, k=2)
        context_list = [doc.page_content for doc in context_docs]
        context_prompt = "\n".join(context_list)
        
        max_context_length = 12000
        if len(context_prompt) > max_context_length:
            context_prompt = context_prompt[:max_context_length]
        
        messages = [SystemMessage(content=system_template)]
        messages.extend(chat_history)
        messages.append(HumanMessage(content=human_template.format(context=context_prompt, question=user_query)))

        response = await self.llm.agenerate([messages])
        return {"response": response.generations[0][0].text}

# PDF processing functions
async def fetch_pdf(session, url):
    async with session.get(url) as response:
        if response.status == 200:
            return await response.read()
        else:
            return None

async def process_pdf(pdf_content):
    pdf_reader = PdfReader(BytesIO(pdf_content))
    text = "\n".join([page.extract_text() for page in pdf_reader.pages])
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=40)
    return text_splitter.split_text(text)

@st.cache_resource
def initialize_pipeline():
    return asyncio.run(main())

# Main execution
async def main():
    pdf_urls = [
        "https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf",
        "https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf",
    ]

    all_chunks = []
    async with aiohttp.ClientSession() as session:
        pdf_contents = await asyncio.gather(*[fetch_pdf(session, url) for url in pdf_urls])
        
    for pdf_content in pdf_contents:
        if pdf_content:
            chunks = await process_pdf(pdf_content)
            all_chunks.extend(chunks)

    embeddings = OpenAIEmbeddings()
    vector_db = Chroma.from_texts(all_chunks, embeddings)
    
    chat_openai = ChatOpenAI()
    return RetrievalAugmentedQAPipeline(vector_db=vector_db, llm=chat_openai)

# Streamlit UI
st.title("Ask About AI!")

# Initialize session state for chat history
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []

pipeline = initialize_pipeline()

# Display chat history
for message in st.session_state.chat_history:
    if isinstance(message, HumanMessage):
        st.write("You:", message.content)
    elif isinstance(message, AIMessage):
        st.write("AI:", message.content)

user_query = st.text_input("Enter your question about AI:")

if user_query:
    # Add user message to chat history
    st.session_state.chat_history.append(HumanMessage(content=user_query))
    
    with st.spinner("Generating response..."):
        result = asyncio.run(pipeline.arun_pipeline(user_query, st.session_state.chat_history))
    
    # Add AI response to chat history
    ai_message = AIMessage(content=result["response"])
    st.session_state.chat_history.append(ai_message)
    
    # Display the latest response
    st.write("AI:", result["response"])

# Add a button to clear chat history
if st.button("Clear Chat History"):
    st.session_state.chat_history = []
    st.experimental_rerun()