File size: 3,057 Bytes
385b1cf
 
aaf8725
385b1cf
aaf8725
385b1cf
aaf8725
 
385b1cf
 
aaf8725
620b6be
 
 
 
aaf8725
385b1cf
 
 
aaf8725
 
 
 
385b1cf
620b6be
aaf8725
620b6be
aaf8725
385b1cf
620b6be
 
 
 
 
385b1cf
620b6be
 
 
 
385b1cf
 
620b6be
 
385b1cf
620b6be
 
 
 
385b1cf
620b6be
 
 
 
385b1cf
620b6be
 
385b1cf
620b6be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385b1cf
aaf8725
 
620b6be
 
 
 
 
 
875f069
aaf8725
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from dotenv import load_dotenv
import os
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import CharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain.chains import create_history_aware_retriever
from langchain.prompts import PromptTemplate
from langchain.chains.question_answering import load_qa_chain
import pydantic
# Load environment variables
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

# Initialize components
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len)
embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
llm = ChatOpenAI(model="gpt-4-1106-preview", api_key=OPENAI_API_KEY)

vectordb_path = './vector_db'
dbname = 'vector_db'
uploaded_files = ['airbus.pdf', 'annualreport2223.pdf']
vectorstore = None

def create_vectordb():
    for file in uploaded_files:
        loader = PyPDFLoader(file)
        data = loader.load()
        texts = text_splitter.split_documents(data)

        if vectorstore is None:
            vectorstore = Chroma.from_documents(documents=texts, embedding=embeddings, persist_directory=os.path.join(vectordb_path, dbname))
        else:
            vectorstore.add_documents(texts)


def rag_bot(query, chat_history):
    print(f"Received query: {query}")

    template = """Please answer to human's input based on context. If the input is not mentioned in context, output something like 'I don't know'.
    Context: {context}
    Human: {human_input}
    Your Response as Chatbot:"""

    prompt_s = PromptTemplate(
        input_variables=["human_input", "context"],
        template=template
    )

    # Initialize vector store
    vectorstore = Chroma(persist_directory=os.path.join(vectordb_path), embedding_function=embeddings)

    # prompt = hub.pull("langchain-ai/chat-langchain-rephrase")

    docs = vectorstore.similarity_search(query)

    try:
        stuff_chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt_s)
    except pydantic.ValidationError as e:
        print(f"Validation error: {e}")

    output = stuff_chain({"input_documents": docs, "human_input": query}, return_only_outputs=False)

    final_answer = output["output_text"]
    print(f"Final Answer ---> {final_answer}")

    return final_answer

def chat(query, chat_history):
    response = rag_bot(query, chat_history)
    # chat_history.append((query, response))
    return response

chatbot = gr.Chatbot(avatar_images=["user.jpg", "bot.png"], height=600)
clear_but = gr.Button(value="Clear Chat")
demo = gr.ChatInterface(fn=chat, title="RAG Chatbot Prototype", multimodal=False, retry_btn=None, undo_btn=None, clear_btn=clear_but, chatbot=chatbot)

if __name__ == '__main__':
    demo.launch(debug=True, share=True)