File size: 2,424 Bytes
385b1cf
 
aaf8725
385b1cf
aaf8725
385b1cf
aaf8725
 
385b1cf
 
aaf8725
385b1cf
aaf8725
385b1cf
 
 
aaf8725
 
 
 
 
385b1cf
aaf8725
 
 
 
385b1cf
aaf8725
 
 
 
385b1cf
aaf8725
 
 
 
385b1cf
aaf8725
 
385b1cf
aaf8725
 
 
385b1cf
aaf8725
 
385b1cf
aaf8725
 
 
 
 
 
385b1cf
aaf8725
 
 
 
385b1cf
aaf8725
 
385b1cf
aaf8725
 
 
 
385b1cf
aaf8725
 
 
 
 
 
 
 
875f069
aaf8725
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from dotenv import load_dotenv
import os
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import CharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain import hub
from langchain_core.output_parsers import StrOutputParser

# Load environment variables
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

# Initialize components
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len)
embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
llm = ChatOpenAI(model="gpt-4-1106-preview", api_key=OPENAI_API_KEY)
vectordb_path = './vector_db'

# Load and process documents
uploaded_files = ['airbus.pdf', 'annualreport2223.pdf']
dbname = 'vector_db'
vectorstore = None

for file in uploaded_files:
    loader = PyPDFLoader(file)
    data = loader.load()
    texts = text_splitter.split_documents(data)

    if vectorstore is None:
        vectorstore = Chroma.from_documents(documents=texts, embedding=embeddings, persist_directory=os.path.join(vectordb_path, dbname))
    else:
        vectorstore.add_documents(texts)

vectorstore.persist()
retriever = vectorstore.as_retriever()

# Load prompt template
prompt = hub.pull("rlm/rag-prompt")
print(prompt)

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

# Gradio interface
def rag_bot(query, chat_history):
    response = rag_chain.invoke({"input": query, "chat_history": chat_history})
    return response

chatbot = gr.Chatbot(avatar_images=["user.jpg", "bot.png"], height=600)
clear_but = gr.Button(value="Clear Chat")

def chat(query, chat_history):
    response = rag_bot(query, chat_history)
    chat_history.append((query, response))
    return chat_history, chat_history

demo = gr.Interface(
    fn=chat,
    inputs=["text", "state"],
    outputs=["chatbot", "state"],
    title="RAG Chatbot Prototype",
    description="A Chatbot using Retrieval-Augmented Generation (RAG) with PDF files.",
    allow_flagging="never",
)

if __name__ == '__main__':
    demo.launch(debug=True, share=True)