Resistivity / app.py
Sompote's picture
Upload 6 files
778f96f verified
import streamlit as st
import pandas as pd
import numpy as np
import torch
from predict import TabularTransformer, model_predict
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import shap
# Set page config
st.set_page_config(
page_title="Resistivity Prediction App",
page_icon="๐Ÿ”ฎ",
layout="wide"
)
# Title and description
st.title("Resistivity Prediction App")
st.markdown("""
This app predicts resistivity based on input features. Enter the values for each feature
and click 'Predict' to get the prediction and explanation.
""")
@st.cache_resource
def load_model_and_scalers():
# Load data for scaling
df = pd.read_excel('data.xlsx')
X = df.iloc[:, 0:8]
y = df.iloc[:, 8]
feature_names = X.columns.tolist()
# Initialize scalers
scaler_X = MinMaxScaler()
scaler_y = MinMaxScaler()
# Fit scalers
scaler_X.fit(X)
scaler_y.fit(y.values.reshape(-1, 1))
# Load model
model = TabularTransformer(input_dim=8, output_dim=1)
model.load_state_dict(torch.load('model.pth'))
model.eval()
return model, scaler_X, scaler_y, feature_names, X
def explain_prediction(model, input_df, X_background, scaler_X, scaler_y, feature_names):
# Create a prediction function for SHAP
def predict_fn(X):
X_tensor = torch.FloatTensor(scaler_X.transform(X))
with torch.no_grad():
scaled_pred = model(X_tensor).numpy()
return scaler_y.inverse_transform(scaled_pred)
# Use a subset of training data as background
background_sample = X_background.sample(n=min(100, len(X_background)), random_state=42)
explainer = shap.KernelExplainer(predict_fn, background_sample)
# Calculate SHAP values for the input
shap_values = explainer.shap_values(input_df)
# Handle different SHAP value formats
if isinstance(shap_values, list):
shap_values = np.array(shap_values[0])
# Ensure correct shape for waterfall plot
if len(shap_values.shape) > 1:
if shap_values.shape[0] == len(feature_names):
shap_values = shap_values.T
shap_values = shap_values.flatten()
# Create waterfall plot
plt.figure(figsize=(12, 8))
shap.plots.waterfall(
shap.Explanation(
values=shap_values,
base_values=explainer.expected_value if np.isscalar(explainer.expected_value)
else explainer.expected_value[0],
data=input_df.iloc[0].values,
feature_names=feature_names
),
show=False
)
plt.title('SHAP Value Contributions')
plt.tight_layout()
plt.savefig('shap_explanation.png', dpi=300, bbox_inches='tight')
plt.close()
return explainer.expected_value, shap_values
# Load model and scalers
try:
model, scaler_X, scaler_y, feature_names, X = load_model_and_scalers()
# Create input fields for features
st.subheader("Input Features")
# Create two columns for input fields
col1, col2 = st.columns(2)
# Dictionary to store input values
input_values = {}
# Create input fields split between two columns
for i, feature in enumerate(feature_names):
# Get min and max values for each feature
min_val = float(X[feature].min())
max_val = float(X[feature].max())
# Add input field to alternating columns
with col1 if i < len(feature_names)//2 else col2:
input_values[feature] = st.number_input(
f"{feature}",
min_value=float(min_val),
max_value=float(max_val),
value=float(X[feature].mean()),
help=f"Range: {min_val:.2f} to {max_val:.2f}"
)
# Add predict button
if st.button("Predict"):
# Create input DataFrame
input_df = pd.DataFrame([input_values])
# Make prediction
prediction = model_predict(model, input_df, scaler_X, scaler_y)
# Display prediction
st.subheader("Prediction Result")
st.markdown(f"### Predicted Resistivity: {prediction[0]:.2f}")
# Calculate and display SHAP values
st.subheader("Feature Importance Explanation")
# Get SHAP values using the training data as background
expected_value, shap_values = explain_prediction(
model, input_df, X, scaler_X, scaler_y, feature_names
)
# Display the waterfall plot
st.image('shap_explanation.png')
except Exception as e:
st.error(f"""
Error loading the model and data. Please make sure:
1. The model file 'model.pth' exists
2. The data file 'data.xlsx' exists
3. All required packages are installed
Error details: {str(e)}
""")