brian-yu-nexusflow commited on
Commit
64d7e88
β€’
1 Parent(s): daa4f11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -2
app.py CHANGED
@@ -10,6 +10,8 @@ from time import sleep
10
 
11
  import inspect
12
 
 
 
13
  from random import randint
14
 
15
  from urllib.parse import quote
@@ -179,6 +181,7 @@ class RavenDemo(gr.Blocks):
179
  self.summary_model_client = InferenceClient(config.summary_model_endpoint)
180
 
181
  self.max_num_steps = 20
 
182
 
183
  with self:
184
  gr.HTML(HEADER_HTML)
@@ -299,6 +302,10 @@ class RavenDemo(gr.Blocks):
299
  *steps,
300
  )
301
 
 
 
 
 
302
  user_input = gr.Textbox(interactive=False)
303
  raven_function_call = ""
304
  summary_model_summary = ""
@@ -307,7 +314,8 @@ class RavenDemo(gr.Blocks):
307
  gmaps_html = ""
308
  steps_accordion = gr.Accordion(open=True)
309
  steps = [gr.Textbox(value="", visible=False) for _ in range(self.max_num_steps)]
310
- yield get_returns()
 
311
 
312
  raven_prompt = self.functions_helper.get_prompt(
313
  query.replace("'", r"\'").replace('"', r"\"")
@@ -328,7 +336,18 @@ class RavenDemo(gr.Blocks):
328
  r_calls = [c.strip() for c in raven_function_call.split(";") if c.strip()]
329
  f_r_calls = []
330
  for r_c in r_calls:
331
- f_r_call = format_str(r_c.strip(), mode=Mode())
 
 
 
 
 
 
 
 
 
 
 
332
  f_r_calls.append(f_r_call)
333
 
334
  raven_function_call = "; ".join(f_r_calls)
@@ -424,6 +443,21 @@ class RavenDemo(gr.Blocks):
424
  user_input = gr.Textbox(interactive=True, autofocus=False)
425
  yield get_returns()
426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
  def get_summary_model_prompt(self, results: List, query: str) -> None:
428
  # TODO check what outputs are returned and return them properly
429
  ALLOWED_KEYS = [
 
10
 
11
  import inspect
12
 
13
+ import ast
14
+
15
  from random import randint
16
 
17
  from urllib.parse import quote
 
181
  self.summary_model_client = InferenceClient(config.summary_model_endpoint)
182
 
183
  self.max_num_steps = 20
184
+ self.function_call_name_set = set([f.name for f in FUNCTIONS])
185
 
186
  with self:
187
  gr.HTML(HEADER_HTML)
 
302
  *steps,
303
  )
304
 
305
+ def on_error():
306
+ initial_return[0] = gr.Textbox(interactive=True, autofocus=False)
307
+ return initial_return
308
+
309
  user_input = gr.Textbox(interactive=False)
310
  raven_function_call = ""
311
  summary_model_summary = ""
 
314
  gmaps_html = ""
315
  steps_accordion = gr.Accordion(open=True)
316
  steps = [gr.Textbox(value="", visible=False) for _ in range(self.max_num_steps)]
317
+ initial_return = list(get_returns())
318
+ yield initial_return
319
 
320
  raven_prompt = self.functions_helper.get_prompt(
321
  query.replace("'", r"\'").replace('"', r"\"")
 
336
  r_calls = [c.strip() for c in raven_function_call.split(";") if c.strip()]
337
  f_r_calls = []
338
  for r_c in r_calls:
339
+ try:
340
+ f_r_call = format_str(r_c.strip(), mode=Mode())
341
+ except:
342
+ yield on_error()
343
+ gr.Warning(ERROR_MESSAGE)
344
+ return
345
+
346
+ if not self.whitelist_function_names(f_r_call):
347
+ yield on_error()
348
+ gr.Warning(ERROR_MESSAGE)
349
+ return
350
+
351
  f_r_calls.append(f_r_call)
352
 
353
  raven_function_call = "; ".join(f_r_calls)
 
443
  user_input = gr.Textbox(interactive=True, autofocus=False)
444
  yield get_returns()
445
 
446
+ def whitelist_function_names(self, function_call_str: str) -> bool:
447
+ """
448
+ Defensive function name whitelisting inspired by @evan-nexusflow
449
+ """
450
+ for expr in ast.walk(ast.parse(function_call_str)):
451
+ if not isinstance(expr, ast.Call):
452
+ continue
453
+
454
+ expr: ast.Call
455
+ function_name = expr.func.id
456
+ if function_name not in self.function_call_name_set:
457
+ return False
458
+
459
+ return True
460
+
461
  def get_summary_model_prompt(self, results: List, query: str) -> None:
462
  # TODO check what outputs are returned and return them properly
463
  ALLOWED_KEYS = [