DeanGumas commited on
Commit
9f2b199
·
1 Parent(s): 88abe86

Created evaluation loop for running on full dataframes

Browse files
Files changed (1) hide show
  1. test_pretrained.ipynb +106 -64
test_pretrained.ipynb CHANGED
@@ -26,14 +26,16 @@
26
  "Total dataset examples: 1044\n",
27
  "\n",
28
  "\n",
29
- "What was the largest lead the Golden State Warriors had in a game during the 2018 season?\n",
30
- "SELECT MAX(other_stats.largest_lead_home) FROM other_stats JOIN game ON other_stats.game_id = game.game_id WHERE game.team_name_home = 'Golden State Warriors' AND game.season_id = '22018';\n",
31
- "44\n"
32
  ]
33
  }
34
  ],
35
  "source": [
36
  "import pandas as pd \n",
 
 
37
  "\n",
38
  "# Load dataset and check length\n",
39
  "df = pd.read_csv(\"./train-data/sql_train.tsv\", sep='\\t')\n",
@@ -58,16 +60,7 @@
58
  "cell_type": "code",
59
  "execution_count": 2,
60
  "metadata": {},
61
- "outputs": [
62
- {
63
- "name": "stderr",
64
- "output_type": "stream",
65
- "text": [
66
- "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
67
- " from .autonotebook import tqdm as notebook_tqdm\n"
68
- ]
69
- }
70
- ],
71
  "source": [
72
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
73
  "import torch\n",
@@ -77,7 +70,8 @@
77
  "\n",
78
  "# Load model and tokenizer\n",
79
  "tokenizer = AutoTokenizer.from_pretrained(\"./deepseek-coder-1.3b-instruct\")\n",
80
- "model = AutoModelForCausalLM.from_pretrained(\"./deepseek-coder-1.3b-instruct\", torch_dtype=torch.bfloat16, device_map=device) "
 
81
  ]
82
  },
83
  {
@@ -288,27 +282,15 @@
288
  "execution_count": 4,
289
  "metadata": {},
290
  "outputs": [
291
- {
292
- "name": "stderr",
293
- "output_type": "stream",
294
- "text": [
295
- "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\generation\\configuration_utils.py:634: UserWarning: `do_sample` is set to `False`. However, `top_p` is set to `0.95` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `top_p`.\n",
296
- " warnings.warn(\n",
297
- "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
298
- "Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.\n",
299
- "The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
300
- "c:\\Users\\Dean\\AppData\\Local\\Programs\\Python\\Python312\\Lib\\site-packages\\transformers\\integrations\\sdpa_attention.py:53: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\transformers\\cuda\\sdp_utils.cpp:555.)\n",
301
- " attn_output = torch.nn.functional.scaled_dot_product_attention(\n"
302
- ]
303
- },
304
  {
305
  "name": "stdout",
306
  "output_type": "stream",
307
  "text": [
308
  "SQLite:\n",
309
- "SELECT MAX(largest_lead_home) \n",
310
- "FROM other_stats \n",
311
- "WHERE team_name_home = 'Golden State Warriors' AND season_id = '22018';\n",
 
312
  "\n"
313
  ]
314
  }
@@ -340,18 +322,8 @@
340
  "name": "stdout",
341
  "output_type": "stream",
342
  "text": [
343
- "cleaned\n"
344
- ]
345
- },
346
- {
347
- "ename": "OperationalError",
348
- "evalue": "no such column: team_name_home",
349
- "output_type": "error",
350
- "traceback": [
351
- "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
352
- "\u001b[1;31mOperationalError\u001b[0m Traceback (most recent call last)",
353
- "Cell \u001b[1;32mIn[5], line 15\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m 14\u001b[0m query \u001b[38;5;241m=\u001b[39m query_output\n\u001b[1;32m---> 15\u001b[0m \u001b[43mcursor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mquery\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 16\u001b[0m rows \u001b[38;5;241m=\u001b[39m cursor\u001b[38;5;241m.\u001b[39mfetchall()\n\u001b[0;32m 17\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m row \u001b[38;5;129;01min\u001b[39;00m rows:\n",
354
- "\u001b[1;31mOperationalError\u001b[0m: no such column: team_name_home"
355
  ]
356
  }
357
  ],
@@ -370,10 +342,14 @@
370
  " query = query_output[4:]\n",
371
  "else:\n",
372
  " query = query_output\n",
373
- "cursor.execute(query)\n",
374
- "rows = cursor.fetchall()\n",
375
- "for row in rows:\n",
376
- " print(row)"
 
 
 
 
377
  ]
378
  },
379
  {
@@ -385,30 +361,22 @@
385
  },
386
  {
387
  "cell_type": "code",
388
- "execution_count": 65,
389
  "metadata": {},
390
  "outputs": [
391
- {
392
- "name": "stderr",
393
- "output_type": "stream",
394
- "text": [
395
- "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n",
396
- "Setting `pad_token_id` to `eos_token_id`:32021 for open-end generation.\n"
397
- ]
398
- },
399
  {
400
  "name": "stdout",
401
  "output_type": "stream",
402
  "text": [
403
- "What is the total number of assists by the Chicago Bulls at home?\n",
404
- "SELECT SUM(ast_home) as total_assists FROM game WHERE team_name_home = 'Chicago Bulls';\n",
405
- "45090.0\n",
406
  "SQLite:\n",
407
- "SELECT SUM(ast_home) \n",
408
- "FROM game \n",
409
- "WHERE team_name_home = 'Chicago Bulls';\n",
 
410
  "\n",
411
- "[(45090.0,)]\n",
412
  "Statement valid? True\n",
413
  "SQLite matched? False\n",
414
  "Result matched? True\n"
@@ -444,7 +412,7 @@
444
  "\n",
445
  " # Check if this is a multi-line query\n",
446
  " if \"|\" in sample_result or \"(\" in sample_result:\n",
447
- " print(rows)\n",
448
  " # Create list of results by stripping separators and splitting on them\n",
449
  " if \"(\" in sample_result:\n",
450
  " sample_result = sample_result.replace(\"(\", \"\").replace(\")\", \"\")\n",
@@ -477,7 +445,7 @@
477
  " return True, query_match, result\n",
478
  " # Else the sample result is a single value or string\n",
479
  " else:\n",
480
- " print(rows)\n",
481
  " result = False\n",
482
  " # Loop through model result and see if it contains the sample result\n",
483
  " for row in rows:\n",
@@ -530,6 +498,80 @@
530
  "print(\"SQLite matched? \" + str(result[1]))\n",
531
  "print(\"Result matched? \" + str(result[2]))"
532
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  }
534
  ],
535
  "metadata": {
 
26
  "Total dataset examples: 1044\n",
27
  "\n",
28
  "\n",
29
+ "What was the combined rebound total for the Toronto Raptors and Brooklyn Nets in their highest scoring game against each other?\n",
30
+ "SELECT MAX(g.pts_home + g.pts_away) AS total_points, g.reb_home + g.reb_away AS total_rebounds FROM game g WHERE (g.team_name_home = 'Toronto Raptors' AND g.team_name_away = 'Brooklyn Nets') OR (g.team_name_home = 'Brooklyn Nets' AND g.team_name_away = 'Toronto Raptors') ORDER BY total_points DESC LIMIT 1;\n",
31
+ "272.0 | 101.0 \n"
32
  ]
33
  }
34
  ],
35
  "source": [
36
  "import pandas as pd \n",
37
+ "import warnings\n",
38
+ "warnings.filterwarnings(\"ignore\")\n",
39
  "\n",
40
  "# Load dataset and check length\n",
41
  "df = pd.read_csv(\"./train-data/sql_train.tsv\", sep='\\t')\n",
 
60
  "cell_type": "code",
61
  "execution_count": 2,
62
  "metadata": {},
63
+ "outputs": [],
 
 
 
 
 
 
 
 
 
64
  "source": [
65
  "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
66
  "import torch\n",
 
70
  "\n",
71
  "# Load model and tokenizer\n",
72
  "tokenizer = AutoTokenizer.from_pretrained(\"./deepseek-coder-1.3b-instruct\")\n",
73
+ "model = AutoModelForCausalLM.from_pretrained(\"./deepseek-coder-1.3b-instruct\", torch_dtype=torch.bfloat16, device_map=device) \n",
74
+ "model.generation_config.pad_token_id = tokenizer.pad_token_id"
75
  ]
76
  },
77
  {
 
282
  "execution_count": 4,
283
  "metadata": {},
284
  "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  {
286
  "name": "stdout",
287
  "output_type": "stream",
288
  "text": [
289
  "SQLite:\n",
290
+ "SELECT SUM(reb_home + reb_away) AS combined_rebounds\n",
291
+ "FROM game\n",
292
+ "WHERE (team_name_home = 'Toronto Raptors' AND team_name_away = 'Brooklyn Nets')\n",
293
+ "OR (team_name_home = 'Brooklyn Nets' AND team_name_away = 'Toronto Raptors');\n",
294
  "\n"
295
  ]
296
  }
 
322
  "name": "stdout",
323
  "output_type": "stream",
324
  "text": [
325
+ "cleaned\n",
326
+ "(4350.0,)\n"
 
 
 
 
 
 
 
 
 
 
327
  ]
328
  }
329
  ],
 
342
  " query = query_output[4:]\n",
343
  "else:\n",
344
  " query = query_output\n",
345
+ "\n",
346
+ "try:\n",
347
+ " cursor.execute(query)\n",
348
+ " rows = cursor.fetchall()\n",
349
+ " for row in rows:\n",
350
+ " print(row)\n",
351
+ "except:\n",
352
+ " pass"
353
  ]
354
  },
355
  {
 
361
  },
362
  {
363
  "cell_type": "code",
364
+ "execution_count": 6,
365
  "metadata": {},
366
  "outputs": [
 
 
 
 
 
 
 
 
367
  {
368
  "name": "stdout",
369
  "output_type": "stream",
370
  "text": [
371
+ "What was the three-point shooting percentage for the Los Angeles Clippers in games against the Los Angeles Lakers?\n",
372
+ "SELECT AVG( CASE WHEN team_name_home = 'LA Clippers' THEN fg3_pct_home ELSE fg3_pct_away END ) AS avg_3pt_percentage FROM game WHERE (team_name_home = 'LA Clippers' AND team_name_away = 'Los Angeles Lakers') OR (team_name_home = 'Los Angeles Lakers' AND team_name_away = 'LA Clippers');\n",
373
+ "0.3734705882\n",
374
  "SQLite:\n",
375
+ "SELECT team_name_home, team_name_away, AVG(fg3_pct_home) AS three_point_percentage\n",
376
+ "FROM game\n",
377
+ "WHERE team_name_home = 'Los Angeles Clippers' AND team_name_away = 'Los Angeles Lakers'\n",
378
+ "GROUP BY team_name_home, team_name_away;\n",
379
  "\n",
 
380
  "Statement valid? True\n",
381
  "SQLite matched? False\n",
382
  "Result matched? True\n"
 
412
  "\n",
413
  " # Check if this is a multi-line query\n",
414
  " if \"|\" in sample_result or \"(\" in sample_result:\n",
415
+ " #print(rows)\n",
416
  " # Create list of results by stripping separators and splitting on them\n",
417
  " if \"(\" in sample_result:\n",
418
  " sample_result = sample_result.replace(\"(\", \"\").replace(\")\", \"\")\n",
 
445
  " return True, query_match, result\n",
446
  " # Else the sample result is a single value or string\n",
447
  " else:\n",
448
+ " #print(rows)\n",
449
  " result = False\n",
450
  " # Loop through model result and see if it contains the sample result\n",
451
  " for row in rows:\n",
 
498
  "print(\"SQLite matched? \" + str(result[1]))\n",
499
  "print(\"Result matched? \" + str(result[2]))"
500
  ]
501
+ },
502
+ {
503
+ "cell_type": "markdown",
504
+ "metadata": {},
505
+ "source": [
506
+ "## Create function to evaluate pretrained model on full datasets"
507
+ ]
508
+ },
509
+ {
510
+ "cell_type": "code",
511
+ "execution_count": 9,
512
+ "metadata": {},
513
+ "outputs": [
514
+ {
515
+ "name": "stdout",
516
+ "output_type": "stream",
517
+ "text": [
518
+ "Less than 90 results:\n",
519
+ "Percent valid: 0.0653061224489796\n",
520
+ "Percent SQLite matched: 0.00816326530612245\n",
521
+ "Percent result matched: 0.024489795918367346\n"
522
+ ]
523
+ }
524
+ ],
525
+ "source": [
526
+ "def run_evaluation(nba_df, title):\n",
527
+ " counter = 0\n",
528
+ " num_valid = 0\n",
529
+ " num_sql_matched = 0\n",
530
+ " num_result_matched = 0\n",
531
+ " for index, row in nba_df.iterrows():\n",
532
+ " # Create message with sample query and run model\n",
533
+ " message=[{ 'role': 'user', 'content': input_text + row[\"natural_query\"]}]\n",
534
+ " inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
535
+ " outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n",
536
+ "\n",
537
+ " # Obtain output\n",
538
+ " query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
539
+ "\n",
540
+ " # Evaluate model result\n",
541
+ " valid, sql_matched, result_matched = compare_result(row[\"sql_query\"], row[\"result\"], query_output)\n",
542
+ " if valid:\n",
543
+ " num_valid += 1\n",
544
+ " if sql_matched:\n",
545
+ " num_sql_matched += 1\n",
546
+ " if result_matched:\n",
547
+ " num_result_matched += 1\n",
548
+ "\n",
549
+ " # Break after predefined number of examples\n",
550
+ " counter += 1\n",
551
+ " if counter % 50 == 0:\n",
552
+ " print(\"Completed \" + str(counter))\n",
553
+ " elif counter == 20:\n",
554
+ " break\n",
555
+ "\n",
556
+ " # Print evaluation results\n",
557
+ " print(title + \" results:\")\n",
558
+ " print(\"Percent valid: \" + str(num_valid / len(nba_df)))\n",
559
+ " print(\"Percent SQLite matched: \" + str(num_sql_matched / len(nba_df)))\n",
560
+ " print(\"Percent result matched: \" + str(num_result_matched / len(nba_df)))\n",
561
+ "\n",
562
+ "less_than_90_df = pd.read_csv(\"./train-data/less_than_90.tsv\", sep='\\t')\n",
563
+ "run_evaluation(less_than_90_df, \"Less than 90\")\n",
564
+ "\n",
565
+ "# Run evaluation on all training data\n",
566
+ "#run_evaluation(df, \"All training data\")"
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "markdown",
571
+ "metadata": {},
572
+ "source": [
573
+ "# Evaluate on less than 90 dataset"
574
+ ]
575
  }
576
  ],
577
  "metadata": {