File size: 9,961 Bytes
3a79137
 
 
c95e38c
3a79137
 
 
 
 
c95e38c
 
 
 
3a79137
 
 
 
 
 
 
 
 
 
 
c95e38c
3a79137
c95e38c
3a79137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c95e38c
3a79137
 
 
c95e38c
3a79137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

import streamlit as st
import pandas as pd
# Make sure to import the correct module dynamically based on the task
from pycaret.classification import load_model, predict_model
import os
import warnings # Added to potentially suppress warnings
import logging # Added for better debugging in the Space

# --- Page Configuration (MUST BE FIRST STREAMLIT COMMAND) ---
APP_TITLE = "my-pycaret-app"
st.set_page_config(page_title=APP_TITLE, layout="centered")

# Configure simple logging for the Streamlit app
# Use Streamlit logger if available, otherwise basic config
try:
    # Attempt to get logger specific to Streamlit context
    logger = st.logger.get_logger(__name__)
except AttributeError: # Fallback for older Streamlit versions or different contexts
    # Basic logging setup if Streamlit logger isn't available
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - StreamlitApp - %(levelname)s - %(message)s')
    logger = logging.getLogger(__name__)


# --- Model Configuration ---
MODEL_FILE = "model.pkl" # Relative path within the Space


# --- Load Model ---
# Use cache_resource for efficient loading
@st.cache_resource
def get_model():
    logger.info(f"Attempting to load model from file: {MODEL_FILE}")
    # Define the path expected by PyCaret's load_model (without extension)
    model_load_path = MODEL_FILE.replace('.pkl','')
    logger.info(f"Calculated PyCaret load path: '{model_load_path}'") # Escaped braces

    if not os.path.exists(MODEL_FILE):
        st.error(f"Model file '{MODEL_FILE}' not found in the Space repository.")
        logger.error(f"Model file '{MODEL_FILE}' not found at expected path.")
        return None
    try:
        # Suppress specific warnings during loading if needed
        # warnings.filterwarnings("ignore", category=UserWarning, message=".*Trying to unpickle estimator.*")
        logger.info(f"Calling PyCaret's load_model('{model_load_path}')...") # Escaped braces
        # Ensure PyCaret logging doesn't interfere excessively if needed
        # from pycaret.utils.generic import enable_colab
        # enable_colab() # May help manage output/logging in some environments
        model = load_model(model_load_path)
        logger.info("PyCaret's load_model executed successfully.")
        return model
    except FileNotFoundError:
        # Specific handling if load_model itself can't find related files (like preprocess.pkl)
        st.error(f"Error loading model components for '{model_load_path}'. PyCaret's load_model failed, possibly missing auxiliary files.") # Escaped braces
        logger.exception(f"PyCaret load_model failed for '{model_load_path}', likely due to missing components:") # Escaped braces
        return None
    except Exception as e:
        # Catch other potential errors during model loading
        st.error(f"An unexpected error occurred loading model '{model_load_path}': {e}") # Escaped braces around model_load_path
        logger.exception("Unexpected model loading error details:") # Log full traceback
        return None

# --- Load the model ---
model = get_model()

# --- App Layout ---
st.title(APP_TITLE) # Title now comes after page config and model loading attempt

if model is None:
    st.error("Model could not be loaded. Please check the application logs in the Space settings for more details. Application cannot proceed.")
else:
    st.success("Model loaded successfully!") # Indicate success
    st.write("Enter the input features below to get a prediction.")

    # --- Input Widgets ---
    with st.form("prediction_form"):
        st.subheader("Input Features:")
        # Dynamically generated widgets based on schema
        input_PassengerId = st.number_input(label='PassengerId', format='%f', key='input_PassengerId')
        input_Pclass = st.number_input(label='Pclass', format='%f', key='input_Pclass')
        input_Name = st.number_input(label='Name', format='%f', key='input_Name')
        input_Sex = st.number_input(label='Sex', format='%f', key='input_Sex')
        input_Age = st.number_input(label='Age', format='%f', key='input_Age')
        input_SibSp = st.number_input(label='SibSp', format='%f', key='input_SibSp')
        input_Parch = st.number_input(label='Parch', format='%f', key='input_Parch')
        input_Ticket = st.number_input(label='Ticket', format='%f', key='input_Ticket')
        input_Fare = st.number_input(label='Fare', format='%f', key='input_Fare')
        input_Cabin = st.number_input(label='Cabin', format='%f', key='input_Cabin')
        input_Embarked = st.number_input(label='Embarked', format='%f', key='input_Embarked')
        input_Survived = st.number_input(label='Survived', format='%f', key='input_Survived')
        submitted = st.form_submit_button("Predict")

    # --- Prediction Logic ---
    if submitted:
        try:
            # Create DataFrame from inputs using original feature names as keys
            # The values are automatically fetched by Streamlit using the keys assigned to widgets
            input_data_dict = {'PassengerId': input_PassengerId, 'Pclass': input_Pclass, 'Name': input_Name, 'Sex': input_Sex, 'Age': input_Age, 'SibSp': input_SibSp, 'Parch': input_Parch, 'Ticket': input_Ticket, 'Fare': input_Fare, 'Cabin': input_Cabin, 'Embarked': input_Embarked, 'Survived': input_Survived} # Use triple braces for dict literal inside f-string
            logger.info(f"Raw input data from form: {input_data_dict}")
            input_data = pd.DataFrame([input_data_dict])

            # Ensure correct dtypes based on schema before prediction
            logger.info("Applying dtypes based on schema...")
            # Use double braces for the schema dict literal in the generated code
            for feature, f_type in {'PassengerId': 'numerical', 'Pclass': 'numerical', 'Name': 'numerical', 'Sex': 'numerical', 'Age': 'numerical', 'SibSp': 'numerical', 'Parch': 'numerical', 'Ticket': 'numerical', 'Fare': 'numerical', 'Cabin': 'numerical', 'Embarked': 'numerical', 'Survived': 'numerical'}.items():
                 if feature in input_data.columns: # Check if feature exists
                     try:
                         if f_type == 'numerical':
                             # Convert to numeric, coercing errors (users might enter text)
                             input_data[feature] = pd.to_numeric(input_data[feature], errors='coerce')
                         # Add elif for 'categorical' or other types if needed
                         # else:
                         #     input_data[feature] = input_data[feature].astype(str) # Ensure string type
                     except Exception as type_e:
                         logger.warning(f"Could not convert feature '{feature}' to type '{f_type}'. Error: {type_e}")
                         # Decide how to handle type conversion errors, e.g., set to NaN or keep original
                         input_data[feature] = pd.NA # Set to missing if conversion fails

                 else:
                     logger.warning(f"Feature '{feature}' from schema not found in input form data.")


            # Handle potential NaN values from coercion or failed conversion
            if input_data.isnull().values.any():
                 st.warning("Some inputs might be invalid or missing. Attempting to handle missing values (e.g., replacing with 0).")
                 logger.warning(f"NaN values found in input data after type conversion/validation. Filling with 0. Data before fill:\n{input_data}")
                 # More robust imputation might be needed depending on the model
                 input_data.fillna(0, inplace=True) # Simple imputation strategy
                 logger.info(f"Data after filling NaN with 0:\n{input_data}")


            st.write("Input Data for Prediction (after processing):")
            st.dataframe(input_data)

            # Make prediction
            logger.info("Calling predict_model...")
            with st.spinner("Predicting..."):
                # Suppress prediction warnings if needed
                # with warnings.catch_warnings():
                #    warnings.simplefilter("ignore")
                predictions = predict_model(model, data=input_data)
                logger.info("Prediction successful.")

            st.subheader("Prediction Result:")
            logger.info(f"Prediction output columns: {predictions.columns.tolist()}")

            # Display relevant prediction columns (adjust based on PyCaret task)
            # Common columns: 'prediction_label', 'prediction_score'
            pred_col_label = 'prediction_label'
            pred_col_score = 'prediction_score'

            if pred_col_label in predictions.columns:
                st.success(f"Predicted Label: **{predictions[pred_col_label].iloc[0]}**")
            elif pred_col_score in predictions.columns: # Show score if label not present (e.g., regression)
                 st.success(f"Prediction Score: **{predictions[pred_col_score].iloc[0]:.4f}**")
            else:
                 # Fallback: Display the last column as prediction if specific ones aren't found
                 try:
                     last_col_name = predictions.columns[-1]
                     st.info(f"Prediction Output (Column: '{last_col_name}'): **{predictions[last_col_name].iloc[0]}**")
                     logger.warning(f"Could not find 'prediction_label' or 'prediction_score'. Displaying last column: '{last_col_name}'")
                 except IndexError:
                     st.error("Prediction result DataFrame is empty.")
                     logger.error("Prediction result DataFrame is empty.")


            # Show full prediction output optionally
            with st.expander("See Full Prediction Output DataFrame"):
                st.dataframe(predictions)

        except Exception as e:
            st.error(f"An error occurred during prediction: {e}")
            logger.exception("Prediction error details:") # Log full traceback