patrickramos's picture
Translate team names
e5250f2
# import pandas as pd
import polars as pl
import numpy as np
# from gradio_client import Client
from tqdm.auto import tqdm
import os
import re
from seasons import SEASONS
from translate import (
translate_pa_outcome, translate_pitch_outcome,
jp_pitch_to_en_pitch, jp_pitch_to_pitch_code,
jp_team_to_en_team, jp_team_to_en_full_team,
max_pitch_types
)
def identify_bb_type(hit_type):
if hit_type in list(range(1, 10)) + list(range(40, 49)):
return 'ground_ball'
elif hit_type in list(range(58, 67))+list(range(201, 209)):
return 'line_drive'
elif hit_type in list(range(28, 31)) + list(range(55, 58)) + list(range(107, 110)) + list(range(247, 251)):
return 'fly_ball'
elif hit_type in list(range(49, 55)) + list(range(101, 107)) + list(range(242, 248)):
return 'pop_up'
elif hit_type in [31, 32]:
return None
else:
raise Exception(f'Unexpect hit_type {hit_type}')
DATA_DIR = 'data'
SEASONS = [str(season) for season in SEASONS]
game_df, pa_df, pitch_df, player_df, df = [], [], [], [], []
for season in SEASONS:
season_dir = os.path.join(DATA_DIR, season)
# load game data
_game_df = pl.read_csv(os.path.join(season_dir, 'game.csv')).unique()
assert len(_game_df) == len(_game_df['game_pk'].unique())
# load pa data
_pa_df = []
for game_pk in tqdm(_game_df['game_pk']):
_pa_df.append(pl.read_csv(os.path.join(season_dir, 'pa', f'{game_pk}.csv'), schema_overrides={'pa_pk': str}))
_pa_df = pl.concat(_pa_df)
# load pitch data
_pitch_df = []
for game_pk in tqdm(_game_df['game_pk']):
_pitch_df.append(pl.read_csv(os.path.join(season_dir, 'pitch', f'{game_pk}.csv'), schema_overrides={'pitch_id': pl.Int64, 'pitch_number': pl.Int64, 'pa_pk': str, 'on_1b': pl.Int64, 'on_2b': pl.Int64, 'on_3b': pl.Int64}))
try:
_pitch_df = pl.concat(_pitch_df)
except:
rows = []
for __pitch_df in _pitch_df:
row = dict(zip(__pitch_df.columns, __pitch_df.dtypes))
print(row)
rows.append(row)
print(pl.DataFrame(rows))
# load player data
_player_df = pl.read_csv(os.path.join(season_dir, 'player.csv'))
# translate game data
_game_df = (
_game_df
.with_columns(
pl.col('home_team').alias('jp_home_team'),
pl.col('away_team').alias('jp_away_team')
)
.with_columns(
pl.col('home_team').replace_strict(jp_team_to_en_team),
pl.col('home_team').replace_strict(jp_team_to_en_full_team).alias('full_home_team'),
pl.col('away_team').replace_strict(jp_team_to_en_team),
pl.col('away_team').replace_strict(jp_team_to_en_full_team).alias('full_away_team')
)
)
# translate pa data
_pa_df = (
_pa_df
.with_columns(
pl.col('des').str.strip_chars().alias('_des'),
pl.col('des').str.strip_chars(),
pl.col('des_more').str.strip_chars()
)
.with_columns(
pl.col('des').fill_null(pl.col('des_more'))
)
.with_columns(
pl.when(
(pl.col('des').str.split(' ').list.len() > 1) &
(pl.col('des').str.contains(r'+\d+点'))
)
.then(pl.col('des').str.split(' ').list.first())
.otherwise(pl.col('des'))
.alias('des')
)
.with_columns(
pl.when(
pl.col('des').is_in(['ボール', '見逃し', '空振り']) |
pl.col('des').str.ends_with('塁けん制')
)
.then(
pl.col('des_more')
)
.otherwise(
pl.col('des')
)
.alias('des')
)
.with_columns(
pl.col('des').map_elements(translate_pa_outcome, return_dtype=str)
)
.with_columns(
pl.col('bb_type').alias('hit_type').str.strip_prefix('dakyu').cast(int).alias('hit_type')
)
.with_columns(
pl.col('hit_type').map_elements(lambda hit_type: identify_bb_type(hit_type), return_dtype=str).alias('bb_type')
)
)
# translate pitch data
_pitch_df = (
_pitch_df
.filter(pl.col('pitch_name').is_not_null())
.with_columns(
pl.col('pitch_name').alias('jp_pitch_name')
)
.with_columns(
# pl.col('jp_pitch_name').map_elements(lambda pitch_name: jp_pitch_to_en_pitch[pitch_name], return_dtype=str).alias('pitch_name'),
pl.col('jp_pitch_name').replace_strict(jp_pitch_to_en_pitch).alias('pitch_name'),
# pl.col('jp_pitch_name').map_elements(lambda pitch_name: jp_pitch_to_pitch_code[pitch_name], return_dtype=str).alias('pitch_type'),
pl.col('jp_pitch_name').replace_strict(jp_pitch_to_pitch_code).alias('pitch_type'),
pl.col('description').str.split(' ').list.first().map_elements(translate_pitch_outcome, return_dtype=str),
pl.when(
pl.col('release_speed') != '-'
)
.then(
pl.col('release_speed').str.strip_suffix('km/h')
)
.otherwise(
None
)
.alias('release_speed'),
((pl.col('plate_x') + 13) - 80).alias('plate_x'),
(200 - (pl.col('plate_z') + 13) - 100).alias('plate_z'),
)
.with_columns(
pl.col('release_speed').cast(int), # idk why I can't do this during the strip_suffix step
)
)
# translate player data
register = (
pl.read_csv(os.path.join(season_dir, 'register.csv'))
.with_columns(
pl.col('en_name').str.replace(',', '').alias('en_name'),
)
.select(
pl.col('en_name'),
pl.col('jp_team').alias('team'),
pl.col('jp_name').alias('name')
)
)
_player_df = (
_player_df
.join(register, on=['name', 'team'], how='inner')
.with_columns(
pl.col('en_name').alias('name'),
pl.col('team').alias('jp_team')
)
.with_columns(
pl.col('jp_team').replace_strict(jp_team_to_en_team).alias('team'),
pl.col('jp_team').replace_strict(jp_team_to_en_full_team).alias('full_team'),
)
.drop(pl.col('en_name'))
)
# merge pitch and pa data
_df = (
(
_pitch_df
.join(_pa_df, on=['game_pk', 'pa_pk'], how='inner')
.join(_player_df.rename({'player_id': 'pitcher'}), on='pitcher', how='inner')
.join(_game_df, on=['game_pk'])
)
.with_columns(
pl.col('description').is_in(['SS', 'K']).alias('whiff'),
~pl.col('description').is_in(['B', 'BB', 'LS', 'inv_K', 'bunt_K', 'HBP', 'SH', 'SH E', 'SH FC', 'obstruction', 'illegal_pitch', 'defensive_interference']).alias('swing'),
pl.col('description').is_in(['SS', 'K', 'LS', 'inv_K']).alias('csw'),
~pl.col('description').is_in(['obstruction', 'illegal_pitch', 'defensive_interference']).alias('normal_pitch'), # guess
pl.col('game_date').str.to_datetime()
)
).sort(['game_pk', 'pa_pk', 'pitch_id'])
# add players to pa_df
# unfortunately we have pas that don't show up in the pitch data, so this would be useful for
_pa_df = _pa_df.join(_player_df.rename({'player_id': 'pitcher'}), on='pitcher', how='inner')
# add season dfs to main dfs
game_df.append(_game_df)
pa_df.append(_pa_df)
pitch_df.append(_pitch_df)
player_df.append(_player_df)
df.append(_df)
def compare(list_0, list_1):
print(f'In 0 but not in 1: {[item for item in list_0 if item not in list_1]}')
print(f'In 1 but not in 0: {[item for item in list_1 if item not in list_0]}')
# combine all season dfs
game_df = pl.concat(game_df)
try:
pa_df = pl.concat(pa_df)
except Exception as _:
print('pa_df')
compare(*[_pa_df.columns for _pa_df in pa_df])
try:
pitch_df = pl.concat(pitch_df)
except Exception as _:
print('pitch_df')
compare(*[_pitch_df.columns for _pitch_df in pitch_df])
player_df = pl.concat(player_df).unique()
try:
df = pl.concat(df)
except Exception as _:
print('df')
compare(*[_df.columns for _df in df])
assert len(_game_df) == len(_game_df['game_pk'].unique())
# pitch_stats, rhb_pitch_stats, lhb_pitch_stats = [
# (
# _df
# .group_by(['name', 'pitch_name'])
# .agg(
# ((pl.col('whiff').sum() / pl.col('swing').sum()) * 100).round(1).alias('Whiff%'),
# ((pl.col('csw').sum() / pl.col('normal_pitch').sum()) * 100).round(1).alias('CSW%'),
# pl.col('release_speed').mean().round(1).alias('Velocity'),
# pl.len().alias('Count')
# )
# .sort(['name', 'Count'], descending=[False, True])
# # .rename({'name': 'Player', 'pitch_name': 'Pitch'})
# )
# for _df
# in (
# df,
# df.filter(pl.col('stand') == 'R'),
# df.filter(pl.col('stand') == 'L'),
# )
# ]
# league_pitch_stats, rhb_league_pitch_stats, lhb_league_pitch_stats = [
# _df.group_by('pitch_name').agg(pl.col('release_speed').mean().round(1).alias('Velocity'))
# for _df
# in (
# df,
# df.filter(pl.col('stand') == 'R'),
# df.filter(pl.col('stand') == 'L'),
# )
# ]
def compute_pitch_stats(df):
pitch_stats = (
df
.group_by(['name', 'pitch_name'])
.agg(
((pl.col('whiff').sum() / pl.col('swing').sum()) * 100).round(1).alias('Whiff%'),
((pl.col('csw').sum() / pl.col('normal_pitch').sum()) * 100).round(1).alias('CSW%'),
pl.col('release_speed').mean().round(1).alias('Velocity'),
pl.len().alias('Count')
)
.sort(['name', 'Count'], descending=[False, True])
)
return pitch_stats
pitch_stats = compute_pitch_stats(df)
def compute_league_pitch_stats(df):
return df.group_by('pitch_name').agg(pl.col('release_speed').mean().round(1).alias('Velocity'))
league_pitch_stats = compute_league_pitch_stats(df)
if __name__ == '__main__':
print(df.shape)
print(df.columns)
breakpoint()