import streamlit as st import pandas as pd import numpy as np import torch from predict import TabularTransformer, model_predict from sklearn.preprocessing import MinMaxScaler import matplotlib.pyplot as plt import shap # Set page config st.set_page_config( page_title="Resistivity Prediction App", page_icon="🔮", layout="wide" ) # Title and description st.title("Resistivity Prediction App") st.markdown(""" This app predicts resistivity based on input features. Enter the values for each feature and click 'Predict' to get the prediction and explanation. """) @st.cache_resource def load_model_and_scalers(): # Load data for scaling df = pd.read_excel('data.xlsx') X = df.iloc[:, 0:8] y = df.iloc[:, 8] feature_names = X.columns.tolist() # Initialize scalers scaler_X = MinMaxScaler() scaler_y = MinMaxScaler() # Fit scalers scaler_X.fit(X) scaler_y.fit(y.values.reshape(-1, 1)) # Load model model = TabularTransformer(input_dim=8, output_dim=1) model.load_state_dict(torch.load('model.pth')) model.eval() return model, scaler_X, scaler_y, feature_names, X def explain_prediction(model, input_df, X_background, scaler_X, scaler_y, feature_names): # Create a prediction function for SHAP def predict_fn(X): X_tensor = torch.FloatTensor(scaler_X.transform(X)) with torch.no_grad(): scaled_pred = model(X_tensor).numpy() return scaler_y.inverse_transform(scaled_pred) # Use a subset of training data as background background_sample = X_background.sample(n=min(100, len(X_background)), random_state=42) explainer = shap.KernelExplainer(predict_fn, background_sample) # Calculate SHAP values for the input shap_values = explainer.shap_values(input_df) # Handle different SHAP value formats if isinstance(shap_values, list): shap_values = np.array(shap_values[0]) # Ensure correct shape for waterfall plot if len(shap_values.shape) > 1: if shap_values.shape[0] == len(feature_names): shap_values = shap_values.T shap_values = shap_values.flatten() # Create waterfall plot plt.figure(figsize=(12, 8)) shap.plots.waterfall( shap.Explanation( values=shap_values, base_values=explainer.expected_value if np.isscalar(explainer.expected_value) else explainer.expected_value[0], data=input_df.iloc[0].values, feature_names=feature_names ), show=False ) plt.title('SHAP Value Contributions') plt.tight_layout() plt.savefig('shap_explanation.png', dpi=300, bbox_inches='tight') plt.close() return explainer.expected_value, shap_values # Load model and scalers try: model, scaler_X, scaler_y, feature_names, X = load_model_and_scalers() # Create input fields for features st.subheader("Input Features") # Create two columns for input fields col1, col2 = st.columns(2) # Dictionary to store input values input_values = {} # Create input fields split between two columns for i, feature in enumerate(feature_names): # Get min and max values for each feature min_val = float(X[feature].min()) max_val = float(X[feature].max()) # Add input field to alternating columns with col1 if i < len(feature_names)//2 else col2: input_values[feature] = st.number_input( f"{feature}", min_value=float(min_val), max_value=float(max_val), value=float(X[feature].mean()), help=f"Range: {min_val:.2f} to {max_val:.2f}" ) # Add predict button if st.button("Predict"): # Create input DataFrame input_df = pd.DataFrame([input_values]) # Make prediction prediction = model_predict(model, input_df, scaler_X, scaler_y) # Display prediction st.subheader("Prediction Result") st.markdown(f"### Predicted Resistivity: {prediction[0]:.2f}") # Calculate and display SHAP values st.subheader("Feature Importance Explanation") # Get SHAP values using the training data as background expected_value, shap_values = explain_prediction( model, input_df, X, scaler_X, scaler_y, feature_names ) # Display the waterfall plot st.image('shap_explanation.png') except Exception as e: st.error(f""" Error loading the model and data. Please make sure: 1. The model file 'model.pth' exists 2. The data file 'data.xlsx' exists 3. All required packages are installed Error details: {str(e)} """)