{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "Be sure to run the following notebook first before running this notebook:\n", "- 1-load-and-convert-statsbomb-data.ipynb" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2023-12-30T16:26:23.312159Z", "iopub.status.busy": "2023-12-30T16:26:23.311540Z", "iopub.status.idle": "2023-12-30T16:26:23.724636Z", "shell.execute_reply": "2023-12-30T16:26:23.724141Z" }, "pycharm": { "is_executing": false } }, "outputs": [], "source": [ "import os\n", "import tqdm\n", "import pandas as pd\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2023-12-30T16:26:23.726741Z", "iopub.status.busy": "2023-12-30T16:26:23.726493Z", "iopub.status.idle": "2023-12-30T16:26:24.456434Z", "shell.execute_reply": "2023-12-30T16:26:24.455729Z" }, "pycharm": { "is_executing": false } }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "import socceraction.spadl as spadl\n", "import socceraction.vaep.features as fs\n", "import socceraction.xthreat as xthreat" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Select data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2023-12-30T16:26:24.458812Z", "iopub.status.busy": "2023-12-30T16:26:24.458582Z", "iopub.status.idle": "2023-12-30T16:26:24.480218Z", "shell.execute_reply": "2023-12-30T16:26:24.479643Z" }, "pycharm": { "is_executing": false } }, "outputs": [], "source": [ "# Configure file and folder names, use SPADL format.\n", "datafolder = \"../data-fifa\"\n", "spadl_h5 = os.path.join(datafolder, \"spadl-statsbomb.h5\")\n", "xT_h5 = os.path.join(datafolder, \"xT.h5\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2023-12-30T16:26:24.482774Z", "iopub.status.busy": "2023-12-30T16:26:24.482616Z", "iopub.status.idle": "2023-12-30T16:26:25.773057Z", "shell.execute_reply": "2023-12-30T16:26:25.772465Z" }, "pycharm": { "is_executing": false, "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "nb of games: 64\n" ] } ], "source": [ "games = pd.read_hdf(spadl_h5, \"games\")\n", "print(\"nb of games:\", len(games))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2023-12-30T16:26:25.776110Z", "iopub.status.busy": "2023-12-30T16:26:25.775862Z", "iopub.status.idle": "2023-12-30T16:26:28.078380Z", "shell.execute_reply": "2023-12-30T16:26:28.077809Z" }, "pycharm": { "is_executing": false, "name": "#%%\n" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████████████████████████████████████████████████████████████████████████████| 64/64 [00:02<00:00, 28.53it/s]\n" ] } ], "source": [ "# Read in all actions of games\n", "A = []\n", "\n", "with pd.HDFStore(spadl_h5) as spadlstore:\n", " for game in tqdm.tqdm(list(games.itertuples())):\n", " actions = spadlstore[f\"actions/game_{game.game_id}\"]\n", " actions = spadl.add_names(actions)\n", " actions = spadl.play_left_to_right(actions, game.home_team_id)\n", " A.append(actions) \n", "\n", "A = pd.concat(A)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load pre-trained model" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2023-12-30T16:26:28.081074Z", "iopub.status.busy": "2023-12-30T16:26:28.080909Z", "iopub.status.idle": "2023-12-30T16:26:28.131183Z", "shell.execute_reply": "2023-12-30T16:26:28.130518Z" } }, "outputs": [], "source": [ "# uncomment the lines below if you get an SSLError\n", "# import ssl\n", "# ssl._create_default_https_context = ssl._create_unverified_context\n", "\n", "url_grid = \"https://karun.in/blog/data/open_xt_12x8_v1.json\"\n", "xTModel = xthreat.load_model(url_grid)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train a custom model" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2023-12-30T16:26:28.134226Z", "iopub.status.busy": "2023-12-30T16:26:28.133930Z", "iopub.status.idle": "2023-12-30T16:26:29.242405Z", "shell.execute_reply": "2023-12-30T16:26:29.241865Z" }, "pycharm": { "is_executing": false, "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# iterations: 45\n" ] } ], "source": [ "xTModel = xthreat.ExpectedThreat(l=16, w=12)\n", "xTModel.fit(A);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute xT ratings" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2023-12-30T16:26:29.244349Z", "iopub.status.busy": "2023-12-30T16:26:29.244191Z", "iopub.status.idle": "2023-12-30T16:26:29.345267Z", "shell.execute_reply": "2023-12-30T16:26:29.344721Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", " | type_name | \n", "start_x | \n", "start_y | \n", "end_x | \n", "end_y | \n", "xT_value | \n", "
---|---|---|---|---|---|---|
0 | \n", "pass | \n", "52.0625 | \n", "34.425 | \n", "43.3125 | \n", "33.575 | \n", "-0.000860 | \n", "
1 | \n", "dribble | \n", "43.3125 | \n", "33.575 | \n", "44.1875 | \n", "34.425 | \n", "-0.000255 | \n", "
2 | \n", "pass | \n", "44.1875 | \n", "34.425 | \n", "40.6875 | \n", "22.525 | \n", "-0.000446 | \n", "
3 | \n", "dribble | \n", "40.6875 | \n", "22.525 | \n", "42.4375 | \n", "21.675 | \n", "0.000000 | \n", "
4 | \n", "pass | \n", "42.4375 | \n", "21.675 | \n", "56.4375 | \n", "1.275 | \n", "0.001047 | \n", "
5 | \n", "dribble | \n", "56.4375 | \n", "1.275 | \n", "57.3125 | \n", "2.125 | \n", "0.000000 | \n", "
7 | \n", "pass | \n", "21.4375 | \n", "49.725 | \n", "27.5625 | \n", "66.725 | \n", "-0.000299 | \n", "
9 | \n", "dribble | \n", "83.5625 | \n", "14.025 | \n", "82.6875 | \n", "14.025 | \n", "0.000000 | \n", "
10 | \n", "pass | \n", "82.6875 | \n", "14.025 | \n", "80.0625 | \n", "3.825 | \n", "-0.003468 | \n", "
11 | \n", "dribble | \n", "80.0625 | \n", "3.825 | \n", "77.4375 | \n", "12.325 | \n", "-0.000025 | \n", "