Spaces:
Running
on
Zero
Running
on
Zero
jedick
commited on
Commit
·
f8c72d3
1
Parent(s):
dac4e7d
Use session state for LangChain graph
Browse files
app.py
CHANGED
@@ -18,55 +18,58 @@ import os
|
|
18 |
# Setup environment variables
|
19 |
load_dotenv(dotenv_path=".env", override=True)
|
20 |
|
21 |
-
# Global
|
22 |
-
COMPUTE = "local"
|
23 |
search_type = "hybrid"
|
24 |
|
25 |
-
# Global variables for LangChain graph
|
26 |
-
|
27 |
-
|
28 |
|
29 |
|
30 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
"""The main function to run the chat workflow"""
|
32 |
|
33 |
-
#
|
34 |
-
|
35 |
-
if COMPUTE == "local":
|
36 |
-
# We don't want the app to switch into remote mode without notification,
|
37 |
-
# so ask the user to do it
|
38 |
if not torch.cuda.is_available():
|
39 |
raise gr.Error(
|
40 |
"Local mode requires GPU. Please select remote mode.",
|
41 |
print_exception=False,
|
42 |
)
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
46 |
|
47 |
if graph is None:
|
48 |
# Notify when we're loading the local model because it takes some time
|
49 |
-
if
|
50 |
gr.Info(
|
51 |
f"Please wait for the local model to load",
|
52 |
duration=15,
|
53 |
title=f"Model loading...",
|
54 |
)
|
55 |
# Get the chat model and build the graph
|
56 |
-
chat_model = GetChatModel(
|
57 |
-
graph_builder = BuildGraph(chat_model,
|
58 |
# Compile the graph with an in-memory checkpointer
|
59 |
memory = MemorySaver()
|
60 |
graph = graph_builder.compile(checkpointer=memory)
|
61 |
# Set global graph for compute mode
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
# Notify when model finishes loading
|
68 |
-
gr.Success(f"{COMPUTE}", duration=4, title=f"Model loaded!")
|
69 |
-
print(f"Set graph for {COMPUTE}, {search_type}!")
|
70 |
|
71 |
print(f"Using thread_id: {thread_id}")
|
72 |
|
@@ -180,13 +183,16 @@ def run_workflow(input, history, thread_id):
|
|
180 |
yield history, None, citations
|
181 |
|
182 |
|
183 |
-
def to_workflow(*args):
|
184 |
"""Wrapper function to call function with or without @spaces.GPU"""
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
187 |
yield value
|
188 |
-
if
|
189 |
-
for value in run_workflow_remote(*
|
190 |
yield value
|
191 |
|
192 |
|
@@ -236,7 +242,7 @@ with gr.Blocks(
|
|
236 |
"local",
|
237 |
"remote",
|
238 |
],
|
239 |
-
value=
|
240 |
label="Compute Mode",
|
241 |
info=(None if torch.cuda.is_available() else "NOTE: local mode requires GPU"),
|
242 |
render=False,
|
@@ -444,10 +450,6 @@ with gr.Blocks(
|
|
444 |
"""Return updated value for a component"""
|
445 |
return gr.update(value=value)
|
446 |
|
447 |
-
def set_compute(compute_mode):
|
448 |
-
global COMPUTE
|
449 |
-
COMPUTE = compute_mode
|
450 |
-
|
451 |
def set_avatar(compute_mode):
|
452 |
if compute_mode == "remote":
|
453 |
image_file = "images/cloud.png"
|
@@ -475,13 +477,6 @@ with gr.Blocks(
|
|
475 |
# Display the content in the textbox
|
476 |
return content, change_visibility(True)
|
477 |
|
478 |
-
# def update_citations(citations):
|
479 |
-
# if citations == []:
|
480 |
-
# # Blank out and hide the citations textbox when new input is submitted
|
481 |
-
# return "", change_visibility(False)
|
482 |
-
# else:
|
483 |
-
# return citations, change_visibility(True)
|
484 |
-
|
485 |
# --------------
|
486 |
# Event handlers
|
487 |
# --------------
|
@@ -495,11 +490,6 @@ with gr.Blocks(
|
|
495 |
return component.clear()
|
496 |
|
497 |
compute_mode.change(
|
498 |
-
# Update global COMPUTE variable
|
499 |
-
set_compute,
|
500 |
-
[compute_mode],
|
501 |
-
api_name=False,
|
502 |
-
).then(
|
503 |
# Change the app status text
|
504 |
get_status_text,
|
505 |
[compute_mode],
|
@@ -527,7 +517,7 @@ with gr.Blocks(
|
|
527 |
input.submit(
|
528 |
# Submit input to the chatbot
|
529 |
to_workflow,
|
530 |
-
[input, chatbot, thread_id],
|
531 |
[chatbot, retrieved_emails, citations_text],
|
532 |
api_name=False,
|
533 |
)
|
@@ -661,6 +651,9 @@ with gr.Blocks(
|
|
661 |
)
|
662 |
# fmt: on
|
663 |
|
|
|
|
|
|
|
664 |
|
665 |
if __name__ == "__main__":
|
666 |
|
|
|
18 |
# Setup environment variables
|
19 |
load_dotenv(dotenv_path=".env", override=True)
|
20 |
|
21 |
+
# Global setting for search type
|
|
|
22 |
search_type = "hybrid"
|
23 |
|
24 |
+
# Global variables for LangChain graph: use dictionaries to store user-specific instances
|
25 |
+
# https://www.gradio.app/guides/state-in-blocks
|
26 |
+
graph_instances = {"local": {}, "remote": {}}
|
27 |
|
28 |
|
29 |
+
def cleanup_graph(request: gr.Request):
|
30 |
+
if request.session_hash in graph_instances["local"]:
|
31 |
+
del graph_instances["local"][request.session_hash]
|
32 |
+
print(f"Deleted local graph for session {request.session_hash}")
|
33 |
+
if request.session_hash in graph_instances["remote"]:
|
34 |
+
del graph_instances["remote"][request.session_hash]
|
35 |
+
print(f"Deleted remote graph for session {request.session_hash}")
|
36 |
+
|
37 |
+
|
38 |
+
def run_workflow(input, history, compute_mode, thread_id, session_hash):
|
39 |
"""The main function to run the chat workflow"""
|
40 |
|
41 |
+
# Error if user tries to run local mode without GPU
|
42 |
+
if compute_mode == "local":
|
|
|
|
|
|
|
43 |
if not torch.cuda.is_available():
|
44 |
raise gr.Error(
|
45 |
"Local mode requires GPU. Please select remote mode.",
|
46 |
print_exception=False,
|
47 |
)
|
48 |
+
|
49 |
+
# Get graph for compute mode
|
50 |
+
graph = graph_instances[compute_mode].get(session_hash)
|
51 |
+
if graph is not None:
|
52 |
+
print(f"Get {compute_mode} graph for session {session_hash}")
|
53 |
|
54 |
if graph is None:
|
55 |
# Notify when we're loading the local model because it takes some time
|
56 |
+
if compute_mode == "local":
|
57 |
gr.Info(
|
58 |
f"Please wait for the local model to load",
|
59 |
duration=15,
|
60 |
title=f"Model loading...",
|
61 |
)
|
62 |
# Get the chat model and build the graph
|
63 |
+
chat_model = GetChatModel(compute_mode)
|
64 |
+
graph_builder = BuildGraph(chat_model, compute_mode, search_type)
|
65 |
# Compile the graph with an in-memory checkpointer
|
66 |
memory = MemorySaver()
|
67 |
graph = graph_builder.compile(checkpointer=memory)
|
68 |
# Set global graph for compute mode
|
69 |
+
graph_instances[compute_mode][session_hash] = graph
|
70 |
+
print(f"Set {compute_mode} graph for session {session_hash}")
|
71 |
+
# Notify when model finishes loading
|
72 |
+
gr.Success(f"{compute_mode}", duration=4, title=f"Model loaded")
|
|
|
|
|
|
|
|
|
73 |
|
74 |
print(f"Using thread_id: {thread_id}")
|
75 |
|
|
|
183 |
yield history, None, citations
|
184 |
|
185 |
|
186 |
+
def to_workflow(request: gr.Request, *args):
|
187 |
"""Wrapper function to call function with or without @spaces.GPU"""
|
188 |
+
compute_mode = args[2]
|
189 |
+
# Add session_hash to arguments
|
190 |
+
new_args = args + (request.session_hash,)
|
191 |
+
if compute_mode == "local":
|
192 |
+
for value in run_workflow_local(*new_args):
|
193 |
yield value
|
194 |
+
if compute_mode == "remote":
|
195 |
+
for value in run_workflow_remote(*new_args):
|
196 |
yield value
|
197 |
|
198 |
|
|
|
242 |
"local",
|
243 |
"remote",
|
244 |
],
|
245 |
+
value=("local" if torch.cuda.is_available() else "remote"),
|
246 |
label="Compute Mode",
|
247 |
info=(None if torch.cuda.is_available() else "NOTE: local mode requires GPU"),
|
248 |
render=False,
|
|
|
450 |
"""Return updated value for a component"""
|
451 |
return gr.update(value=value)
|
452 |
|
|
|
|
|
|
|
|
|
453 |
def set_avatar(compute_mode):
|
454 |
if compute_mode == "remote":
|
455 |
image_file = "images/cloud.png"
|
|
|
477 |
# Display the content in the textbox
|
478 |
return content, change_visibility(True)
|
479 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
# --------------
|
481 |
# Event handlers
|
482 |
# --------------
|
|
|
490 |
return component.clear()
|
491 |
|
492 |
compute_mode.change(
|
|
|
|
|
|
|
|
|
|
|
493 |
# Change the app status text
|
494 |
get_status_text,
|
495 |
[compute_mode],
|
|
|
517 |
input.submit(
|
518 |
# Submit input to the chatbot
|
519 |
to_workflow,
|
520 |
+
[input, chatbot, compute_mode, thread_id],
|
521 |
[chatbot, retrieved_emails, citations_text],
|
522 |
api_name=False,
|
523 |
)
|
|
|
651 |
)
|
652 |
# fmt: on
|
653 |
|
654 |
+
# Clean up graph instances when page is closed/refreshed
|
655 |
+
demo.unload(cleanup_graph)
|
656 |
+
|
657 |
|
658 |
if __name__ == "__main__":
|
659 |
|