File size: 6,009 Bytes
385b1cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
from flask import Flask, render_template, jsonify, request, redirect, url_for
from flask_wtf.csrf import CSRFProtect
# from tavily import TavilyClient
from dotenv import load_dotenv
import os
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain import hub
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate
import time
load_dotenv()
# TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
# tavily = TavilyClient(api_key=TAVILY_API_KEY)
app = Flask(__name__, static_folder='static')
app.config['SECRET_KEY'] = 'secret'
csrf = CSRFProtect(app)
text_splitter = CharacterTextSplitter(separator = "\n", chunk_size=1000, chunk_overlap=200, length_function = len)
embeddings = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
retrieval_qa_chat_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
llm = ChatOpenAI(api_key=OPENAI_API_KEY)
vectordb_path = "./vector_db"
@app.route('/')
def home():
return redirect(url_for('search_view'))
@app.route('/search_view')
def search_view():
return render_template('search.html')
@app.route('/rag_view')
def rag_view():
dbs = [f.name for f in os.scandir(vectordb_path) if f.is_dir()]
return render_template('rag.html', dbs = dbs)
@app.route('/query', methods=['POST'])
def query():
if request.method == "POST":
prompt = request.get_json().get("prompt")
title = request.get_json().get("title")
db = request.get_json().get("db")
# if title == "search":
# response = tavily.search(query=prompt, include_images=True, include_answer=True, max_results=5)
# output = response['answer'] + "\n"
# for res in response['results']:
# output += f"\nTitle: {res['title']}\nURL: {res['url']}\nContent: {res['content']}\n"
# data = {"success": "ok", "response": output, "images": response['images']}
# return jsonify(data)
if title == "rag":
if db != "":
template = """Please answer to human's input based on context. If the input is not mentioned in context, output something like 'I don't know'.
Context: {context}
Human: {human_input}
Your Response as Chatbot:"""
prompt_s = PromptTemplate(
input_variables=["human_input", "context"],
template=template
)
db = Chroma(persist_directory=os.path.join(vectordb_path, db), embedding_function=embeddings)
docs = db.similarity_search(prompt)
llm = ChatOpenAI(model="gpt-4-1106-preview", api_key=OPENAI_API_KEY)
stuff_chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt_s)
output = stuff_chain({"input_documents": docs, "human_input": prompt}, return_only_outputs=False)
final_answer = output["output_text"]
# prompt = ChatPromptTemplate.from_messages(
# [("system", "Please answer to user's query based on following context.\n\nContext: {context}")]
# )
# chain = create_stuff_documents_chain(llm, prompt)
# answer = chain.invoke({"context": docs, "prompt": prompt})
data = {"success": "ok", "response": final_answer}
return jsonify(data)
else:
data = {"success": "ok", "response": "Please select database."}
return jsonify(data)
@app.route('/uploadDocuments', methods=['POST'])
@csrf.exempt
def uploadDocuments():
# uploaded_files = request.files.getlist('files[]')
dbname = request.form.get('dbname')
uploaded_files = ['https://www.airbus.com/sites/g/files/jlcbta136/files/2024-03/Airbus-Annual-Report-2023.pdf', 'https://www.singaporeair.com/saar5/pdf/Investor-Relations/Annual-Report/annualreport2223.pdf']
if dbname == "":
return {"success": "db"}
if len(uploaded_files) > 0:
for file in uploaded_files:
file.save(f"uploads/{file.filename}")
if file.filename.endswith(".txt"):
loader = TextLoader(f"uploads/{file.filename}", encoding='utf-8')
else:
loader = PyPDFLoader(f"uploads/{file.filename}")
data = loader.load()
texts = text_splitter.split_documents(data)
Chroma.from_documents(texts, embeddings, persist_directory=os.path.join(vectordb_path, dbname))
return {'success': "ok"}
else:
return {"success": "bad"}
@app.route('/dbcreate', methods=['POST'])
@csrf.exempt
def dbcreate():
dbname = request.get_json().get("dbname")
if not os.path.exists(os.path.join(vectordb_path, dbname)):
os.makedirs(os.path.join(vectordb_path, dbname))
return {'success': "ok"}
else:
return {'success': 'bad'}
# import gradio as gr
# chatbot = gr.Chatbot(avatar_images=["user.png", "bot.jpg"], height=600)
# clear_but = gr.Button(value="Clear Chat")
# demo = gr.ChatInterface(fn=search, title="Mediate.com Chatbot Prototype", multimodal=False, retry_btn=None, undo_btn=None, clear_btn=clear_but, chatbot=chatbot)
if __name__ == '__main__':
app.run(debug=True) |