Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from sklearn.model_selection import train_test_split | |
import torch | |
batting = pd.read_csv('Batting.csv') | |
salaries = pd.read_csv('Salaries.csv') | |
df = pd.merge(batting, salaries, on=['playerID', 'yearID', 'teamID']) | |
df['BA'] = df['H'] / df['AB'] | |
df['OBP'] = (df['H'] + df['BB'] + df['HBP']) / (df['AB'] + df['BB'] + df['HBP'] + df['SF']) | |
# Slugging percentage | |
df['SLG'] = (df['H'] + 2*df['2B'] + 3*df['3B'] + 4*df['HR']) / df['AB'] | |
df = df.dropna() | |
features = ['BA', 'OBP', 'SLG', 'HR', 'RBI', 'SB'] | |
X = df[features] | |
y = df['salary'] | |
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) | |
X_train_tensor = torch.tensor(X_train.fillna(0).values, dtype=torch.float32) | |
y_train_tensor = torch.tensor(y_train.fillna(0).values, dtype=torch.float32) | |
X_test_tensor = torch.FloatTensor(X_test.values) | |
y_test_tensor = torch.FloatTensor(y_test.values) | |
class LinearRegression(torch.nn.Module): | |
def __init__(self, input_dim): | |
super(LinearRegression, self).__init__() | |
self.linear = torch.nn.Linear(input_dim, 1) | |
def forward(self, x): | |
return self.linear(x) | |
model = LinearRegression(len(features)) | |
criterion = torch.nn.MSELoss() | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) | |
num_epochs = 1000 | |
for epoch in range(num_epochs): | |
outputs = model(X_train_tensor) | |
loss = criterion(outputs, y_train_tensor.unsqueeze(1)) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
# Define prediction function for Gradio | |
def predict_salary(BA, OBP, SLG, HR, RBI, SB): | |
stats = [BA, OBP, SLG, HR, RBI, SB] | |
with torch.no_grad(): | |
stats_tensor = torch.FloatTensor([stats]) | |
predicted_salary = model(stats_tensor).item() | |
return f'${predicted_salary:,.2f}' | |
# Gradio interface | |
demo = gr.Interface( | |
fn=predict_salary, | |
inputs=[ | |
gr.components.Number(label="Batting Average (BA)"), | |
gr.components.Number(label="On-base Percentage (OBP)"), | |
gr.components.Number(label="Slugging Percentage (SLG)"), | |
gr.components.Number(label="Home Runs (HR)"), | |
gr.components.Number(label="Runs Batted In (RBI)"), | |
gr.components.Number(label="Stolen Bases (SB)") | |
], | |
outputs="text", | |
title="Baseball Player Salary Predictor" | |
) | |
# Launch the app | |
demo.launch(share=True) | |