Spaces:
Running
Running
File size: 9,961 Bytes
3a79137 c95e38c 3a79137 c95e38c 3a79137 c95e38c 3a79137 c95e38c 3a79137 c95e38c 3a79137 c95e38c 3a79137 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import streamlit as st
import pandas as pd
# Make sure to import the correct module dynamically based on the task
from pycaret.classification import load_model, predict_model
import os
import warnings # Added to potentially suppress warnings
import logging # Added for better debugging in the Space
# --- Page Configuration (MUST BE FIRST STREAMLIT COMMAND) ---
APP_TITLE = "my-pycaret-app"
st.set_page_config(page_title=APP_TITLE, layout="centered")
# Configure simple logging for the Streamlit app
# Use Streamlit logger if available, otherwise basic config
try:
# Attempt to get logger specific to Streamlit context
logger = st.logger.get_logger(__name__)
except AttributeError: # Fallback for older Streamlit versions or different contexts
# Basic logging setup if Streamlit logger isn't available
logging.basicConfig(level=logging.INFO, format='%(asctime)s - StreamlitApp - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Model Configuration ---
MODEL_FILE = "model.pkl" # Relative path within the Space
# --- Load Model ---
# Use cache_resource for efficient loading
@st.cache_resource
def get_model():
logger.info(f"Attempting to load model from file: {MODEL_FILE}")
# Define the path expected by PyCaret's load_model (without extension)
model_load_path = MODEL_FILE.replace('.pkl','')
logger.info(f"Calculated PyCaret load path: '{model_load_path}'") # Escaped braces
if not os.path.exists(MODEL_FILE):
st.error(f"Model file '{MODEL_FILE}' not found in the Space repository.")
logger.error(f"Model file '{MODEL_FILE}' not found at expected path.")
return None
try:
# Suppress specific warnings during loading if needed
# warnings.filterwarnings("ignore", category=UserWarning, message=".*Trying to unpickle estimator.*")
logger.info(f"Calling PyCaret's load_model('{model_load_path}')...") # Escaped braces
# Ensure PyCaret logging doesn't interfere excessively if needed
# from pycaret.utils.generic import enable_colab
# enable_colab() # May help manage output/logging in some environments
model = load_model(model_load_path)
logger.info("PyCaret's load_model executed successfully.")
return model
except FileNotFoundError:
# Specific handling if load_model itself can't find related files (like preprocess.pkl)
st.error(f"Error loading model components for '{model_load_path}'. PyCaret's load_model failed, possibly missing auxiliary files.") # Escaped braces
logger.exception(f"PyCaret load_model failed for '{model_load_path}', likely due to missing components:") # Escaped braces
return None
except Exception as e:
# Catch other potential errors during model loading
st.error(f"An unexpected error occurred loading model '{model_load_path}': {e}") # Escaped braces around model_load_path
logger.exception("Unexpected model loading error details:") # Log full traceback
return None
# --- Load the model ---
model = get_model()
# --- App Layout ---
st.title(APP_TITLE) # Title now comes after page config and model loading attempt
if model is None:
st.error("Model could not be loaded. Please check the application logs in the Space settings for more details. Application cannot proceed.")
else:
st.success("Model loaded successfully!") # Indicate success
st.write("Enter the input features below to get a prediction.")
# --- Input Widgets ---
with st.form("prediction_form"):
st.subheader("Input Features:")
# Dynamically generated widgets based on schema
input_PassengerId = st.number_input(label='PassengerId', format='%f', key='input_PassengerId')
input_Pclass = st.number_input(label='Pclass', format='%f', key='input_Pclass')
input_Name = st.number_input(label='Name', format='%f', key='input_Name')
input_Sex = st.number_input(label='Sex', format='%f', key='input_Sex')
input_Age = st.number_input(label='Age', format='%f', key='input_Age')
input_SibSp = st.number_input(label='SibSp', format='%f', key='input_SibSp')
input_Parch = st.number_input(label='Parch', format='%f', key='input_Parch')
input_Ticket = st.number_input(label='Ticket', format='%f', key='input_Ticket')
input_Fare = st.number_input(label='Fare', format='%f', key='input_Fare')
input_Cabin = st.number_input(label='Cabin', format='%f', key='input_Cabin')
input_Embarked = st.number_input(label='Embarked', format='%f', key='input_Embarked')
input_Survived = st.number_input(label='Survived', format='%f', key='input_Survived')
submitted = st.form_submit_button("Predict")
# --- Prediction Logic ---
if submitted:
try:
# Create DataFrame from inputs using original feature names as keys
# The values are automatically fetched by Streamlit using the keys assigned to widgets
input_data_dict = {'PassengerId': input_PassengerId, 'Pclass': input_Pclass, 'Name': input_Name, 'Sex': input_Sex, 'Age': input_Age, 'SibSp': input_SibSp, 'Parch': input_Parch, 'Ticket': input_Ticket, 'Fare': input_Fare, 'Cabin': input_Cabin, 'Embarked': input_Embarked, 'Survived': input_Survived} # Use triple braces for dict literal inside f-string
logger.info(f"Raw input data from form: {input_data_dict}")
input_data = pd.DataFrame([input_data_dict])
# Ensure correct dtypes based on schema before prediction
logger.info("Applying dtypes based on schema...")
# Use double braces for the schema dict literal in the generated code
for feature, f_type in {'PassengerId': 'numerical', 'Pclass': 'numerical', 'Name': 'numerical', 'Sex': 'numerical', 'Age': 'numerical', 'SibSp': 'numerical', 'Parch': 'numerical', 'Ticket': 'numerical', 'Fare': 'numerical', 'Cabin': 'numerical', 'Embarked': 'numerical', 'Survived': 'numerical'}.items():
if feature in input_data.columns: # Check if feature exists
try:
if f_type == 'numerical':
# Convert to numeric, coercing errors (users might enter text)
input_data[feature] = pd.to_numeric(input_data[feature], errors='coerce')
# Add elif for 'categorical' or other types if needed
# else:
# input_data[feature] = input_data[feature].astype(str) # Ensure string type
except Exception as type_e:
logger.warning(f"Could not convert feature '{feature}' to type '{f_type}'. Error: {type_e}")
# Decide how to handle type conversion errors, e.g., set to NaN or keep original
input_data[feature] = pd.NA # Set to missing if conversion fails
else:
logger.warning(f"Feature '{feature}' from schema not found in input form data.")
# Handle potential NaN values from coercion or failed conversion
if input_data.isnull().values.any():
st.warning("Some inputs might be invalid or missing. Attempting to handle missing values (e.g., replacing with 0).")
logger.warning(f"NaN values found in input data after type conversion/validation. Filling with 0. Data before fill:\n{input_data}")
# More robust imputation might be needed depending on the model
input_data.fillna(0, inplace=True) # Simple imputation strategy
logger.info(f"Data after filling NaN with 0:\n{input_data}")
st.write("Input Data for Prediction (after processing):")
st.dataframe(input_data)
# Make prediction
logger.info("Calling predict_model...")
with st.spinner("Predicting..."):
# Suppress prediction warnings if needed
# with warnings.catch_warnings():
# warnings.simplefilter("ignore")
predictions = predict_model(model, data=input_data)
logger.info("Prediction successful.")
st.subheader("Prediction Result:")
logger.info(f"Prediction output columns: {predictions.columns.tolist()}")
# Display relevant prediction columns (adjust based on PyCaret task)
# Common columns: 'prediction_label', 'prediction_score'
pred_col_label = 'prediction_label'
pred_col_score = 'prediction_score'
if pred_col_label in predictions.columns:
st.success(f"Predicted Label: **{predictions[pred_col_label].iloc[0]}**")
elif pred_col_score in predictions.columns: # Show score if label not present (e.g., regression)
st.success(f"Prediction Score: **{predictions[pred_col_score].iloc[0]:.4f}**")
else:
# Fallback: Display the last column as prediction if specific ones aren't found
try:
last_col_name = predictions.columns[-1]
st.info(f"Prediction Output (Column: '{last_col_name}'): **{predictions[last_col_name].iloc[0]}**")
logger.warning(f"Could not find 'prediction_label' or 'prediction_score'. Displaying last column: '{last_col_name}'")
except IndexError:
st.error("Prediction result DataFrame is empty.")
logger.error("Prediction result DataFrame is empty.")
# Show full prediction output optionally
with st.expander("See Full Prediction Output DataFrame"):
st.dataframe(predictions)
except Exception as e:
st.error(f"An error occurred during prediction: {e}")
logger.exception("Prediction error details:") # Log full traceback
|