import gradio as gr import pandas as pd import numpy as np from sklearn.model_selection import train_test_split import torch batting = pd.read_csv('Batting.csv') salaries = pd.read_csv('Salaries.csv') df = pd.merge(batting, salaries, on=['playerID', 'yearID', 'teamID']) df['BA'] = df['H'] / df['AB'] df['OBP'] = (df['H'] + df['BB'] + df['HBP']) / (df['AB'] + df['BB'] + df['HBP'] + df['SF']) # Slugging percentage df['SLG'] = (df['H'] + 2*df['2B'] + 3*df['3B'] + 4*df['HR']) / df['AB'] df = df.dropna() features = ['BA', 'OBP', 'SLG', 'HR', 'RBI', 'SB'] X = df[features] y = df['salary'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) X_train_tensor = torch.tensor(X_train.fillna(0).values, dtype=torch.float32) y_train_tensor = torch.tensor(y_train.fillna(0).values, dtype=torch.float32) X_test_tensor = torch.FloatTensor(X_test.values) y_test_tensor = torch.FloatTensor(y_test.values) class LinearRegression(torch.nn.Module): def __init__(self, input_dim): super(LinearRegression, self).__init__() self.linear = torch.nn.Linear(input_dim, 1) def forward(self, x): return self.linear(x) model = LinearRegression(len(features)) criterion = torch.nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) num_epochs = 1000 for epoch in range(num_epochs): outputs = model(X_train_tensor) loss = criterion(outputs, y_train_tensor.unsqueeze(1)) optimizer.zero_grad() loss.backward() optimizer.step() # Define prediction function for Gradio def predict_salary(BA, OBP, SLG, HR, RBI, SB): stats = [BA, OBP, SLG, HR, RBI, SB] with torch.no_grad(): stats_tensor = torch.FloatTensor([stats]) predicted_salary = model(stats_tensor).item() return f'${predicted_salary:,.2f}' # Gradio interface demo = gr.Interface( fn=predict_salary, inputs=[ gr.components.Number(label="Batting Average (BA)"), gr.components.Number(label="On-base Percentage (OBP)"), gr.components.Number(label="Slugging Percentage (SLG)"), gr.components.Number(label="Home Runs (HR)"), gr.components.Number(label="Runs Batted In (RBI)"), gr.components.Number(label="Stolen Bases (SB)") ], outputs="text", title="Baseball Player Salary Predictor" ) # Launch the app demo.launch(share=True)