Moneyball / app.py
umairrrkhan's picture
Update app.py
74f79ed verified
import gradio as gr
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import torch
batting = pd.read_csv('Batting.csv')
salaries = pd.read_csv('Salaries.csv')
df = pd.merge(batting, salaries, on=['playerID', 'yearID', 'teamID'])
df['BA'] = df['H'] / df['AB']
df['OBP'] = (df['H'] + df['BB'] + df['HBP']) / (df['AB'] + df['BB'] + df['HBP'] + df['SF'])
# Slugging percentage
df['SLG'] = (df['H'] + 2*df['2B'] + 3*df['3B'] + 4*df['HR']) / df['AB']
df = df.dropna()
features = ['BA', 'OBP', 'SLG', 'HR', 'RBI', 'SB']
X = df[features]
y = df['salary']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train_tensor = torch.tensor(X_train.fillna(0).values, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train.fillna(0).values, dtype=torch.float32)
X_test_tensor = torch.FloatTensor(X_test.values)
y_test_tensor = torch.FloatTensor(y_test.values)
class LinearRegression(torch.nn.Module):
def __init__(self, input_dim):
super(LinearRegression, self).__init__()
self.linear = torch.nn.Linear(input_dim, 1)
def forward(self, x):
return self.linear(x)
model = LinearRegression(len(features))
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
num_epochs = 1000
for epoch in range(num_epochs):
outputs = model(X_train_tensor)
loss = criterion(outputs, y_train_tensor.unsqueeze(1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Define prediction function for Gradio
def predict_salary(BA, OBP, SLG, HR, RBI, SB):
stats = [BA, OBP, SLG, HR, RBI, SB]
with torch.no_grad():
stats_tensor = torch.FloatTensor([stats])
predicted_salary = model(stats_tensor).item()
return f'${predicted_salary:,.2f}'
# Gradio interface
demo = gr.Interface(
fn=predict_salary,
inputs=[
gr.components.Number(label="Batting Average (BA)"),
gr.components.Number(label="On-base Percentage (OBP)"),
gr.components.Number(label="Slugging Percentage (SLG)"),
gr.components.Number(label="Home Runs (HR)"),
gr.components.Number(label="Runs Batted In (RBI)"),
gr.components.Number(label="Stolen Bases (SB)")
],
outputs="text",
title="Baseball Player Salary Predictor"
)
# Launch the app
demo.launch(share=True)