jedick commited on
Commit
f8c72d3
·
1 Parent(s): dac4e7d

Use session state for LangChain graph

Browse files
Files changed (1) hide show
  1. app.py +41 -48
app.py CHANGED
@@ -18,55 +18,58 @@ import os
18
  # Setup environment variables
19
  load_dotenv(dotenv_path=".env", override=True)
20
 
21
- # Global settings for compute_mode and search_type
22
- COMPUTE = "local"
23
  search_type = "hybrid"
24
 
25
- # Global variables for LangChain graph
26
- graph_local = None
27
- graph_remote = None
28
 
29
 
30
- def run_workflow(input, history, thread_id):
 
 
 
 
 
 
 
 
 
31
  """The main function to run the chat workflow"""
32
 
33
- # Get global graph depending on compute mode
34
- global graph_local, graph_remote
35
- if COMPUTE == "local":
36
- # We don't want the app to switch into remote mode without notification,
37
- # so ask the user to do it
38
  if not torch.cuda.is_available():
39
  raise gr.Error(
40
  "Local mode requires GPU. Please select remote mode.",
41
  print_exception=False,
42
  )
43
- graph = graph_local
44
- if COMPUTE == "remote":
45
- graph = graph_remote
 
 
46
 
47
  if graph is None:
48
  # Notify when we're loading the local model because it takes some time
49
- if COMPUTE == "local":
50
  gr.Info(
51
  f"Please wait for the local model to load",
52
  duration=15,
53
  title=f"Model loading...",
54
  )
55
  # Get the chat model and build the graph
56
- chat_model = GetChatModel(COMPUTE)
57
- graph_builder = BuildGraph(chat_model, COMPUTE, search_type)
58
  # Compile the graph with an in-memory checkpointer
59
  memory = MemorySaver()
60
  graph = graph_builder.compile(checkpointer=memory)
61
  # Set global graph for compute mode
62
- if COMPUTE == "local":
63
- graph_local = graph
64
- if COMPUTE == "remote":
65
- graph_remote = graph
66
-
67
- # Notify when model finishes loading
68
- gr.Success(f"{COMPUTE}", duration=4, title=f"Model loaded!")
69
- print(f"Set graph for {COMPUTE}, {search_type}!")
70
 
71
  print(f"Using thread_id: {thread_id}")
72
 
@@ -180,13 +183,16 @@ def run_workflow(input, history, thread_id):
180
  yield history, None, citations
181
 
182
 
183
- def to_workflow(*args):
184
  """Wrapper function to call function with or without @spaces.GPU"""
185
- if COMPUTE == "local":
186
- for value in run_workflow_local(*args):
 
 
 
187
  yield value
188
- if COMPUTE == "remote":
189
- for value in run_workflow_remote(*args):
190
  yield value
191
 
192
 
@@ -236,7 +242,7 @@ with gr.Blocks(
236
  "local",
237
  "remote",
238
  ],
239
- value=COMPUTE,
240
  label="Compute Mode",
241
  info=(None if torch.cuda.is_available() else "NOTE: local mode requires GPU"),
242
  render=False,
@@ -444,10 +450,6 @@ with gr.Blocks(
444
  """Return updated value for a component"""
445
  return gr.update(value=value)
446
 
447
- def set_compute(compute_mode):
448
- global COMPUTE
449
- COMPUTE = compute_mode
450
-
451
  def set_avatar(compute_mode):
452
  if compute_mode == "remote":
453
  image_file = "images/cloud.png"
@@ -475,13 +477,6 @@ with gr.Blocks(
475
  # Display the content in the textbox
476
  return content, change_visibility(True)
477
 
478
- # def update_citations(citations):
479
- # if citations == []:
480
- # # Blank out and hide the citations textbox when new input is submitted
481
- # return "", change_visibility(False)
482
- # else:
483
- # return citations, change_visibility(True)
484
-
485
  # --------------
486
  # Event handlers
487
  # --------------
@@ -495,11 +490,6 @@ with gr.Blocks(
495
  return component.clear()
496
 
497
  compute_mode.change(
498
- # Update global COMPUTE variable
499
- set_compute,
500
- [compute_mode],
501
- api_name=False,
502
- ).then(
503
  # Change the app status text
504
  get_status_text,
505
  [compute_mode],
@@ -527,7 +517,7 @@ with gr.Blocks(
527
  input.submit(
528
  # Submit input to the chatbot
529
  to_workflow,
530
- [input, chatbot, thread_id],
531
  [chatbot, retrieved_emails, citations_text],
532
  api_name=False,
533
  )
@@ -661,6 +651,9 @@ with gr.Blocks(
661
  )
662
  # fmt: on
663
 
 
 
 
664
 
665
  if __name__ == "__main__":
666
 
 
18
  # Setup environment variables
19
  load_dotenv(dotenv_path=".env", override=True)
20
 
21
+ # Global setting for search type
 
22
  search_type = "hybrid"
23
 
24
+ # Global variables for LangChain graph: use dictionaries to store user-specific instances
25
+ # https://www.gradio.app/guides/state-in-blocks
26
+ graph_instances = {"local": {}, "remote": {}}
27
 
28
 
29
+ def cleanup_graph(request: gr.Request):
30
+ if request.session_hash in graph_instances["local"]:
31
+ del graph_instances["local"][request.session_hash]
32
+ print(f"Deleted local graph for session {request.session_hash}")
33
+ if request.session_hash in graph_instances["remote"]:
34
+ del graph_instances["remote"][request.session_hash]
35
+ print(f"Deleted remote graph for session {request.session_hash}")
36
+
37
+
38
+ def run_workflow(input, history, compute_mode, thread_id, session_hash):
39
  """The main function to run the chat workflow"""
40
 
41
+ # Error if user tries to run local mode without GPU
42
+ if compute_mode == "local":
 
 
 
43
  if not torch.cuda.is_available():
44
  raise gr.Error(
45
  "Local mode requires GPU. Please select remote mode.",
46
  print_exception=False,
47
  )
48
+
49
+ # Get graph for compute mode
50
+ graph = graph_instances[compute_mode].get(session_hash)
51
+ if graph is not None:
52
+ print(f"Get {compute_mode} graph for session {session_hash}")
53
 
54
  if graph is None:
55
  # Notify when we're loading the local model because it takes some time
56
+ if compute_mode == "local":
57
  gr.Info(
58
  f"Please wait for the local model to load",
59
  duration=15,
60
  title=f"Model loading...",
61
  )
62
  # Get the chat model and build the graph
63
+ chat_model = GetChatModel(compute_mode)
64
+ graph_builder = BuildGraph(chat_model, compute_mode, search_type)
65
  # Compile the graph with an in-memory checkpointer
66
  memory = MemorySaver()
67
  graph = graph_builder.compile(checkpointer=memory)
68
  # Set global graph for compute mode
69
+ graph_instances[compute_mode][session_hash] = graph
70
+ print(f"Set {compute_mode} graph for session {session_hash}")
71
+ # Notify when model finishes loading
72
+ gr.Success(f"{compute_mode}", duration=4, title=f"Model loaded")
 
 
 
 
73
 
74
  print(f"Using thread_id: {thread_id}")
75
 
 
183
  yield history, None, citations
184
 
185
 
186
+ def to_workflow(request: gr.Request, *args):
187
  """Wrapper function to call function with or without @spaces.GPU"""
188
+ compute_mode = args[2]
189
+ # Add session_hash to arguments
190
+ new_args = args + (request.session_hash,)
191
+ if compute_mode == "local":
192
+ for value in run_workflow_local(*new_args):
193
  yield value
194
+ if compute_mode == "remote":
195
+ for value in run_workflow_remote(*new_args):
196
  yield value
197
 
198
 
 
242
  "local",
243
  "remote",
244
  ],
245
+ value=("local" if torch.cuda.is_available() else "remote"),
246
  label="Compute Mode",
247
  info=(None if torch.cuda.is_available() else "NOTE: local mode requires GPU"),
248
  render=False,
 
450
  """Return updated value for a component"""
451
  return gr.update(value=value)
452
 
 
 
 
 
453
  def set_avatar(compute_mode):
454
  if compute_mode == "remote":
455
  image_file = "images/cloud.png"
 
477
  # Display the content in the textbox
478
  return content, change_visibility(True)
479
 
 
 
 
 
 
 
 
480
  # --------------
481
  # Event handlers
482
  # --------------
 
490
  return component.clear()
491
 
492
  compute_mode.change(
 
 
 
 
 
493
  # Change the app status text
494
  get_status_text,
495
  [compute_mode],
 
517
  input.submit(
518
  # Submit input to the chatbot
519
  to_workflow,
520
+ [input, chatbot, compute_mode, thread_id],
521
  [chatbot, retrieved_emails, citations_text],
522
  api_name=False,
523
  )
 
651
  )
652
  # fmt: on
653
 
654
+ # Clean up graph instances when page is closed/refreshed
655
+ demo.unload(cleanup_graph)
656
+
657
 
658
  if __name__ == "__main__":
659