rag / app.py
dfasd
Update app.py
3af14c7 verified
raw
history blame
4.66 kB
from dotenv import load_dotenv
import os
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Chroma
from langchain_text_splitters import CharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_openai import ChatOpenAI
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain
from langchain import hub
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts import PromptTemplate
import time
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
# text_splitter = CharacterTextSplitter(separator = "\n", chunk_size=1000, chunk_overlap=200, length_function = len)
# embeddings = OpenAIEmbeddings(api_key=OPENAI_API_KEY)
# retrieval_qa_chat_prompt = hub.pull("langchain-ai/retrieval-qa-chat")
# llm = ChatOpenAI(api_key=OPENAI_API_KEY)
vectordb_path = "./vector_db"
def query():
if request.method == "POST":
prompt = request.get_json().get("prompt")
title = request.get_json().get("title")
db = request.get_json().get("db")
# if title == "search":
# response = tavily.search(query=prompt, include_images=True, include_answer=True, max_results=5)
# output = response['answer'] + "\n"
# for res in response['results']:
# output += f"\nTitle: {res['title']}\nURL: {res['url']}\nContent: {res['content']}\n"
# data = {"success": "ok", "response": output, "images": response['images']}
# return jsonify(data)
if title == "rag":
if db != "":
template = """Please answer to human's input based on context. If the input is not mentioned in context, output something like 'I don't know'.
Context: {context}
Human: {human_input}
Your Response as Chatbot:"""
prompt_s = PromptTemplate(
input_variables=["human_input", "context"],
template=template
)
db = Chroma(persist_directory=os.path.join(vectordb_path, db), embedding_function=embeddings)
docs = db.similarity_search(prompt)
llm = ChatOpenAI(model="gpt-4-1106-preview", api_key=OPENAI_API_KEY)
stuff_chain = load_qa_chain(llm, chain_type="stuff", prompt=prompt_s)
output = stuff_chain({"input_documents": docs, "human_input": prompt}, return_only_outputs=False)
final_answer = output["output_text"]
data = {"success": "ok", "response": final_answer}
return jsonify(data)
else:
data = {"success": "ok", "response": "Please select database."}
return jsonify(data)
def uploadDocuments():
# uploaded_files = request.files.getlist('files[]')
uploaded_files = ['annualreport2223.pdf', 'Airbus-Annual-Report-2023.pdf']
dbname = request.form.get('dbname')
if dbname == "":
return {"success": "db"}
if len(uploaded_files) > 0:
for file in uploaded_files:
file.save(f"uploads/{file.filename}")
if file.filename.endswith(".txt"):
loader = TextLoader(f"uploads/{file.filename}", encoding='utf-8')
else:
loader = PyPDFLoader(f"uploads/{file.filename}")
data = loader.load()
texts = text_splitter.split_documents(data)
Chroma.from_documents(texts, embeddings, persist_directory=os.path.join(vectordb_path, dbname))
return {'success': "ok"}
else:
return {"success": "bad"}
def dbcreate():
dbname = request.get_json().get("dbname")
if not os.path.exists(os.path.join(vectordb_path, dbname)):
os.makedirs(os.path.join(vectordb_path, dbname))
return {'success': "ok"}
else:
return {'success': 'bad'}
import gradio as gr
chatbot = gr.Chatbot(avatar_images=["user.png", "bot.jpg"], height=600)
clear_but = gr.Button(value="Clear Chat")
demo = gr.ChatInterface(fn="", title="Mediate.com Chatbot Prototype", multimodal=False, retry_btn=None, undo_btn=None, clear_btn=clear_but, chatbot=chatbot)
if __name__ == "__main__":
demo.launch(debug=True)