Spaces:
Build error
Build error
import plotly.express as px | |
import plotly.graph_objects as go | |
import plotly.colors as pc | |
from scipy.stats import gaussian_kde | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
from translate import max_pitch_types | |
from data import df, pitch_stats | |
# GRADIO FUNCTIONS | |
# location maps | |
def fit_pred_kde(data, X, Y): | |
kde = gaussian_kde(data) | |
return kde(np.stack((X, Y)).reshape(2, -1)).reshape(*X.shape) | |
plot_s = 256 | |
sz_h = 200 | |
sz_w = 160 | |
h_h = 200 - 40*2 | |
h_w = 160 - 32*2 | |
kde_range = np.arange(-plot_s/2, plot_s/2, 1) | |
X, Y = np.meshgrid( | |
kde_range, | |
kde_range | |
) | |
def coordinatify(h, w): | |
return dict( | |
x0=-w/2, | |
y0=-h/2, | |
x1=w/2, | |
y1=h/2 | |
) | |
colorscale = pc.sequential.OrRd | |
colorscale = [ | |
[0, 'rgba(0, 0, 0, 0)'], | |
] + [ | |
[i / len(colorscale), color] for i, color in enumerate(colorscale, start=1) | |
] | |
def plot_pitch_map(player=None, loc=None, pitch_type=None, pitch_name=None): | |
assert not ((loc is None and player is None) or (loc is not None and player is not None)), 'exactly one of `player` or `loc` must be specified' | |
if loc is None and player is not None: | |
assert not ((pitch_type is None and pitch_name is None) or (pitch_type is not None and pitch_name is not None)), 'exactly one of `pitch_type` or `pitch_name` must be specified' | |
pitch_val = pitch_type or pitch_name | |
pitch_col = 'pitch_type' if pitch_type else 'pitch_name' | |
loc = df.set_index(['name', pitch_col]).loc[(player, pitch_val), ['plate_x', 'plate_z']] | |
Z = fit_pred_kde(loc.to_numpy().T, X, Y) | |
fig = go.Figure() | |
fig.add_shape( | |
type="rect", | |
**coordinatify(sz_h, sz_w), | |
line_color='gray', | |
# fillcolor='rgba(220, 220, 220, 0.75)', #gainsboro | |
) | |
fig.add_shape( | |
type="rect", | |
**coordinatify(h_h, h_w), | |
line_color='dimgray', | |
) | |
fig.add_trace(go.Contour( | |
z=Z, | |
x=kde_range, | |
y=kde_range, | |
colorscale=colorscale, | |
zmin=1e-5, | |
zmax=Z.max(), | |
contours={ | |
'start': 1e-5, | |
'end': Z.max(), | |
'size': (Z.max() - 1e-5) / 5 | |
}, | |
showscale=False | |
)) | |
fig.update_layout( | |
xaxis=dict(range=[-plot_s/2, plot_s/2+1]), | |
yaxis=dict(range=[-plot_s/2, plot_s/2+1], scaleanchor='x', scaleratio=1), | |
# width=384, | |
# height=384 | |
) | |
return fig | |
def plot_empty_pitch_map(): | |
fig = go.Figure() | |
fig.add_annotation( | |
x=0, | |
y=0, | |
text='No visualization<br>as less than 10 pitches thrown', | |
showarrow=False | |
) | |
fig.update_layout( | |
xaxis=dict(range=[-plot_s/2, plot_s/2+1]), | |
yaxis=dict(range=[-plot_s/2, plot_s/2+1], scaleanchor='x', scaleratio=1), | |
# width=384, | |
# height=384 | |
) | |
return fig | |
# velo distribution | |
def plot_pitch_velo(player=None, velos=None, pitch_type=None, pitch_name=None): | |
assert not ((velos is None and player is None) or (velos is not None and player is not None)), 'exactly one of `player` or `loc` must be specified' | |
if velos is None and player is not None: | |
assert not ((pitch_type is None and pitch_name is None) or (pitch_type is not None and pitch_name is not None)), 'exactly one of `pitch_type` or `pitch_name` must be specified' | |
pitch_val = pitch_type or pitch_name | |
pitch_col = 'pitch_type' if pitch_type else 'pitch_name' | |
velos = df.set_index(['name', pitch_col]).loc[(player, pitch_val), 'release_speed'] | |
fig = go.Figure(data=go.Violin(x=velos, side='positive', hoveron='points', points=False, meanline_visible=True, name='Velocity Distribution')) | |
fig.update_layout( | |
xaxis=dict( | |
title='Velocity', | |
range=[125, 170], | |
scaleratio=2 | |
), | |
yaxis=dict( | |
title='Frequency', | |
range=[0, 0.3], | |
scaleanchor='x', | |
scaleratio=1, | |
tickvals=np.linspace(0, 0.3, 3), | |
ticktext=np.linspace(0, 0.3, 3), | |
), | |
autosize=True, | |
# width=512, | |
# height=256, | |
modebar_remove=['zoom', 'autoScale', 'resetScale'], | |
) | |
return fig | |
def plot_empty_pitch_velo(): | |
fig = go.Figure() | |
fig.add_annotation( | |
x=(170+125)/2, | |
y=0.3/2, | |
text='No visualization<br>as less than 10 pitches thrown', | |
showarrow=False, | |
) | |
fig.update_layout( | |
xaxis=dict( | |
title='Velocity', | |
range=[125, 170], | |
scaleratio=2 | |
), | |
yaxis=dict( | |
title='Frequency', | |
range=[0, 0.3], | |
scaleanchor='x', | |
scaleratio=1, | |
# tickvals=np.linspace(0, 0.3, 3), | |
# ticktext=np.linspace(0, 0.3, 3), | |
tickvals=[0.15], | |
ticktext=[0.15] | |
), | |
autosize=True, | |
# width=512, | |
# height=256, | |
modebar_remove=['zoom', 'autoScale', 'resetScale'], | |
) | |
return fig | |
def plot_all_pitch_velo(player=None, player_df=None, pitch_counts=None, min_pitches=10): | |
# assert not ((player is None and player_df is None) or (player is not None and player_df is not None)), 'exactly one of `player` or `player_df` must be specified' | |
if player_df is None and player is not None: | |
assert pitch_counts is None, '`pitch_counts` must be `None` if `player_df` is None' | |
player_df = df.sort_values('name').set_index('name').loc[player].sort_values('pitch_name').set_index('pitch_name') | |
pitch_counts = player_df.index.value_counts(ascending=True) | |
league_df = df.set_index('pitch_name') | |
fig = go.Figure() | |
velo_center = (player_df['release_speed'].min() + player_df['release_speed'].max()) / 2 | |
for i, (pitch_name, count) in enumerate(pitch_counts.items()): | |
velos = player_df.loc[pitch_name, 'release_speed'] | |
league_velos = league_df.loc[pitch_name, 'release_speed'] | |
fig.add_trace(go.Violin( | |
x=league_velos, | |
y=[pitch_name]*len(league_velos), | |
line_color='gray', | |
side='positive', | |
orientation='h', | |
meanline_visible=True, | |
points=False, | |
legendgroup='NPB', | |
legendrank=1, | |
# visible='legendonly', | |
showlegend=False, | |
name='NPB', | |
)) | |
if count >= min_pitches: | |
fig.add_trace(go.Violin( | |
x=velos, | |
y=[pitch_name]*len(velos), | |
side='positive', | |
orientation='h', | |
meanline_visible=True, | |
points=False, | |
legendgroup=pitch_name, | |
legendrank=2+(len(pitch_counts) - i), | |
name=pitch_name | |
)) | |
else: | |
fig.add_trace(go.Scatter( | |
x=[velo_center], | |
y=[pitch_name], | |
text=['No visualization as less than 10 pitches thrown'], | |
textposition='top center', | |
hovertext=False, | |
mode="lines+text", | |
legendgroup=pitch_name, | |
legendrank=2+(len(pitch_counts) - i), | |
name=pitch_name, | |
)) | |
fig.add_trace(go.Violin( | |
x=player_df['release_speed'], | |
y=[player]*len(player_df), | |
side='positive', | |
orientation='h', | |
meanline_visible=True, | |
points=False, | |
legendrank=0, | |
name=player | |
)) | |
fig.add_trace(go.Violin( | |
x=league_df['release_speed'], | |
y=[player]*len(league_df), | |
line_color='gray', | |
side='positive', | |
orientation='h', | |
meanline_visible=True, | |
points=False, | |
legendgroup='NPB', | |
legendrank=1, | |
# visible='legendonly', | |
name='NPB', | |
)) | |
fig.update_xaxes(title='Velocity') | |
return fig | |
def get_data(player): | |
player_name = f'# {player}' | |
_df = df.set_index('name').loc[player] | |
_df.to_csv(f'files/npb.csv', index=False) | |
_df_by_pitch_name = _df.set_index('pitch_name') | |
usage_fig = px.pie(_df['pitch_name'], names='pitch_name') | |
usage_fig.update_traces(texttemplate='%{percent:.1%}', hovertemplate=f'<b>{player}</b><br>' + 'threw a <b>%{label}</b><br><b>%{percent:.1%}</b> of the time (<b>%{value}</b> pitches)') | |
pitch_counts = _df['pitch_name'].value_counts() | |
pitch_groups = [] | |
pitch_names = [] | |
pitch_infos = [] | |
pitch_velos = [] | |
pitch_maps = [] | |
for pitch_name, count in pitch_counts.items(): | |
pitch_groups.append(gr.update(visible=True)) | |
pitch_names.append(gr.update(value=f'### {pitch_name}', visible=True)) | |
pitch_infos.append(gr.update( | |
value=pd.DataFrame([{ | |
'Whiff%': pitch_stats.loc[(player, pitch_name), 'Whiff%'].item(), | |
'CSW%': pitch_stats.loc[(player, pitch_name), 'CSW%'].item() | |
}]), | |
visible=True | |
)) | |
if count > 10: | |
pitch_velos.append(gr.update( | |
value=plot_pitch_velo(velos=_df_by_pitch_name.loc[pitch_name, 'release_speed']), | |
visible=True | |
)) | |
pitch_maps.append(gr.update(value=plot_pitch_map(player, pitch_name=pitch_name), label='Pitch location', visible=True)) | |
else: | |
pitch_velos.append(gr.update(value=plot_empty_pitch_velo(),visible=True )) | |
pitch_maps.append(gr.update(value=plot_empty_pitch_map(), label=pitch_name, visible=True)) | |
for _ in range(max_pitch_types - len(pitch_names)): | |
pitch_groups.append(gr.update(visible=False)) | |
pitch_names.append(gr.update(value=None, visible=False)) | |
pitch_infos.append(gr.update(value=None, visible=False)) | |
for _ in range(max_pitch_types - len(pitch_maps)): | |
pitch_velos.append(gr.update(value=None, visible=False)) | |
pitch_maps.append(gr.update(value=None, visible=False)) | |
pitch_velo_summary = plot_all_pitch_velo(player=player, player_df=_df_by_pitch_name, pitch_counts=pitch_counts.sort_values(ascending=True)) | |
return player_name, 'files/npb.csv', usage_fig, *pitch_groups, *pitch_names, *pitch_infos, *pitch_velos, *pitch_maps, pitch_velo_summary | |