Spaces:
Sleeping
Sleeping
File size: 4,842 Bytes
778f96f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import streamlit as st
import pandas as pd
import numpy as np
import torch
from predict import TabularTransformer, model_predict
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import shap
# Set page config
st.set_page_config(
page_title="Resistivity Prediction App",
page_icon="🔮",
layout="wide"
)
# Title and description
st.title("Resistivity Prediction App")
st.markdown("""
This app predicts resistivity based on input features. Enter the values for each feature
and click 'Predict' to get the prediction and explanation.
""")
@st.cache_resource
def load_model_and_scalers():
# Load data for scaling
df = pd.read_excel('data.xlsx')
X = df.iloc[:, 0:8]
y = df.iloc[:, 8]
feature_names = X.columns.tolist()
# Initialize scalers
scaler_X = MinMaxScaler()
scaler_y = MinMaxScaler()
# Fit scalers
scaler_X.fit(X)
scaler_y.fit(y.values.reshape(-1, 1))
# Load model
model = TabularTransformer(input_dim=8, output_dim=1)
model.load_state_dict(torch.load('model.pth'))
model.eval()
return model, scaler_X, scaler_y, feature_names, X
def explain_prediction(model, input_df, X_background, scaler_X, scaler_y, feature_names):
# Create a prediction function for SHAP
def predict_fn(X):
X_tensor = torch.FloatTensor(scaler_X.transform(X))
with torch.no_grad():
scaled_pred = model(X_tensor).numpy()
return scaler_y.inverse_transform(scaled_pred)
# Use a subset of training data as background
background_sample = X_background.sample(n=min(100, len(X_background)), random_state=42)
explainer = shap.KernelExplainer(predict_fn, background_sample)
# Calculate SHAP values for the input
shap_values = explainer.shap_values(input_df)
# Handle different SHAP value formats
if isinstance(shap_values, list):
shap_values = np.array(shap_values[0])
# Ensure correct shape for waterfall plot
if len(shap_values.shape) > 1:
if shap_values.shape[0] == len(feature_names):
shap_values = shap_values.T
shap_values = shap_values.flatten()
# Create waterfall plot
plt.figure(figsize=(12, 8))
shap.plots.waterfall(
shap.Explanation(
values=shap_values,
base_values=explainer.expected_value if np.isscalar(explainer.expected_value)
else explainer.expected_value[0],
data=input_df.iloc[0].values,
feature_names=feature_names
),
show=False
)
plt.title('SHAP Value Contributions')
plt.tight_layout()
plt.savefig('shap_explanation.png', dpi=300, bbox_inches='tight')
plt.close()
return explainer.expected_value, shap_values
# Load model and scalers
try:
model, scaler_X, scaler_y, feature_names, X = load_model_and_scalers()
# Create input fields for features
st.subheader("Input Features")
# Create two columns for input fields
col1, col2 = st.columns(2)
# Dictionary to store input values
input_values = {}
# Create input fields split between two columns
for i, feature in enumerate(feature_names):
# Get min and max values for each feature
min_val = float(X[feature].min())
max_val = float(X[feature].max())
# Add input field to alternating columns
with col1 if i < len(feature_names)//2 else col2:
input_values[feature] = st.number_input(
f"{feature}",
min_value=float(min_val),
max_value=float(max_val),
value=float(X[feature].mean()),
help=f"Range: {min_val:.2f} to {max_val:.2f}"
)
# Add predict button
if st.button("Predict"):
# Create input DataFrame
input_df = pd.DataFrame([input_values])
# Make prediction
prediction = model_predict(model, input_df, scaler_X, scaler_y)
# Display prediction
st.subheader("Prediction Result")
st.markdown(f"### Predicted Resistivity: {prediction[0]:.2f}")
# Calculate and display SHAP values
st.subheader("Feature Importance Explanation")
# Get SHAP values using the training data as background
expected_value, shap_values = explain_prediction(
model, input_df, X, scaler_X, scaler_y, feature_names
)
# Display the waterfall plot
st.image('shap_explanation.png')
except Exception as e:
st.error(f"""
Error loading the model and data. Please make sure:
1. The model file 'model.pth' exists
2. The data file 'data.xlsx' exists
3. All required packages are installed
Error details: {str(e)}
""") |