patrickramos commited on
Commit
cf5350e
·
1 Parent(s): 13a3a28

Update app

Browse files
Files changed (3) hide show
  1. data.py +82 -0
  2. demo.py +9 -5
  3. gradio_function.py +5 -85
data.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pandas as pd
3
+ import numpy as np
4
+ from gradio_client import Client
5
+ from tqdm.auto import tqdm
6
+
7
+ import os
8
+ import re
9
+
10
+ from translate import translate_pa_outcome, translate_pitch_outcome, jp_pitch_to_en_pitch, jp_pitch_to_pitch_code, translate_pitch_outcome, max_pitch_types
11
+
12
+ # load game data
13
+ game_df = pd.read_csv('game.csv').drop_duplicates()
14
+ assert len(game_df) == len(game_df['game_pk'].unique())
15
+
16
+ # load pa data
17
+ pa_df = []
18
+ for game_pk in tqdm(game_df['game_pk']):
19
+ pa_df.append(pd.read_csv(os.path.join('pa', f'{game_pk}.csv'), dtype={'pa_pk': str}))
20
+ pa_df = pd.concat(pa_df, axis='rows')
21
+
22
+ # load pitch data
23
+ pitch_df = []
24
+ for game_pk in tqdm(game_df['game_pk']):
25
+ pitch_df.append(pd.read_csv(os.path.join('pitch', f'{game_pk}.csv'), dtype={'pa_pk': str}))
26
+ pitch_df = pd.concat(pitch_df, axis='rows')
27
+ pitch_df
28
+
29
+ # load player data
30
+ player_df = pd.read_csv('player.csv')
31
+ player_df
32
+
33
+ # translate pa data
34
+ pa_df['_des'] = pa_df['des'].str.strip()
35
+ pa_df['des'] = pa_df['des'].str.strip()
36
+ pa_df['des_more'] = pa_df['des_more'].str.strip()
37
+ pa_df.loc[pa_df['des'].isna(), 'des'] = pa_df[pa_df['des'].isna()]['des_more']
38
+ pa_df.loc[:, 'des'] = pa_df['des'].apply(lambda item: item.split()[0] if (len(item.split()) > 1 and re.search(r'+\d+点', item)) else item)
39
+ non_home_plate_outcome = (pa_df['des'].isin(['ボール', '見逃し', '空振り'])) | (pa_df['des'].str.endswith('塁けん制'))
40
+ pa_df.loc[non_home_plate_outcome, 'des'] = pa_df.loc[non_home_plate_outcome, 'des_more']
41
+ pa_df['des'] = pa_df['des'].apply(translate_pa_outcome)
42
+
43
+ # translate pitch data
44
+ pitch_df = pitch_df[~pitch_df['pitch_name'].isna()]
45
+ pitch_df['jp_pitch_name'] = pitch_df['pitch_name']
46
+ pitch_df['pitch_name'] = pitch_df['jp_pitch_name'].apply(lambda pitch_name: jp_pitch_to_en_pitch[pitch_name])
47
+ pitch_df['pitch_type'] = pitch_df['jp_pitch_name'].apply(lambda pitch_name: jp_pitch_to_pitch_code[pitch_name])
48
+ pitch_df['description'] = pitch_df['description'].apply(lambda item: item.split()[0] if len(item.split()) > 1 else item)
49
+ pitch_df['description'] = pitch_df['description'].apply(translate_pitch_outcome)
50
+ pitch_df['release_speed'] = pitch_df['release_speed'].replace('-', np.nan)
51
+ pitch_df.loc[~pitch_df['release_speed'].isna(), 'release_speed'] = pitch_df.loc[~pitch_df['release_speed'].isna(), 'release_speed'].str.removesuffix('km/h').astype(int)
52
+ pitch_df['plate_x'] = (pitch_df['plate_x'] + 13) - 80
53
+ pitch_df['plate_z'] = 200 - (pitch_df['plate_z'] + 13) - 100
54
+
55
+ # translate player data
56
+ client = Client("Ramos-Ramos/npb_name_translator")
57
+ en_names = client.predict(
58
+ jp_names='\n'.join(player_df.name.tolist()),
59
+ api_name="/predict"
60
+ )
61
+ player_df['jp_name'] = player_df['name']
62
+ player_df['name'] = [name if name != 'nan' else np.nan for name in en_names.splitlines()]
63
+
64
+ # merge pitch and pa data
65
+ df = pd.merge(pitch_df, pa_df, 'inner', on=['game_pk', 'pa_pk'])
66
+ df = pd.merge(df, player_df.rename(columns={'player_id': 'pitcher'}), 'inner', on='pitcher')
67
+ df['whiff'] = df['description'].isin(['SS', 'K'])
68
+ df['swing'] = ~df['description'].isin(['B', 'BB', 'LS', 'inv_K', 'bunt_K', 'HBP', 'SH', 'SH E', 'SH FC', 'obstruction', 'illegal_pitch', 'defensive_interference'])
69
+ df['csw'] = df['description'].isin(['SS', 'K', 'LS', 'inv_K'])
70
+ df['normal_pitch'] = ~df['description'].isin(['obstruction', 'illegal_pitch', 'defensive_interference']) # guess
71
+
72
+ whiff_rate = df.groupby(['name', 'pitch_name'])
73
+ whiff_rate = (whiff_rate['whiff'].sum() / whiff_rate['swing'].sum() * 100).round(1).rename('Whiff%').reset_index()
74
+
75
+ csw_rate = df.groupby(['name', 'pitch_name'])
76
+ csw_rate = (csw_rate['csw'].sum() / csw_rate['normal_pitch'].sum() * 100).round(1).rename('CSW%').reset_index()
77
+
78
+ pitch_stats = pd.merge(
79
+ whiff_rate,
80
+ csw_rate,
81
+ on=['name', 'pitch_name']
82
+ ).set_index(['name', 'pitch_name'])
demo.py CHANGED
@@ -3,10 +3,13 @@ import gradio as gr
3
  import pandas as pd
4
 
5
  from math import ceil
 
6
 
7
- from gradio_function import player_df, jp_pitch_to_en_pitch, get_data
8
- from translate import max_pitch_types
 
9
 
 
10
 
11
  css = '''
12
  .pitch-usage {height: 256px}
@@ -32,6 +35,7 @@ with gr.Blocks(css=css) as demo:
32
  ''')
33
  player = gr.Dropdown(choices=sorted(player_df['name'].dropna().tolist()), label='Player')
34
  player_info = gr.Markdown()
 
35
  with gr.Row():
36
  with gr.Column():
37
  gr.Markdown('## Placeholder')
@@ -87,9 +91,9 @@ with gr.Blocks(css=css) as demo:
87
  '''
88
  )
89
 
90
- player.input(get_data, inputs=player, outputs=[player_info, usage, *pitch_groups, *pitch_names, *pitch_infos, *pitch_velos, *pitch_maps, pitch_velo_summary])
91
 
92
  demo.launch(
93
- share=True
94
- # debug=True
95
  )
 
3
  import pandas as pd
4
 
5
  from math import ceil
6
+ import os
7
 
8
+ from data import player_df
9
+ from gradio_function import get_data
10
+ from translate import jp_pitch_to_en_pitch, max_pitch_types
11
 
12
+ os.makedirs('files', exist_ok=True)
13
 
14
  css = '''
15
  .pitch-usage {height: 256px}
 
35
  ''')
36
  player = gr.Dropdown(choices=sorted(player_df['name'].dropna().tolist()), label='Player')
37
  player_info = gr.Markdown()
38
+ download_file = gr.DownloadButton(label='Download player data')
39
  with gr.Row():
40
  with gr.Column():
41
  gr.Markdown('## Placeholder')
 
91
  '''
92
  )
93
 
94
+ player.input(get_data, inputs=player, outputs=[player_info, download_file, usage, *pitch_groups, *pitch_names, *pitch_infos, *pitch_velos, *pitch_maps, pitch_velo_summary])
95
 
96
  demo.launch(
97
+ share=True,
98
+ debug=True
99
  )
gradio_function.py CHANGED
@@ -1,95 +1,14 @@
1
 
2
- import pandas as pd
3
- import numpy as np
4
- from tqdm.auto import tqdm
5
  import plotly.express as px
6
  import plotly.graph_objects as go
7
  import plotly.colors as pc
8
  from scipy.stats import gaussian_kde
9
  import numpy as np
 
10
  import gradio as gr
11
- from gradio_client import Client
12
-
13
- from scipy.stats import gaussian_kde
14
- import numpy as np
15
-
16
- import os
17
- import re
18
-
19
- from translate import translate_pa_outcome, translate_pitch_outcome, jp_pitch_to_en_pitch, jp_pitch_to_pitch_code, translate_pitch_outcome, max_pitch_types
20
-
21
- # load game data
22
- game_df = pd.read_csv('game.csv').drop_duplicates()
23
- assert len(game_df) == len(game_df['game_pk'].unique())
24
-
25
- # load pa data
26
- pa_df = []
27
- for game_pk in tqdm(game_df['game_pk']):
28
- pa_df.append(pd.read_csv(os.path.join('pa', f'{game_pk}.csv'), dtype={'pa_pk': str}))
29
- pa_df = pd.concat(pa_df, axis='rows')
30
-
31
- # load pitch data
32
- pitch_df = []
33
- for game_pk in tqdm(game_df['game_pk']):
34
- pitch_df.append(pd.read_csv(os.path.join('pitch', f'{game_pk}.csv'), dtype={'pa_pk': str}))
35
- pitch_df = pd.concat(pitch_df, axis='rows')
36
- pitch_df
37
-
38
- # load player data
39
- player_df = pd.read_csv('player.csv')
40
- player_df
41
-
42
- # translate pa data
43
- pa_df['_des'] = pa_df['des'].str.strip()
44
- pa_df['des'] = pa_df['des'].str.strip()
45
- pa_df['des_more'] = pa_df['des_more'].str.strip()
46
- pa_df.loc[pa_df['des'].isna(), 'des'] = pa_df[pa_df['des'].isna()]['des_more']
47
- pa_df.loc[:, 'des'] = pa_df['des'].apply(lambda item: item.split()[0] if (len(item.split()) > 1 and re.search(r'+\d+点', item)) else item)
48
- non_home_plate_outcome = (pa_df['des'].isin(['ボール', '見逃し', '空振り'])) | (pa_df['des'].str.endswith('塁けん制'))
49
- pa_df.loc[non_home_plate_outcome, 'des'] = pa_df.loc[non_home_plate_outcome, 'des_more']
50
- pa_df['des'] = pa_df['des'].apply(translate_pa_outcome)
51
-
52
- # translate pitch data
53
- pitch_df = pitch_df[~pitch_df['pitch_name'].isna()]
54
- pitch_df['jp_pitch_name'] = pitch_df['pitch_name']
55
- pitch_df['pitch_name'] = pitch_df['jp_pitch_name'].apply(lambda pitch_name: jp_pitch_to_en_pitch[pitch_name])
56
- pitch_df['pitch_type'] = pitch_df['jp_pitch_name'].apply(lambda pitch_name: jp_pitch_to_pitch_code[pitch_name])
57
- pitch_df['description'] = pitch_df['description'].apply(lambda item: item.split()[0] if len(item.split()) > 1 else item)
58
- pitch_df['description'] = pitch_df['description'].apply(translate_pitch_outcome)
59
- pitch_df['release_speed'] = pitch_df['release_speed'].replace('-', np.nan)
60
- pitch_df.loc[~pitch_df['release_speed'].isna(), 'release_speed'] = pitch_df.loc[~pitch_df['release_speed'].isna(), 'release_speed'].str.removesuffix('km/h').astype(int)
61
- pitch_df['plate_x'] = (pitch_df['plate_x'] + 13) - 80
62
- pitch_df['plate_z'] = 200 - (pitch_df['plate_z'] + 13) - 100
63
-
64
- # translate player data
65
- client = Client("Ramos-Ramos/npb_name_translator")
66
- en_names = client.predict(
67
- jp_names='\n'.join(player_df.name.tolist()),
68
- api_name="/predict"
69
- )
70
- player_df['jp_name'] = player_df['name']
71
- player_df['name'] = [name if name != 'nan' else np.nan for name in en_names.splitlines()]
72
-
73
- # merge pitch and pa data
74
- df = pd.merge(pitch_df, pa_df, 'inner', on=['game_pk', 'pa_pk'])
75
- df = pd.merge(df, player_df.rename(columns={'player_id': 'pitcher'}), 'inner', on='pitcher')
76
- df['whiff'] = df['description'].isin(['SS', 'K'])
77
- df['swing'] = ~df['description'].isin(['B', 'BB', 'LS', 'inv_K', 'bunt_K', 'HBP', 'SH', 'SH E', 'SH FC', 'obstruction', 'illegal_pitch', 'defensive_interference'])
78
- df['csw'] = df['description'].isin(['SS', 'K', 'LS', 'inv_K'])
79
- df['normal_pitch'] = ~df['description'].isin(['obstruction', 'illegal_pitch', 'defensive_interference']) # guess
80
-
81
- whiff_rate = df.groupby(['name', 'pitch_name'])
82
- whiff_rate = (whiff_rate['whiff'].sum() / whiff_rate['swing'].sum() * 100).round(1).rename('Whiff%').reset_index()
83
-
84
- csw_rate = df.groupby(['name', 'pitch_name'])
85
- csw_rate = (csw_rate['csw'].sum() / csw_rate['normal_pitch'].sum() * 100).round(1).rename('CSW%').reset_index()
86
-
87
- pitch_stats = pd.merge(
88
- whiff_rate,
89
- csw_rate,
90
- on=['name', 'pitch_name']
91
- ).set_index(['name', 'pitch_name'])
92
 
 
 
93
 
94
  # GRADIO FUNCTIONS
95
 
@@ -343,6 +262,7 @@ def get_data(player):
343
  player_name = f'# {player}'
344
 
345
  _df = df.set_index('name').loc[player]
 
346
  _df_by_pitch_name = _df.set_index('pitch_name')
347
 
348
  usage_fig = px.pie(_df['pitch_name'], names='pitch_name')
@@ -387,4 +307,4 @@ def get_data(player):
387
 
388
  pitch_velo_summary = plot_all_pitch_velo(player=player, player_df=_df_by_pitch_name, pitch_counts=pitch_counts.sort_values(ascending=True))
389
 
390
- return player_name, usage_fig, *pitch_groups, *pitch_names, *pitch_infos, *pitch_velos, *pitch_maps, pitch_velo_summary
 
1
 
 
 
 
2
  import plotly.express as px
3
  import plotly.graph_objects as go
4
  import plotly.colors as pc
5
  from scipy.stats import gaussian_kde
6
  import numpy as np
7
+ import pandas as pd
8
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ from translate import max_pitch_types
11
+ from data import df, pitch_stats
12
 
13
  # GRADIO FUNCTIONS
14
 
 
262
  player_name = f'# {player}'
263
 
264
  _df = df.set_index('name').loc[player]
265
+ _df.to_csv(f'files/npb.csv', index=False)
266
  _df_by_pitch_name = _df.set_index('pitch_name')
267
 
268
  usage_fig = px.pie(_df['pitch_name'], names='pitch_name')
 
307
 
308
  pitch_velo_summary = plot_all_pitch_velo(player=player, player_df=_df_by_pitch_name, pitch_counts=pitch_counts.sort_values(ascending=True))
309
 
310
+ return player_name, 'files/npb.csv', usage_fig, *pitch_groups, *pitch_names, *pitch_infos, *pitch_velos, *pitch_maps, pitch_velo_summary