Spaces:
Runtime error
Runtime error
Commit
·
858002b
1
Parent(s):
9777d44
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
import random
|
3 |
+
import time
|
4 |
+
import gradio as gr
|
5 |
+
import os
|
6 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
7 |
+
from langchain.vectorstores import DeepLake
|
8 |
+
from langchain.chat_models import ChatOpenAI
|
9 |
+
from langchain.chains import ConversationalRetrievalChain
|
10 |
+
from langchain.document_loaders import TextLoader
|
11 |
+
from langchain.text_splitter import CharacterTextSplitter
|
12 |
+
from langchain.document_loaders import PyPDFDirectoryLoader
|
13 |
+
from langchain.memory import ConversationBufferMemory
|
14 |
+
from langchain.llms import OpenAI
|
15 |
+
|
16 |
+
def set_api_key(key):
|
17 |
+
os.environ["OPENAI_API_KEY"] = key
|
18 |
+
return f"Your API Key has been set to: {key}"
|
19 |
+
|
20 |
+
def reset_api_key():
|
21 |
+
os.environ["OPENAI_API_KEY"] = ""
|
22 |
+
return "Your API Key has been reset"
|
23 |
+
|
24 |
+
def get_api_key():
|
25 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
26 |
+
return api_key
|
27 |
+
|
28 |
+
def set_model(model):
|
29 |
+
os.environ["OPENAI_MODEL"] = model
|
30 |
+
return f"{model} selected"
|
31 |
+
|
32 |
+
def get_model():
|
33 |
+
model = os.getenv("OPENAI_MODEL")
|
34 |
+
return model
|
35 |
+
|
36 |
+
def upload_file(files):
|
37 |
+
file_paths = [file.name for file in files]
|
38 |
+
return file_paths
|
39 |
+
|
40 |
+
def create_vectorstore(files):
|
41 |
+
pdf_dir = files.name
|
42 |
+
pdf_loader = PyPDFDirectoryLoader(pdf_dir)
|
43 |
+
pdf_docs = pdf_loader.load_and_split()
|
44 |
+
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
45 |
+
texts = text_splitter.split_documents(pdf_docs)
|
46 |
+
embeddings = OpenAIEmbeddings()
|
47 |
+
db = DeepLake.from_documents(texts, dataset_path="./documentation_db", embedding=embeddings, overwrite=True)
|
48 |
+
return "Vectorstore Successfully Created"
|
49 |
+
|
50 |
+
def respond(message, chat_history):
|
51 |
+
|
52 |
+
# Get embeddings
|
53 |
+
embeddings = OpenAIEmbeddings()
|
54 |
+
|
55 |
+
#Connect to existing vectorstore
|
56 |
+
db = DeepLake(dataset_path="./documentation_db", embedding_function=embeddings, read_only=True)
|
57 |
+
#Set retriever settings
|
58 |
+
retriever = db.as_retriever(search_kwargs={"distance_metric":'cos',
|
59 |
+
"fetch_k":10,
|
60 |
+
"maximal_marginal_relevance":True,
|
61 |
+
"k":10})
|
62 |
+
|
63 |
+
if len(chat_history) != 0:
|
64 |
+
chat_history = [(chat_history[0][0], chat_history[0][1])]
|
65 |
+
|
66 |
+
model = get_model()
|
67 |
+
# Create ChatOpenAI and ConversationalRetrievalChain
|
68 |
+
model = ChatOpenAI(model=model)
|
69 |
+
qa = ConversationalRetrievalChain.from_llm(model, retriever)
|
70 |
+
|
71 |
+
bot_message = qa({"question": message, "chat_history": chat_history})
|
72 |
+
chat_history = [(message, bot_message["answer"])]
|
73 |
+
time.sleep(1)
|
74 |
+
return "", chat_history
|
75 |
+
|
76 |
+
with gr.Blocks() as demo:
|
77 |
+
#create chat history
|
78 |
+
chat_history = []
|
79 |
+
with gr.Row():
|
80 |
+
#create textbox for API input
|
81 |
+
api_input = gr.Textbox(label = "API Key",
|
82 |
+
placeholder = "Please provide your OpenAI API key here.")
|
83 |
+
#create textbox to validate API
|
84 |
+
api_key_status = gr.Textbox(label = "API Key Status",
|
85 |
+
placeholder = "Your API Key has not be set yet. Please enter your key.",
|
86 |
+
interactive = False)
|
87 |
+
#create button to submit API key
|
88 |
+
api_submit_button = gr.Button("Submit")
|
89 |
+
#set api_submit_button functionality
|
90 |
+
api_submit_button.click(set_api_key, inputs=api_input, outputs=api_key_status)
|
91 |
+
#create button to reset API key
|
92 |
+
api_reset_button = gr.Button("Clear API Key from session")
|
93 |
+
#set api_reset_button functionality
|
94 |
+
api_reset_button.click(reset_api_key, outputs=api_key_status)
|
95 |
+
|
96 |
+
with gr.Row():
|
97 |
+
with gr.Column():
|
98 |
+
#create dropdown to select model (gpt-3.5-turbo or gpt4)
|
99 |
+
model_selection = gr.Dropdown(
|
100 |
+
["gpt-3.5-turbo", "gpt-4"], label="Model Selection", info="Please ensure you provide the API Key that corresponds to the Model you select!"
|
101 |
+
)
|
102 |
+
#create button to submit model selection
|
103 |
+
model_submit_button = gr.Button("Submit Model Selection")
|
104 |
+
model_status = gr.Textbox(label = "Selected Model", interactive = False, lines=4)
|
105 |
+
#set model_submit_button functionality
|
106 |
+
model_submit_button.click(set_model, inputs=model_selection, outputs=model_status)
|
107 |
+
|
108 |
+
file_output = gr.File(label = "Uploaded files - Please note these files are persistent and will not be automatically deleted")
|
109 |
+
upload_button = gr.UploadButton("Click to Upload a PDF File", file_types=["pdf"], file_count="multiple")
|
110 |
+
upload_button.upload(upload_file, upload_button, file_output)
|
111 |
+
create_vectorstore_button = gr.Button("Click to create the vectorstore for your uploaded documents")
|
112 |
+
db_output = gr.Textbox(label = "Vectorstore Status")
|
113 |
+
create_vectorstore_button.click(create_vectorstore, inputs=file_output, outputs = db_output)
|
114 |
+
|
115 |
+
chatbot = gr.Chatbot(label="ChatGPT Powered Grant Writing Assistant")
|
116 |
+
msg = gr.Textbox(label="User Prompt", placeholder="Your Query Here")
|
117 |
+
clear = gr.Button("Clear")
|
118 |
+
|
119 |
+
msg.submit(respond, inputs = [msg, chatbot], outputs = [msg, chatbot])
|
120 |
+
clear.click(lambda: None, None, chatbot, queue=False)
|
121 |
+
|
122 |
+
|
123 |
+
demo.launch()
|