DeanGumas commited on
Commit
50a9093
·
1 Parent(s): 7dc8863

Updated evaluation function to give tolerance for slight floating point differences

Browse files
Files changed (2) hide show
  1. test_pretrained.ipynb +60 -17
  2. train-data/sql_train.tsv +1 -1
test_pretrained.ipynb CHANGED
@@ -385,7 +385,7 @@
385
  },
386
  {
387
  "cell_type": "code",
388
- "execution_count": 17,
389
  "metadata": {},
390
  "outputs": [
391
  {
@@ -400,21 +400,24 @@
400
  "name": "stdout",
401
  "output_type": "stream",
402
  "text": [
403
- "How many games had at least one team with 30+ assists?\n",
404
- "SELECT COUNT(*) FROM game WHERE ast_home >= 30 OR ast_away >= 30;\n",
405
- "11305\n",
406
  "SQLite:\n",
407
- "SELECT COUNT(*) \n",
408
  "FROM game \n",
409
- "WHERE ast_home >= 30 OR ast_away >= 30;\n",
410
  "\n",
411
- "[(11305,)]\n",
412
- "SQL matched? True\n",
 
413
  "Result matched? True\n"
414
  ]
415
  }
416
  ],
417
  "source": [
 
 
418
  "def compare_result(sample_query, sample_result, query_output):\n",
419
  " # Clean model output to only have the query output\n",
420
  " if query_output[0:7] == \"SQLite:\":\n",
@@ -435,38 +438,77 @@
435
  " sample_query = sample_query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
436
  " query_match = (query == sample_query)\n",
437
  "\n",
 
 
 
 
438
  " # Check if this is a multi-line query\n",
439
  " if \"|\" in sample_result or \"(\" in sample_result:\n",
 
 
440
  " if \"(\" in sample_result:\n",
441
  " sample_result = sample_result.replace(\"(\", \"\").replace(\")\", \"\")\n",
442
  " result_list = sample_result.split(\",\") \n",
443
  " else:\n",
444
  " result_list = sample_result.split(\"|\") \n",
445
  "\n",
 
446
  " for i in range(len(result_list)):\n",
447
  " result_list[i] = str(result_list[i]).strip()\n",
 
 
448
  " result = False\n",
449
  " for row in rows:\n",
450
  " for r in row:\n",
451
- " if str(r) in result_list:\n",
452
- " return query_match, True\n",
 
 
 
 
 
 
 
453
  " if len(rows) == 1:\n",
454
  " for r in rows[0]:\n",
455
  " if r == str(len(result_list)):\n",
456
- " return query_match, True\n",
457
- " return query_match, result\n",
 
 
458
  " else:\n",
459
  " print(rows)\n",
460
  " result = False\n",
 
461
  " for row in rows:\n",
462
  " for r in row:\n",
 
463
  " if str(r) in str(sample_result):\n",
464
- " return query_match, True\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  "\n",
466
  " # Compare results and return\n",
467
- " return query_match, result\n",
468
  " except:\n",
469
- " return False, False\n",
470
  "\n",
471
  "# Obtain sample\n",
472
  "sample = df.sample(n=1)\n",
@@ -484,8 +526,9 @@
484
  "print(query_output)\n",
485
  "\n",
486
  "result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
487
- "print(\"SQL matched? \" + str(result[0]))\n",
488
- "print(\"Result matched? \" + str(result[1]))"
 
489
  ]
490
  }
491
  ],
 
385
  },
386
  {
387
  "cell_type": "code",
388
+ "execution_count": 65,
389
  "metadata": {},
390
  "outputs": [
391
  {
 
400
  "name": "stdout",
401
  "output_type": "stream",
402
  "text": [
403
+ "What is the total number of assists by the Chicago Bulls at home?\n",
404
+ "SELECT SUM(ast_home) as total_assists FROM game WHERE team_name_home = 'Chicago Bulls';\n",
405
+ "45090.0\n",
406
  "SQLite:\n",
407
+ "SELECT SUM(ast_home) \n",
408
  "FROM game \n",
409
+ "WHERE team_name_home = 'Chicago Bulls';\n",
410
  "\n",
411
+ "[(45090.0,)]\n",
412
+ "Statement valid? True\n",
413
+ "SQLite matched? False\n",
414
  "Result matched? True\n"
415
  ]
416
  }
417
  ],
418
  "source": [
419
+ "import math\n",
420
+ "\n",
421
  "def compare_result(sample_query, sample_result, query_output):\n",
422
  " # Clean model output to only have the query output\n",
423
  " if query_output[0:7] == \"SQLite:\":\n",
 
438
  " sample_query = sample_query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
439
  " query_match = (query == sample_query)\n",
440
  "\n",
441
+ " # If the queries match, the results clearly also match\n",
442
+ " if query_match:\n",
443
+ " return True, True, True\n",
444
+ "\n",
445
  " # Check if this is a multi-line query\n",
446
  " if \"|\" in sample_result or \"(\" in sample_result:\n",
447
+ " print(rows)\n",
448
+ " # Create list of results by stripping separators and splitting on them\n",
449
  " if \"(\" in sample_result:\n",
450
  " sample_result = sample_result.replace(\"(\", \"\").replace(\")\", \"\")\n",
451
  " result_list = sample_result.split(\",\") \n",
452
  " else:\n",
453
  " result_list = sample_result.split(\"|\") \n",
454
  "\n",
455
+ " # Strip all results in list\n",
456
  " for i in range(len(result_list)):\n",
457
  " result_list[i] = str(result_list[i]).strip()\n",
458
+ " \n",
459
+ " # Loop through model result and see if it matches training example\n",
460
  " result = False\n",
461
  " for row in rows:\n",
462
  " for r in row:\n",
463
+ " for res in result_list:\n",
464
+ " try:\n",
465
+ " if math.isclose(float(r), float(res), abs_tol=0.5):\n",
466
+ " return True, query_match, True\n",
467
+ " except:\n",
468
+ " if r in res or res in r:\n",
469
+ " return True, query_match, True\n",
470
+ " \n",
471
+ " # Check if the model returned a sum of examples as opposed to the whole thing\n",
472
  " if len(rows) == 1:\n",
473
  " for r in rows[0]:\n",
474
  " if r == str(len(result_list)):\n",
475
+ " return True, query_match, True\n",
476
+ " \n",
477
+ " return True, query_match, result\n",
478
+ " # Else the sample result is a single value or string\n",
479
  " else:\n",
480
  " print(rows)\n",
481
  " result = False\n",
482
+ " # Loop through model result and see if it contains the sample result\n",
483
  " for row in rows:\n",
484
  " for r in row:\n",
485
+ " # Check by string\n",
486
  " if str(r) in str(sample_result):\n",
487
+ " try:\n",
488
+ " if math.isclose(float(r), float(sample_result), abs_tol=0.5):\n",
489
+ " return True, query_match, True\n",
490
+ " except:\n",
491
+ " return True, query_match, True\n",
492
+ " # Check by number, using try incase the cast as float fails\n",
493
+ " try:\n",
494
+ " if math.isclose(float(r), float(sample_result), abs_tol=0.5):\n",
495
+ " return True, query_match, True\n",
496
+ " except:\n",
497
+ " pass\n",
498
+ "\n",
499
+ " # Check if the model returned a list of examples instead of a total sum (both acceptable)\n",
500
+ " try:\n",
501
+ " if len(rows) > 1 and len(rows) == int(sample_result):\n",
502
+ " return True, query_match, True\n",
503
+ " if len(rows[0]) > 1 and rows[0][1] is not None and len(rows[0]) == int(sample_result):\n",
504
+ " return True, query_match, True\n",
505
+ " except:\n",
506
+ " pass\n",
507
  "\n",
508
  " # Compare results and return\n",
509
+ " return True, query_match, result\n",
510
  " except:\n",
511
+ " return False, False, False\n",
512
  "\n",
513
  "# Obtain sample\n",
514
  "sample = df.sample(n=1)\n",
 
526
  "print(query_output)\n",
527
  "\n",
528
  "result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
529
+ "print(\"Statement valid? \" + str(result[0]))\n",
530
+ "print(\"SQLite matched? \" + str(result[1]))\n",
531
+ "print(\"Result matched? \" + str(result[2]))"
532
  ]
533
  }
534
  ],
train-data/sql_train.tsv CHANGED
@@ -476,7 +476,7 @@ How many away games did the Chicago Bulls play in the 2022 season? SELECT COUNT(
476
  How many home games did the Boston Celtics play in the 2018 season? SELECT COUNT(*) FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22018'; 41.0
477
  How many home games did the Boston Celtics play in the 2020 season? SELECT COUNT(*) FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22020'; 36.0
478
  What is the average number of fg_pct in home games by the Chicago Bulls? SELECT AVG(fg_pct_home) FROM game WHERE team_name_home = 'Chicago Bulls'; 0.4636694306246544
479
- In which season did the Los Angeles Lakers have the highest average ast at home? SELECT season_id, AVG(ast_home) as avg_stat FROM game WHERE team_name_home = 'Los Angeles Lakers' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 1969.0
480
  What is the average number of ft_pct in home games by the Los Angeles Lakers? SELECT AVG(ft_pct_home) FROM game WHERE team_name_home = 'Los Angeles Lakers'; 0.7450706106870195
481
  In which season did the Golden State Warriors have the highest average reb at home? SELECT season_id, AVG(reb_home) as avg_stat FROM game WHERE team_name_home = 'Golden State Warriors' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 1974.0
482
  How many away games did the Miami Heat play in the 1999 season? SELECT COUNT(*) FROM game WHERE team_name_away = 'Miami Heat' AND season_id = '21999'; 41.0
 
476
  How many home games did the Boston Celtics play in the 2018 season? SELECT COUNT(*) FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22018'; 41.0
477
  How many home games did the Boston Celtics play in the 2020 season? SELECT COUNT(*) FROM game WHERE team_name_home = 'Boston Celtics' AND season_id = '22020'; 36.0
478
  What is the average number of fg_pct in home games by the Chicago Bulls? SELECT AVG(fg_pct_home) FROM game WHERE team_name_home = 'Chicago Bulls'; 0.4636694306246544
479
+ In which season did the Los Angeles Lakers have the highest average ast at home? SELECT season_id, AVG(ast_home) as avg_stat FROM game WHERE team_name_home = 'Los Angeles Lakers' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 41969|36.6666666666667
480
  What is the average number of ft_pct in home games by the Los Angeles Lakers? SELECT AVG(ft_pct_home) FROM game WHERE team_name_home = 'Los Angeles Lakers'; 0.7450706106870195
481
  In which season did the Golden State Warriors have the highest average reb at home? SELECT season_id, AVG(reb_home) as avg_stat FROM game WHERE team_name_home = 'Golden State Warriors' GROUP BY season_id ORDER BY avg_stat DESC LIMIT 1; 1974.0
482
  How many away games did the Miami Heat play in the 1999 season? SELECT COUNT(*) FROM game WHERE team_name_away = 'Miami Heat' AND season_id = '21999'; 41.0