Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import torch | |
from predict import TabularTransformer, model_predict | |
from sklearn.preprocessing import MinMaxScaler | |
import matplotlib.pyplot as plt | |
import shap | |
# Set page config | |
st.set_page_config( | |
page_title="Resistivity Prediction App", | |
page_icon="๐ฎ", | |
layout="wide" | |
) | |
# Title and description | |
st.title("Resistivity Prediction App") | |
st.markdown(""" | |
This app predicts resistivity based on input features. Enter the values for each feature | |
and click 'Predict' to get the prediction and explanation. | |
""") | |
def load_model_and_scalers(): | |
# Load data for scaling | |
df = pd.read_excel('data.xlsx') | |
X = df.iloc[:, 0:8] | |
y = df.iloc[:, 8] | |
feature_names = X.columns.tolist() | |
# Initialize scalers | |
scaler_X = MinMaxScaler() | |
scaler_y = MinMaxScaler() | |
# Fit scalers | |
scaler_X.fit(X) | |
scaler_y.fit(y.values.reshape(-1, 1)) | |
# Load model | |
model = TabularTransformer(input_dim=8, output_dim=1) | |
model.load_state_dict(torch.load('model.pth')) | |
model.eval() | |
return model, scaler_X, scaler_y, feature_names, X | |
def explain_prediction(model, input_df, X_background, scaler_X, scaler_y, feature_names): | |
# Create a prediction function for SHAP | |
def predict_fn(X): | |
X_tensor = torch.FloatTensor(scaler_X.transform(X)) | |
with torch.no_grad(): | |
scaled_pred = model(X_tensor).numpy() | |
return scaler_y.inverse_transform(scaled_pred) | |
# Use a subset of training data as background | |
background_sample = X_background.sample(n=min(100, len(X_background)), random_state=42) | |
explainer = shap.KernelExplainer(predict_fn, background_sample) | |
# Calculate SHAP values for the input | |
shap_values = explainer.shap_values(input_df) | |
# Handle different SHAP value formats | |
if isinstance(shap_values, list): | |
shap_values = np.array(shap_values[0]) | |
# Ensure correct shape for waterfall plot | |
if len(shap_values.shape) > 1: | |
if shap_values.shape[0] == len(feature_names): | |
shap_values = shap_values.T | |
shap_values = shap_values.flatten() | |
# Create waterfall plot | |
plt.figure(figsize=(12, 8)) | |
shap.plots.waterfall( | |
shap.Explanation( | |
values=shap_values, | |
base_values=explainer.expected_value if np.isscalar(explainer.expected_value) | |
else explainer.expected_value[0], | |
data=input_df.iloc[0].values, | |
feature_names=feature_names | |
), | |
show=False | |
) | |
plt.title('SHAP Value Contributions') | |
plt.tight_layout() | |
plt.savefig('shap_explanation.png', dpi=300, bbox_inches='tight') | |
plt.close() | |
return explainer.expected_value, shap_values | |
# Load model and scalers | |
try: | |
model, scaler_X, scaler_y, feature_names, X = load_model_and_scalers() | |
# Create input fields for features | |
st.subheader("Input Features") | |
# Create two columns for input fields | |
col1, col2 = st.columns(2) | |
# Dictionary to store input values | |
input_values = {} | |
# Create input fields split between two columns | |
for i, feature in enumerate(feature_names): | |
# Get min and max values for each feature | |
min_val = float(X[feature].min()) | |
max_val = float(X[feature].max()) | |
# Add input field to alternating columns | |
with col1 if i < len(feature_names)//2 else col2: | |
input_values[feature] = st.number_input( | |
f"{feature}", | |
min_value=float(min_val), | |
max_value=float(max_val), | |
value=float(X[feature].mean()), | |
help=f"Range: {min_val:.2f} to {max_val:.2f}" | |
) | |
# Add predict button | |
if st.button("Predict"): | |
# Create input DataFrame | |
input_df = pd.DataFrame([input_values]) | |
# Make prediction | |
prediction = model_predict(model, input_df, scaler_X, scaler_y) | |
# Display prediction | |
st.subheader("Prediction Result") | |
st.markdown(f"### Predicted Resistivity: {prediction[0]:.2f}") | |
# Calculate and display SHAP values | |
st.subheader("Feature Importance Explanation") | |
# Get SHAP values using the training data as background | |
expected_value, shap_values = explain_prediction( | |
model, input_df, X, scaler_X, scaler_y, feature_names | |
) | |
# Display the waterfall plot | |
st.image('shap_explanation.png') | |
except Exception as e: | |
st.error(f""" | |
Error loading the model and data. Please make sure: | |
1. The model file 'model.pth' exists | |
2. The data file 'data.xlsx' exists | |
3. All required packages are installed | |
Error details: {str(e)} | |
""") |