jedick commited on
Commit
b42e964
·
1 Parent(s): f6e2d8a

Improve parsing of JSON for tool calls

Browse files
Files changed (4) hide show
  1. app.py +3 -2
  2. data.py +5 -3
  3. main.py +1 -1
  4. mods/tool_calling_llm.py +9 -5
app.py CHANGED
@@ -30,10 +30,10 @@ print(f"Using embedding checkpoints from {embedding_ckpt_dir}")
30
 
31
  # Download and extract data if data directory is not present
32
  if not os.path.isdir(db_dir):
33
- print("Downloading data ... ", end = "")
34
  download_data()
35
  print("done!")
36
- print("Extracting data ... ", end = "")
37
  extract_data()
38
  print("done!")
39
 
@@ -44,6 +44,7 @@ search_type = "hybrid"
44
  # https://www.gradio.app/guides/state-in-blocks
45
  graph_instances = {"local": {}, "remote": {}}
46
 
 
47
  def cleanup_graph(request: gr.Request):
48
  if request.session_hash in graph_instances["local"]:
49
  del graph_instances["local"][request.session_hash]
 
30
 
31
  # Download and extract data if data directory is not present
32
  if not os.path.isdir(db_dir):
33
+ print("Downloading data ... ", end="")
34
  download_data()
35
  print("done!")
36
+ print("Extracting data ... ", end="")
37
  extract_data()
38
  print("done!")
39
 
 
44
  # https://www.gradio.app/guides/state-in-blocks
45
  graph_instances = {"local": {}, "remote": {}}
46
 
47
+
48
  def cleanup_graph(request: gr.Request):
49
  if request.session_hash in graph_instances["local"]:
50
  del graph_instances["local"][request.session_hash]
data.py CHANGED
@@ -4,6 +4,7 @@ import shutil
4
  import boto3
5
  import os
6
 
 
7
  def download_file_from_bucket(bucket_name, s3_key, output_file):
8
  """Download file from S3 bucket"""
9
 
@@ -17,6 +18,7 @@ def download_file_from_bucket(bucket_name, s3_key, output_file):
17
  bucket = s3_resource.Bucket(bucket_name)
18
  bucket.download_file(Key=s3_key, Filename=output_file)
19
 
 
20
  def download_dropbox_file(shared_url, output_file):
21
  """Download file from Dropbox"""
22
 
@@ -35,9 +37,8 @@ def download_dropbox_file(shared_url, output_file):
35
  file.write(chunk)
36
  print(f"File downloaded successfully as '{output_file}'")
37
  else:
38
- print(
39
- f"Failed to download file. HTTP Status Code: {response.status_code}"
40
- )
41
 
42
  def download_data():
43
  """Download the email database"""
@@ -50,6 +51,7 @@ def download_data():
50
  # output_filename = "db.zip"
51
  # download_dropbox_file(shared_link, output_filename)
52
 
 
53
  def extract_data():
54
  """Extract the db.zip file"""
55
 
 
4
  import boto3
5
  import os
6
 
7
+
8
  def download_file_from_bucket(bucket_name, s3_key, output_file):
9
  """Download file from S3 bucket"""
10
 
 
18
  bucket = s3_resource.Bucket(bucket_name)
19
  bucket.download_file(Key=s3_key, Filename=output_file)
20
 
21
+
22
  def download_dropbox_file(shared_url, output_file):
23
  """Download file from Dropbox"""
24
 
 
37
  file.write(chunk)
38
  print(f"File downloaded successfully as '{output_file}'")
39
  else:
40
+ print(f"Failed to download file. HTTP Status Code: {response.status_code}")
41
+
 
42
 
43
  def download_data():
44
  """Download the email database"""
 
51
  # output_filename = "db.zip"
52
  # download_dropbox_file(shared_link, output_filename)
53
 
54
+
55
  def extract_data():
56
  """Extract the db.zip file"""
57
 
main.py CHANGED
@@ -41,7 +41,7 @@ model_id = os.getenv("MODEL_ID")
41
  if model_id is None:
42
  # model_id = "HuggingFaceTB/SmolLM3-3B"
43
  # model_id = "google/gemma-3-12b-it"
44
- model_id = "Qwen/Qwen3-8B"
45
 
46
  # Suppress these messages:
47
  # INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
 
41
  if model_id is None:
42
  # model_id = "HuggingFaceTB/SmolLM3-3B"
43
  # model_id = "google/gemma-3-12b-it"
44
+ model_id = "Qwen/Qwen3-14B"
45
 
46
  # Suppress these messages:
47
  # INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings "HTTP/1.1 200 OK"
mods/tool_calling_llm.py CHANGED
@@ -181,16 +181,20 @@ class ToolCallingLLM(BaseChatModel, ABC):
181
  # print("post_think")
182
  # print(post_think)
183
 
 
 
184
  # Parse output for JSON (support multiple objects separated by commas)
185
  try:
186
- # Works for one or more JSON objects not enclosed in "[]"
187
- parsed_json_results = json.loads(f"[{post_think}]")
 
 
188
  except:
189
  try:
190
- # Works for one or more JSON objects already enclosed in "[]"
191
- parsed_json_results = json.loads(f"{post_think}")
192
  except json.JSONDecodeError:
193
- # Return entire response if JSON wasn't parsed (or is missing)
194
  return AIMessage(content=response_message.content)
195
 
196
  # print("parsed_json_results")
 
181
  # print("post_think")
182
  # print(post_think)
183
 
184
+ # Remove trailing comma (if there is one)
185
+ post_think = post_think.rstrip(",")
186
  # Parse output for JSON (support multiple objects separated by commas)
187
  try:
188
+ # Works for one JSON object, or multiple JSON objects enclosed in "[]"
189
+ parsed_json_results = json.loads(f"{post_think}")
190
+ if not isinstance(parsed_json_results, list):
191
+ parsed_json_results = [parsed_json_results]
192
  except:
193
  try:
194
+ # Works for multiple JSON objects not enclosed in "[]"
195
+ parsed_json_results = json.loads(f"[{post_think}]")
196
  except json.JSONDecodeError:
197
+ # Return entire response if JSON wasn't parsed or is missing
198
  return AIMessage(content=response_message.content)
199
 
200
  # print("parsed_json_results")