Updated evaluation function to give tolerance for slight floating point differences
Browse files- test_pretrained.ipynb +60 -17
- train-data/sql_train.tsv +1 -1
test_pretrained.ipynb
CHANGED
@@ -385,7 +385,7 @@
|
|
385 |
},
|
386 |
{
|
387 |
"cell_type": "code",
|
388 |
-
"execution_count":
|
389 |
"metadata": {},
|
390 |
"outputs": [
|
391 |
{
|
@@ -400,21 +400,24 @@
|
|
400 |
"name": "stdout",
|
401 |
"output_type": "stream",
|
402 |
"text": [
|
403 |
-
"
|
404 |
-
"SELECT
|
405 |
-
"
|
406 |
"SQLite:\n",
|
407 |
-
"SELECT
|
408 |
"FROM game \n",
|
409 |
-
"WHERE
|
410 |
"\n",
|
411 |
-
"[(
|
412 |
-
"
|
|
|
413 |
"Result matched? True\n"
|
414 |
]
|
415 |
}
|
416 |
],
|
417 |
"source": [
|
|
|
|
|
418 |
"def compare_result(sample_query, sample_result, query_output):\n",
|
419 |
" # Clean model output to only have the query output\n",
|
420 |
" if query_output[0:7] == \"SQLite:\":\n",
|
@@ -435,38 +438,77 @@
|
|
435 |
" sample_query = sample_query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
|
436 |
" query_match = (query == sample_query)\n",
|
437 |
"\n",
|
|
|
|
|
|
|
|
|
438 |
" # Check if this is a multi-line query\n",
|
439 |
" if \"|\" in sample_result or \"(\" in sample_result:\n",
|
|
|
|
|
440 |
" if \"(\" in sample_result:\n",
|
441 |
" sample_result = sample_result.replace(\"(\", \"\").replace(\")\", \"\")\n",
|
442 |
" result_list = sample_result.split(\",\") \n",
|
443 |
" else:\n",
|
444 |
" result_list = sample_result.split(\"|\") \n",
|
445 |
"\n",
|
|
|
446 |
" for i in range(len(result_list)):\n",
|
447 |
" result_list[i] = str(result_list[i]).strip()\n",
|
|
|
|
|
448 |
" result = False\n",
|
449 |
" for row in rows:\n",
|
450 |
" for r in row:\n",
|
451 |
-
"
|
452 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
" if len(rows) == 1:\n",
|
454 |
" for r in rows[0]:\n",
|
455 |
" if r == str(len(result_list)):\n",
|
456 |
-
" return query_match, True\n",
|
457 |
-
"
|
|
|
|
|
458 |
" else:\n",
|
459 |
" print(rows)\n",
|
460 |
" result = False\n",
|
|
|
461 |
" for row in rows:\n",
|
462 |
" for r in row:\n",
|
|
|
463 |
" if str(r) in str(sample_result):\n",
|
464 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
"\n",
|
466 |
" # Compare results and return\n",
|
467 |
-
" return query_match, result\n",
|
468 |
" except:\n",
|
469 |
-
" return False, False\n",
|
470 |
"\n",
|
471 |
"# Obtain sample\n",
|
472 |
"sample = df.sample(n=1)\n",
|
@@ -484,8 +526,9 @@
|
|
484 |
"print(query_output)\n",
|
485 |
"\n",
|
486 |
"result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
|
487 |
-
"print(\"
|
488 |
-
"print(\"
|
|
|
489 |
]
|
490 |
}
|
491 |
],
|
|
|
385 |
},
|
386 |
{
|
387 |
"cell_type": "code",
|
388 |
+
"execution_count": 65,
|
389 |
"metadata": {},
|
390 |
"outputs": [
|
391 |
{
|
|
|
400 |
"name": "stdout",
|
401 |
"output_type": "stream",
|
402 |
"text": [
|
403 |
+
"What is the total number of assists by the Chicago Bulls at home?\n",
|
404 |
+
"SELECT SUM(ast_home) as total_assists FROM game WHERE team_name_home = 'Chicago Bulls';\n",
|
405 |
+
"45090.0\n",
|
406 |
"SQLite:\n",
|
407 |
+
"SELECT SUM(ast_home) \n",
|
408 |
"FROM game \n",
|
409 |
+
"WHERE team_name_home = 'Chicago Bulls';\n",
|
410 |
"\n",
|
411 |
+
"[(45090.0,)]\n",
|
412 |
+
"Statement valid? True\n",
|
413 |
+
"SQLite matched? False\n",
|
414 |
"Result matched? True\n"
|
415 |
]
|
416 |
}
|
417 |
],
|
418 |
"source": [
|
419 |
+
"import math\n",
|
420 |
+
"\n",
|
421 |
"def compare_result(sample_query, sample_result, query_output):\n",
|
422 |
" # Clean model output to only have the query output\n",
|
423 |
" if query_output[0:7] == \"SQLite:\":\n",
|
|
|
438 |
" sample_query = sample_query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
|
439 |
" query_match = (query == sample_query)\n",
|
440 |
"\n",
|
441 |
+
" # If the queries match, the results clearly also match\n",
|
442 |
+
" if query_match:\n",
|
443 |
+
" return True, True, True\n",
|
444 |
+
"\n",
|
445 |
" # Check if this is a multi-line query\n",
|
446 |
" if \"|\" in sample_result or \"(\" in sample_result:\n",
|
447 |
+
" print(rows)\n",
|
448 |
+
" # Create list of results by stripping separators and splitting on them\n",
|
449 |
" if \"(\" in sample_result:\n",
|
450 |
" sample_result = sample_result.replace(\"(\", \"\").replace(\")\", \"\")\n",
|
451 |
" result_list = sample_result.split(\",\") \n",
|
452 |
" else:\n",
|
453 |
" result_list = sample_result.split(\"|\") \n",
|
454 |
"\n",
|
455 |
+
" # Strip all results in list\n",
|
456 |
" for i in range(len(result_list)):\n",
|
457 |
" result_list[i] = str(result_list[i]).strip()\n",
|
458 |
+
" \n",
|
459 |
+
" # Loop through model result and see if it matches training example\n",
|
460 |
" result = False\n",
|
461 |
" for row in rows:\n",
|
462 |
" for r in row:\n",
|
463 |
+
" for res in result_list:\n",
|
464 |
+
" try:\n",
|
465 |
+
" if math.isclose(float(r), float(res), abs_tol=0.5):\n",
|
466 |
+
" return True, query_match, True\n",
|
467 |
+
" except:\n",
|
468 |
+
" if r in res or res in r:\n",
|
469 |
+
" return True, query_match, True\n",
|
470 |
+
" \n",
|
471 |
+
" # Check if the model returned a sum of examples as opposed to the whole thing\n",
|
472 |
" if len(rows) == 1:\n",
|
473 |
" for r in rows[0]:\n",
|
474 |
" if r == str(len(result_list)):\n",
|
475 |
+
" return True, query_match, True\n",
|
476 |
+
" \n",
|
477 |
+
" return True, query_match, result\n",
|
478 |
+
" # Else the sample result is a single value or string\n",
|
479 |
" else:\n",
|
480 |
" print(rows)\n",
|
481 |
" result = False\n",
|
482 |
+
" # Loop through model result and see if it contains the sample result\n",
|
483 |
" for row in rows:\n",
|
484 |
" for r in row:\n",
|
485 |
+
" # Check by string\n",
|
486 |
" if str(r) in str(sample_result):\n",
|
487 |
+
" try:\n",
|
488 |
+
" if math.isclose(float(r), float(sample_result), abs_tol=0.5):\n",
|
489 |
+
" return True, query_match, True\n",
|
490 |
+
" except:\n",
|
491 |
+
" return True, query_match, True\n",
|
492 |
+
" # Check by number, using try incase the cast as float fails\n",
|
493 |
+
" try:\n",
|
494 |
+
" if math.isclose(float(r), float(sample_result), abs_tol=0.5):\n",
|
495 |
+
" return True, query_match, True\n",
|
496 |
+
" except:\n",
|
497 |
+
" pass\n",
|
498 |
+
"\n",
|
499 |
+
" # Check if the model returned a list of examples instead of a total sum (both acceptable)\n",
|
500 |
+
" try:\n",
|
501 |
+
" if len(rows) > 1 and len(rows) == int(sample_result):\n",
|
502 |
+
" return True, query_match, True\n",
|
503 |
+
" if len(rows[0]) > 1 and rows[0][1] is not None and len(rows[0]) == int(sample_result):\n",
|
504 |
+
" return True, query_match, True\n",
|
505 |
+
" except:\n",
|
506 |
+
" pass\n",
|
507 |
"\n",
|
508 |
" # Compare results and return\n",
|
509 |
+
" return True, query_match, result\n",
|
510 |
" except:\n",
|
511 |
+
" return False, False, False\n",
|
512 |
"\n",
|
513 |
"# Obtain sample\n",
|
514 |
"sample = df.sample(n=1)\n",
|
|
|
526 |
"print(query_output)\n",
|
527 |
"\n",
|
528 |
"result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
|
529 |
+
"print(\"Statement valid? \" + str(result[0]))\n",
|
530 |
+
"print(\"SQLite matched? \" + str(result[1]))\n",
|
531 |
+
"print(\"Result matched? \" + str(result[2]))"
|
532 |
]
|
533 |
}
|
534 |
],
|
train-data/sql_train.tsv
CHANGED
@@ -476,7 +476,7 @@ How many away games did the Chicago Bulls play in the 2022 season? SELECT COUNT(
|
|
476 |
How many home games did the Boston Celtics play in the 2018 season? SELECT COUNT(*) FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22018'; 41.0
|
477 |
How many home games did the Boston Celtics play in the 2020 season? SELECT COUNT(*) FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22020'; 36.0
|
478 |
What is the average number of fg_pct in home games by the Chicago Bulls? SELECT AVG(fg_pct_home) FROM game WHERE team_name_home = 'Chicago Bulls'; 0.4636694306246544
|
479 |
-
In which season did the Los Angeles Lakers have the highest average ast at home? SELECT season_id, AVG(ast_home) as avg_stat FROM game WHERE team_name_home = 'Los Angeles Lakers' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1;
|
480 |
What is the average number of ft_pct in home games by the Los Angeles Lakers? SELECT AVG(ft_pct_home) FROM game WHERE team_name_home = 'Los Angeles Lakers'; 0.7450706106870195
|
481 |
In which season did the Golden State Warriors have the highest average reb at home? SELECT season_id, AVG(reb_home) as avg_stat FROM game WHERE team_name_home = 'Golden State Warriors' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 1974.0
|
482 |
How many away games did the Miami Heat play in the 1999 season? SELECT COUNT(*) FROM game WHERE team_name_away = 'Miami Heat' AND season_id = '21999'; 41.0
|
|
|
476 |
How many home games did the Boston Celtics play in the 2018 season? SELECT COUNT(*) FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22018'; 41.0
|
477 |
How many home games did the Boston Celtics play in the 2020 season? SELECT COUNT(*) FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22020'; 36.0
|
478 |
What is the average number of fg_pct in home games by the Chicago Bulls? SELECT AVG(fg_pct_home) FROM game WHERE team_name_home = 'Chicago Bulls'; 0.4636694306246544
|
479 |
+
In which season did the Los Angeles Lakers have the highest average ast at home? SELECT season_id, AVG(ast_home) as avg_stat FROM game WHERE team_name_home = 'Los Angeles Lakers' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 41969|36.6666666666667
|
480 |
What is the average number of ft_pct in home games by the Los Angeles Lakers? SELECT AVG(ft_pct_home) FROM game WHERE team_name_home = 'Los Angeles Lakers'; 0.7450706106870195
|
481 |
In which season did the Golden State Warriors have the highest average reb at home? SELECT season_id, AVG(reb_home) as avg_stat FROM game WHERE team_name_home = 'Golden State Warriors' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 1974.0
|
482 |
How many away games did the Miami Heat play in the 1999 season? SELECT COUNT(*) FROM game WHERE team_name_away = 'Miami Heat' AND season_id = '21999'; 41.0
|