Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
# Make sure to import the correct module dynamically based on the task | |
from pycaret.classification import load_model, predict_model | |
import os | |
import warnings # Added to potentially suppress warnings | |
import logging # Added for better debugging in the Space | |
# --- Page Configuration (MUST BE FIRST STREAMLIT COMMAND) --- | |
APP_TITLE = "my-pycaret-app" | |
st.set_page_config(page_title=APP_TITLE, layout="centered", initial_sidebar_state="collapsed") | |
# Configure simple logging for the Streamlit app | |
# Use Streamlit logger if available, otherwise basic config | |
try: | |
# Attempt to get logger specific to Streamlit context | |
logger = st.logger.get_logger(__name__) | |
except AttributeError: # Fallback for older Streamlit versions or different contexts | |
# Basic logging setup if Streamlit logger isn't available | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - StreamlitApp - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- Model Configuration --- | |
MODEL_FILE = "model.pkl" # Relative path within the Space | |
# --- Processed Schema (for type checking later) --- | |
# Use double braces to embed the schema dict correctly in the generated code | |
APP_SCHEMA = {'PassengerId': {'type': 'numerical'}, 'Pclass': {'type': 'numerical'}, 'Name': {'type': 'numerical'}, 'Sex': {'type': 'numerical'}, 'Age': {'type': 'numerical'}, 'SibSp': {'type': 'numerical'}, 'Parch': {'type': 'numerical'}, 'Ticket': {'type': 'numerical'}, 'Fare': {'type': 'numerical'}, 'Cabin': {'type': 'numerical'}, 'Embarked': {'type': 'numerical'}, 'Survived': {'type': 'numerical'}} | |
# --- Load Model --- | |
# Use cache_resource for efficient loading | |
def get_model(): | |
logger.info(f"Attempting to load model from file: {MODEL_FILE}") | |
# Define the path expected by PyCaret's load_model (without extension) | |
model_load_path = MODEL_FILE.replace('.pkl','') | |
logger.info(f"Calculated PyCaret load path: '{model_load_path}'") # Escaped braces | |
if not os.path.exists(MODEL_FILE): | |
st.error(f"Model file '{MODEL_FILE}' not found in the Space repository.") | |
logger.error(f"Model file '{MODEL_FILE}' not found at expected path.") | |
return None | |
try: | |
# Suppress specific warnings during loading if needed | |
# warnings.filterwarnings("ignore", category=UserWarning, message=".*Trying to unpickle estimator.*") | |
logger.info(f"Calling PyCaret's load_model('{model_load_path}')...") # Escaped braces | |
# Ensure PyCaret logging doesn't interfere excessively if needed | |
# from pycaret.utils.generic import enable_colab | |
# enable_colab() # May help manage output/logging in some environments | |
model = load_model(model_load_path) | |
logger.info("PyCaret's load_model executed successfully.") | |
return model | |
except FileNotFoundError: | |
# Specific handling if load_model itself can't find related files (like preprocess.pkl) | |
st.error(f"Error loading model components for '{model_load_path}'. PyCaret's load_model failed, possibly missing auxiliary files.") # Escaped braces | |
logger.exception(f"PyCaret load_model failed for '{model_load_path}', likely due to missing components:") # Escaped braces | |
return None | |
except Exception as e: | |
# Catch other potential errors during model loading | |
st.error(f"An unexpected error occurred loading model '{model_load_path}': {e}") # Escaped braces around model_load_path | |
logger.exception("Unexpected model loading error details:") # Log full traceback | |
return None | |
# --- Load the model --- | |
model = get_model() | |
# --- App Layout --- | |
st.title(APP_TITLE) # Title now comes after page config | |
if model is None: | |
st.error("Model could not be loaded. Please check the application logs in the Space settings for more details. Application cannot proceed.") | |
else: | |
st.success("Model loaded successfully!") # Indicate success | |
st.markdown("Provide the input features below to generate a prediction using the deployed model.") | |
# --- Input Section --- | |
st.header("Model Inputs") | |
with st.form("prediction_form"): | |
# Dynamically generated widgets based on schema (now with correct indentation) | |
input_PassengerId = st.number_input(label='PassengerId', format='%f', key='input_PassengerId') | |
input_Pclass = st.number_input(label='Pclass', format='%f', key='input_Pclass') | |
input_Name = st.number_input(label='Name', format='%f', key='input_Name') | |
input_Sex = st.number_input(label='Sex', format='%f', key='input_Sex') | |
input_Age = st.number_input(label='Age', format='%f', key='input_Age') | |
input_SibSp = st.number_input(label='SibSp', format='%f', key='input_SibSp') | |
input_Parch = st.number_input(label='Parch', format='%f', key='input_Parch') | |
input_Ticket = st.number_input(label='Ticket', format='%f', key='input_Ticket') | |
input_Fare = st.number_input(label='Fare', format='%f', key='input_Fare') | |
input_Cabin = st.number_input(label='Cabin', format='%f', key='input_Cabin') | |
input_Embarked = st.number_input(label='Embarked', format='%f', key='input_Embarked') | |
input_Survived = st.number_input(label='Survived', format='%f', key='input_Survived') | |
submitted = st.form_submit_button("π Get Prediction") | |
# --- Prediction Logic & Output Section --- | |
if submitted: | |
st.header("Prediction Output") | |
try: | |
# Create DataFrame from inputs using original feature names as keys | |
# The values are automatically fetched by Streamlit using the keys assigned to widgets | |
input_data_dict = {'PassengerId': input_PassengerId, 'Pclass': input_Pclass, 'Name': input_Name, 'Sex': input_Sex, 'Age': input_Age, 'SibSp': input_SibSp, 'Parch': input_Parch, 'Ticket': input_Ticket, 'Fare': input_Fare, 'Cabin': input_Cabin, 'Embarked': input_Embarked, 'Survived': input_Survived} # Use triple braces for dict literal inside f-string | |
logger.info(f"Raw input data from form: {input_data_dict}") | |
input_data = pd.DataFrame([input_data_dict]) | |
# Ensure correct dtypes based on schema before prediction | |
logger.info("Applying dtypes based on schema...") | |
# Use APP_SCHEMA defined earlier | |
for feature, details in APP_SCHEMA.items(): | |
feature_type = details.get("type", "text").lower() | |
if feature in input_data.columns: # Check if feature exists | |
try: | |
current_value = input_data[feature].iloc[0] | |
# Skip conversion if value is already None or NaN equivalent | |
if pd.isna(current_value): | |
continue | |
if feature_type == 'numerical': | |
# Convert to numeric, coercing errors (users might enter text) | |
input_data[feature] = pd.to_numeric(input_data[feature], errors='coerce') | |
elif feature_type == 'categorical': | |
# Ensure categorical inputs are treated as strings by the model if needed | |
# PyCaret often expects object/string type for categoricals in predict_model | |
input_data[feature] = input_data[feature].astype(str) | |
# Add elif for other types if needed (e.g., datetime) | |
# else: # text | |
# input_data[feature] = input_data[feature].astype(str) # Ensure string type | |
except Exception as type_e: | |
logger.warning(f"Could not convert feature '{feature}' (value: {current_value}) to type '{feature_type}'. Error: {type_e}") | |
# Decide how to handle type conversion errors, e.g., set to NaN or keep original | |
input_data[feature] = pd.NA # Set to missing if conversion fails | |
else: | |
logger.warning(f"Feature '{feature}' from schema not found in input form data.") | |
# Handle potential NaN values from coercion or failed conversion | |
if input_data.isnull().values.any(): | |
st.warning("Some inputs might be invalid or missing. Attempting to handle missing values (e.g., replacing with 0 for numerical). Check logs for details.") | |
logger.warning(f"NaN values found in input data after type conversion/validation. Filling numerical with 0. Data before fill:\n{input_data}") | |
# More robust imputation might be needed depending on the model | |
# Fill only numerical NaNs with 0, leave others? Or use mode for categoricals? | |
for feature, details in APP_SCHEMA.items(): | |
# Check if column exists before attempting to fill | |
if feature in input_data.columns and details.get("type") == "numerical" and input_data[feature].isnull().any(): | |
input_data[feature].fillna(0, inplace=True) | |
# input_data.fillna(0, inplace=True) # Previous simpler strategy | |
logger.info(f"Data after filling NaN:\n{input_data}") | |
st.markdown("##### Input Data Sent to Model (after processing):") | |
st.dataframe(input_data) | |
# Make prediction | |
logger.info("Calling predict_model...") | |
with st.spinner("Predicting..."): | |
# Suppress prediction warnings if needed | |
# with warnings.catch_warnings(): | |
# warnings.simplefilter("ignore") | |
predictions = predict_model(model, data=input_data) | |
logger.info("Prediction successful.") | |
st.markdown("##### Prediction Result:") | |
logger.info(f"Prediction output columns: {predictions.columns.tolist()}") | |
# Display relevant prediction columns (adjust based on PyCaret task) | |
# Common columns: 'prediction_label', 'prediction_score' | |
pred_col_label = 'prediction_label' | |
pred_col_score = 'prediction_score' | |
if pred_col_label in predictions.columns: | |
st.success(f"Predicted Label: **{predictions[pred_col_label].iloc[0]}**") | |
# Also show score if available for classification | |
if pred_col_score in predictions.columns and pycaret_task_module == 'pycaret.classification': | |
st.info(f"Prediction Score: **{predictions[pred_col_score].iloc[0]:.4f}**") | |
# Handle regression output (usually just score) | |
elif pred_col_score in predictions.columns and pycaret_task_module == 'pycaret.regression': | |
st.success(f"Predicted Value: **{predictions[pred_col_score].iloc[0]:.4f}**") | |
else: | |
# Fallback: Display the last column as prediction if specific ones aren't found | |
try: | |
# Exclude input columns if they are present in the output df | |
output_columns = [col for col in predictions.columns if col not in input_data.columns] | |
if output_columns: | |
last_col_name = output_columns[-1] | |
st.info(f"Prediction Output (Column: '{last_col_name}'): **{predictions[last_col_name].iloc[0]}**") | |
logger.warning(f"Could not find standard prediction columns. Displaying last non-input column: '{last_col_name}'") | |
else: # If only input columns are returned (unlikely) | |
st.warning("Prediction output seems to only contain input columns.") | |
except IndexError: | |
st.error("Prediction result DataFrame is empty or has unexpected format.") | |
logger.error("Prediction result DataFrame is empty or has unexpected format.") | |
# Show full prediction output optionally | |
with st.expander("View Full Prediction Output DataFrame"): | |
st.dataframe(predictions) | |
except Exception as e: | |
st.error(f"An error occurred during prediction: {e}") | |
logger.exception("Prediction error details:") # Log full traceback | |