dfasd commited on
Commit
aaf8725
·
verified ·
1 Parent(s): ef07738

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -96
app.py CHANGED
@@ -1,118 +1,78 @@
1
  from dotenv import load_dotenv
2
  import os
3
- from langchain_community.document_loaders import TextLoader
4
- from langchain_community.vectorstores import Chroma
5
- from langchain_text_splitters import CharacterTextSplitter
6
  from langchain_community.document_loaders import PyPDFLoader
7
- from langchain.text_splitter import CharacterTextSplitter
8
  from langchain_openai import OpenAIEmbeddings
 
 
9
  from langchain_openai import ChatOpenAI
10
- from langchain.chains.combine_documents import create_stuff_documents_chain
11
- from langchain.chains import create_retrieval_chain
12
  from langchain import hub
13
- from langchain_core.prompts import ChatPromptTemplate
14
- from langchain.chains.question_answering import load_qa_chain
15
- from langchain.prompts import PromptTemplate
16
 
17
- import time
18
  load_dotenv()
19
-
20
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
21
 
22
- # text_splitter = CharacterTextSplitter(separator = "\n", chunk_size=1000, chunk_overlap=200, length_function = len)
23
- # embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
24
- # retrieval_qa_chat_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
25
- # llm = ChatOpenAI(api_key=OPENAI_API_KEY)
 
26
 
27
- vectordb_path = "./vector_db"
 
 
 
28
 
29
- def query():
30
- if request.method == "POST":
31
- prompt = request.get_json().get("prompt")
32
- title = request.get_json().get("title")
33
- db = request.get_json().get("db")
34
-
35
- # if title == "search":
36
- # response = tavily.search(query=prompt, include_images=True, include_answer=True, max_results=5)
37
-
38
- # output = response['answer'] + "\n"
39
- # for res in response['results']:
40
- # output += f"\nTitle: {res['title']}\nURL: {res['url']}\nContent: {res['content']}\n"
41
-
42
- # data = {"success": "ok", "response": output, "images": response['images']}
43
 
44
- # return jsonify(data)
 
 
 
45
 
46
- if title == "rag":
47
- if db != "":
48
- template = """Please answer to human's input based on context. If the input is not mentioned in context, output something like 'I don't know'.
49
- Context: {context}
50
- Human: {human_input}
51
- Your Response as Chatbot:"""
52
-
53
- prompt_s = PromptTemplate(
54
- input_variables=["human_input", "context"],
55
- template=template
56
- )
57
 
58
- db = Chroma(persist_directory=os.path.join(vectordb_path, db), embedding_function=embeddings)
59
-
60
- docs = db.similarity_search(prompt)
61
-
62
- llm = ChatOpenAI(model="gpt-4-1106-preview", api_key=OPENAI_API_KEY)
63
 
64
- stuff_chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt_s)
65
- output = stuff_chain({"input_documents": docs, "human_input": prompt}, return_only_outputs=False)
66
 
67
- final_answer = output["output_text"]
 
 
 
 
 
68
 
69
- data = {"success": "ok", "response": final_answer}
70
-
71
- return jsonify(data)
72
- else:
73
- data = {"success": "ok", "response": "Please select database."}
74
 
75
- return jsonify(data)
76
-
77
- def uploadDocuments():
78
- # uploaded_files = request.files.getlist('files[]')
79
- uploaded_files = ['annualreport2223.pdf', 'Airbus-Annual-Report-2023.pdf']
80
- dbname = request.form.get('dbname')
81
- if dbname == "":
82
- return {"success": "db"}
83
-
84
- if len(uploaded_files) > 0:
85
- for file in uploaded_files:
86
- file.save(f"uploads/{file.filename}")
87
-
88
- if file.filename.endswith(".txt"):
89
- loader = TextLoader(f"uploads/{file.filename}", encoding='utf-8')
90
- else:
91
- loader = PyPDFLoader(f"uploads/{file.filename}")
92
-
93
- data = loader.load()
94
- texts = text_splitter.split_documents(data)
95
-
96
- Chroma.from_documents(texts, embeddings, persist_directory=os.path.join(vectordb_path, dbname))
97
-
98
- return {'success': "ok"}
99
 
100
- else:
101
- return {"success": "bad"}
 
 
102
 
103
- def dbcreate():
104
- dbname = request.get_json().get("dbname")
105
-
106
- if not os.path.exists(os.path.join(vectordb_path, dbname)):
107
- os.makedirs(os.path.join(vectordb_path, dbname))
108
- return {'success': "ok"}
109
- else:
110
- return {'success': 'bad'}
111
-
112
- import gradio as gr
113
- chatbot = gr.Chatbot(avatar_images=["user.png", "bot.jpg"], height=600)
114
- clear_but = gr.Button(value="Clear Chat")
115
- demo = gr.ChatInterface(fn="", title="Mediate.com Chatbot Prototype", multimodal=False, retry_btn=None, undo_btn=None, clear_btn=clear_but, chatbot=chatbot)
116
 
117
- if __name__ == "__main__":
118
- demo.launch(debug=True)
 
1
  from dotenv import load_dotenv
2
  import os
3
+ import gradio as gr
 
 
4
  from langchain_community.document_loaders import PyPDFLoader
5
+ from langchain_text_splitters import CharacterTextSplitter
6
  from langchain_openai import OpenAIEmbeddings
7
+ from langchain_community.vectorstores import Chroma
8
+ from langchain_core.runnables import RunnablePassthrough
9
  from langchain_openai import ChatOpenAI
 
 
10
  from langchain import hub
11
+ from langchain_core.output_parsers import StrOutputParser
 
 
12
 
13
+ # Load environment variables
14
  load_dotenv()
 
15
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
16
 
17
+ # Initialize components
18
+ text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=200, length_function=len)
19
+ embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
20
+ llm = ChatOpenAI(model="gpt-4-1106-preview", api_key=OPENAI_API_KEY)
21
+ vectordb_path = './vector_db'
22
 
23
+ # Load and process documents
24
+ uploaded_files = ['airbus.pdf', 'annualreport2223.pdf']
25
+ dbname = 'vector_db'
26
+ vectorstore = None
27
 
28
+ for file in uploaded_files:
29
+ loader = PyPDFLoader(file)
30
+ data = loader.load()
31
+ texts = text_splitter.split_documents(data)
 
 
 
 
 
 
 
 
 
 
32
 
33
+ if vectorstore is None:
34
+ vectorstore = Chroma.from_documents(documents=texts, embedding=embeddings, persist_directory=os.path.join(vectordb_path, dbname))
35
+ else:
36
+ vectorstore.add_documents(texts)
37
 
38
+ vectorstore.persist()
39
+ retriever = vectorstore.as_retriever()
 
 
 
 
 
 
 
 
 
40
 
41
+ # Load prompt template
42
+ prompt = hub.pull("rlm/rag-prompt")
43
+ print(prompt)
 
 
44
 
45
+ def format_docs(docs):
46
+ return "\n\n".join(doc.page_content for doc in docs)
47
 
48
+ rag_chain = (
49
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
50
+ | prompt
51
+ | llm
52
+ | StrOutputParser()
53
+ )
54
 
55
+ # Gradio interface
56
+ def rag_bot(query, chat_history):
57
+ response = rag_chain.invoke({"input": query, "chat_history": chat_history})
58
+ return response
 
59
 
60
+ chatbot = gr.Chatbot(avatar_images=["user.jpg", "bot.png"], height=600)
61
+ clear_but = gr.Button(value="Clear Chat")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ def chat(query, chat_history):
64
+ response = rag_bot(query, chat_history)
65
+ chat_history.append((query, response))
66
+ return chat_history, chat_history
67
 
68
+ demo = gr.Interface(
69
+ fn=chat,
70
+ inputs=["text", "state"],
71
+ outputs=["chatbot", "state"],
72
+ title="RAG Chatbot Prototype",
73
+ description="A Chatbot using Retrieval-Augmented Generation (RAG) with PDF files.",
74
+ allow_flagging="never",
75
+ )
 
 
 
 
 
76
 
77
+ if __name__ == '__main__':
78
+ demo.launch(debug=True, share=True)