dfasd commited on
Commit
385b1cf
·
verified ·
1 Parent(s): f2e81a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +157 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, jsonify, request, redirect, url_for
2
+ from flask_wtf.csrf import CSRFProtect
3
+
4
+ # from tavily import TavilyClient
5
+
6
+ from dotenv import load_dotenv
7
+ import os
8
+
9
+ from langchain_community.document_loaders import TextLoader
10
+ from langchain_community.vectorstores import Chroma
11
+ from langchain_text_splitters import CharacterTextSplitter
12
+ from langchain_community.document_loaders import PyPDFLoader
13
+ from langchain.text_splitter import CharacterTextSplitter
14
+ from langchain_openai import OpenAIEmbeddings
15
+ from langchain_openai import ChatOpenAI
16
+ from langchain.chains.combine_documents import create_stuff_documents_chain
17
+ from langchain.chains import create_retrieval_chain
18
+ from langchain import hub
19
+ from langchain_core.prompts import ChatPromptTemplate
20
+ from langchain.chains.question_answering import load_qa_chain
21
+ from langchain.prompts import PromptTemplate
22
+
23
+ import time
24
+ load_dotenv()
25
+
26
+ # TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
27
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
28
+ # tavily = TavilyClient(api_key=TAVILY_API_KEY)
29
+
30
+ app = Flask(__name__, static_folder='static')
31
+ app.config['SECRET_KEY'] = 'secret'
32
+ csrf = CSRFProtect(app)
33
+
34
+ text_splitter = CharacterTextSplitter(separator = "\n", chunk_size=1000, chunk_overlap=200, length_function = len)
35
+ embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
36
+ retrieval_qa_chat_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
37
+ llm = ChatOpenAI(api_key=OPENAI_API_KEY)
38
+
39
+ vectordb_path = "./vector_db"
40
+
41
+ @app.route('/')
42
+ def home():
43
+ return redirect(url_for('search_view'))
44
+
45
+ @app.route('/search_view')
46
+ def search_view():
47
+ return render_template('search.html')
48
+
49
+ @app.route('/rag_view')
50
+ def rag_view():
51
+ dbs = [f.name for f in os.scandir(vectordb_path) if f.is_dir()]
52
+ return render_template('rag.html', dbs = dbs)
53
+
54
+ @app.route('/query', methods=['POST'])
55
+ def query():
56
+ if request.method == "POST":
57
+ prompt = request.get_json().get("prompt")
58
+ title = request.get_json().get("title")
59
+ db = request.get_json().get("db")
60
+
61
+ # if title == "search":
62
+ # response = tavily.search(query=prompt, include_images=True, include_answer=True, max_results=5)
63
+
64
+ # output = response['answer'] + "\n"
65
+ # for res in response['results']:
66
+ # output += f"\nTitle: {res['title']}\nURL: {res['url']}\nContent: {res['content']}\n"
67
+
68
+ # data = {"success": "ok", "response": output, "images": response['images']}
69
+
70
+ # return jsonify(data)
71
+
72
+ if title == "rag":
73
+ if db != "":
74
+ template = """Please answer to human's input based on context. If the input is not mentioned in context, output something like 'I don't know'.
75
+ Context: {context}
76
+ Human: {human_input}
77
+ Your Response as Chatbot:"""
78
+
79
+ prompt_s = PromptTemplate(
80
+ input_variables=["human_input", "context"],
81
+ template=template
82
+ )
83
+
84
+ db = Chroma(persist_directory=os.path.join(vectordb_path, db), embedding_function=embeddings)
85
+
86
+ docs = db.similarity_search(prompt)
87
+
88
+ llm = ChatOpenAI(model="gpt-4-1106-preview", api_key=OPENAI_API_KEY)
89
+
90
+ stuff_chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt_s)
91
+ output = stuff_chain({"input_documents": docs, "human_input": prompt}, return_only_outputs=False)
92
+
93
+ final_answer = output["output_text"]
94
+ # prompt = ChatPromptTemplate.from_messages(
95
+ # [("system", "Please answer to user's query based on following context.\n\nContext: {context}")]
96
+ # )
97
+
98
+
99
+ # chain = create_stuff_documents_chain(llm, prompt)
100
+
101
+ # answer = chain.invoke({"context": docs, "prompt": prompt})
102
+
103
+ data = {"success": "ok", "response": final_answer}
104
+
105
+ return jsonify(data)
106
+ else:
107
+ data = {"success": "ok", "response": "Please select database."}
108
+
109
+ return jsonify(data)
110
+
111
+ @app.route('/uploadDocuments', methods=['POST'])
112
+ @csrf.exempt
113
+ def uploadDocuments():
114
+ # uploaded_files = request.files.getlist('files[]')
115
+ dbname = request.form.get('dbname')
116
+ uploaded_files = ['https://www.airbus.com/sites/g/files/jlcbta136/files/2024-03/Airbus-Annual-Report-2023.pdf', 'https://www.singaporeair.com/saar5/pdf/Investor-Relations/Annual-Report/annualreport2223.pdf']
117
+ if dbname == "":
118
+ return {"success": "db"}
119
+
120
+ if len(uploaded_files) > 0:
121
+ for file in uploaded_files:
122
+ file.save(f"uploads/{file.filename}")
123
+
124
+ if file.filename.endswith(".txt"):
125
+ loader = TextLoader(f"uploads/{file.filename}", encoding='utf-8')
126
+ else:
127
+ loader = PyPDFLoader(f"uploads/{file.filename}")
128
+
129
+ data = loader.load()
130
+ texts = text_splitter.split_documents(data)
131
+
132
+ Chroma.from_documents(texts, embeddings, persist_directory=os.path.join(vectordb_path, dbname))
133
+
134
+ return {'success': "ok"}
135
+
136
+ else:
137
+ return {"success": "bad"}
138
+
139
+ @app.route('/dbcreate', methods=['POST'])
140
+ @csrf.exempt
141
+ def dbcreate():
142
+ dbname = request.get_json().get("dbname")
143
+
144
+ if not os.path.exists(os.path.join(vectordb_path, dbname)):
145
+ os.makedirs(os.path.join(vectordb_path, dbname))
146
+ return {'success': "ok"}
147
+ else:
148
+ return {'success': 'bad'}
149
+
150
+ # import gradio as gr
151
+ # chatbot = gr.Chatbot(avatar_images=["user.png", "bot.jpg"], height=600)
152
+ # clear_but = gr.Button(value="Clear Chat")
153
+ # demo = gr.ChatInterface(fn=search, title="Mediate.com Chatbot Prototype", multimodal=False, retry_btn=None, undo_btn=None, clear_btn=clear_but, chatbot=chatbot)
154
+
155
+
156
+ if __name__ == '__main__':
157
+ app.run(debug=True)