jedick commited on
Commit
355c5a2
·
1 Parent(s): 27b6f54

Revert model download before running workflow

Browse files
Files changed (5) hide show
  1. app.py +2 -5
  2. graph.py +0 -3
  3. main.py +0 -8
  4. mods/tool_calling_llm.py +4 -0
  5. retriever.py +3 -3
app.py CHANGED
@@ -4,7 +4,7 @@ from graph import BuildGraph
4
  from retriever import db_dir
5
  from langgraph.checkpoint.memory import MemorySaver
6
  from dotenv import load_dotenv
7
- from main import openai_model, model_id, DownloadChatModel
8
  from util import get_sources, get_start_end_months
9
  from mods.tool_calling_llm import extract_think
10
  import requests
@@ -82,6 +82,7 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
82
  if compute_mode == "local":
83
  gr.Info(
84
  f"Please wait for the local model to load",
 
85
  title=f"Model loading...",
86
  )
87
  # Get the chat model and build the graph
@@ -210,10 +211,6 @@ def to_workflow(request: gr.Request, *args):
210
  # Add session_hash to arguments
211
  new_args = args + (request.session_hash,)
212
  if compute_mode == "local":
213
- # If graph hasn't been instantiated, download model before running workflow
214
- graph = graph_instances[compute_mode].get(request.session_hash)
215
- if graph is None:
216
- DownloadChatModel()
217
  # Call the workflow function with the @spaces.GPU decorator
218
  for value in run_workflow_local(*new_args):
219
  yield value
 
4
  from retriever import db_dir
5
  from langgraph.checkpoint.memory import MemorySaver
6
  from dotenv import load_dotenv
7
+ from main import openai_model, model_id
8
  from util import get_sources, get_start_end_months
9
  from mods.tool_calling_llm import extract_think
10
  import requests
 
82
  if compute_mode == "local":
83
  gr.Info(
84
  f"Please wait for the local model to load",
85
+ duration=15,
86
  title=f"Model loading...",
87
  )
88
  # Get the chat model and build the graph
 
211
  # Add session_hash to arguments
212
  new_args = args + (request.session_hash,)
213
  if compute_mode == "local":
 
 
 
 
214
  # Call the workflow function with the @spaces.GPU decorator
215
  for value in run_workflow_local(*new_args):
216
  yield value
graph.py CHANGED
@@ -12,9 +12,6 @@ from retriever import BuildRetriever
12
  from prompts import query_prompt, generate_prompt, generic_tools_template
13
  from mods.tool_calling_llm import ToolCallingLLM
14
 
15
- # Local modules
16
- from retriever import BuildRetriever
17
-
18
  # For tracing (disabled)
19
  # os.environ["LANGSMITH_TRACING"] = "true"
20
  # os.environ["LANGSMITH_PROJECT"] = "R-help-chat"
 
12
  from prompts import query_prompt, generate_prompt, generic_tools_template
13
  from mods.tool_calling_llm import ToolCallingLLM
14
 
 
 
 
15
  # For tracing (disabled)
16
  # os.environ["LANGSMITH_TRACING"] = "true"
17
  # os.environ["LANGSMITH_PROJECT"] = "R-help-chat"
main.py CHANGED
@@ -5,7 +5,6 @@ from langchain_core.output_parsers import StrOutputParser
5
  from langgraph.checkpoint.memory import MemorySaver
6
  from langchain_core.messages import ToolMessage
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
- from huggingface_hub import snapshot_download
9
  from datetime import datetime
10
  from dotenv import load_dotenv
11
  import os
@@ -129,13 +128,6 @@ def ProcessDirectory(path, compute_mode):
129
  print(f"Chroma: no change for {file_path}")
130
 
131
 
132
- def DownloadChatModel():
133
- """
134
- Downloads a chat model to the local Hugging Face cache.
135
- """
136
- snapshot_download(model_id)
137
-
138
-
139
  def GetChatModel(compute_mode):
140
  """
141
  Get a chat model.
 
5
  from langgraph.checkpoint.memory import MemorySaver
6
  from langchain_core.messages import ToolMessage
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
8
  from datetime import datetime
9
  from dotenv import load_dotenv
10
  import os
 
128
  print(f"Chroma: no change for {file_path}")
129
 
130
 
 
 
 
 
 
 
 
131
  def GetChatModel(compute_mode):
132
  """
133
  Get a chat model.
mods/tool_calling_llm.py CHANGED
@@ -177,6 +177,10 @@ class ToolCallingLLM(BaseChatModel, ABC):
177
  # Extract <think>...</think> content and text after </think> for further processing 20250726 jmd
178
  think_text, post_think = extract_think(response_message.content)
179
 
 
 
 
 
180
  # Parse output for JSON (support multiple objects separated by commas)
181
  try:
182
  parsed_json_results = json.loads(f"[{post_think}]")
 
177
  # Extract <think>...</think> content and text after </think> for further processing 20250726 jmd
178
  think_text, post_think = extract_think(response_message.content)
179
 
180
+ ## For debugging
181
+ # print("post_think")
182
+ # print(post_think)
183
+
184
  # Parse output for JSON (support multiple objects separated by commas)
185
  try:
186
  parsed_json_results = json.loads(f"[{post_think}]")
retriever.py CHANGED
@@ -174,9 +174,9 @@ def BuildRetrieverDense(compute_mode: str, top_k=6):
174
  # Get top k documents
175
  search_kwargs={"k": top_k},
176
  )
177
- ## Release GPU memory
178
- ## https://github.com/langchain-ai/langchain/discussions/10668
179
- # torch.cuda.empty_cache()
180
  return retriever
181
 
182
 
 
174
  # Get top k documents
175
  search_kwargs={"k": top_k},
176
  )
177
+ # Fix for ValueError('Could not connect to tenant default_tenant. Are you sure it exists?')
178
+ # https://github.com/langchain-ai/langchain/issues/26884
179
+ chromadb.api.client.SharedSystemClient.clear_system_cache()
180
  return retriever
181
 
182