CarPricePred / app.py
Solar-Iz's picture
Update app.py
daaf663
raw
history blame
2.88 kB
import streamlit as st
import pandas as pd
import shap
from catboost import CatBoostRegressor
# Загрузка модели
final_model = CatBoostRegressor()
final_model.load_model('/Model/best_model.cbm')
# Загрузка данных для предобработки
df = pd.read_csv('/Dataset/car_data.txt', sep=',')
# Получаем уникальные бренды и их модели
unique_brands = df['brand'].unique()
brand_model_mapping = {brand: df[df['brand'] == brand]['model'].unique() for brand in unique_brands}
# Заголовок приложения
st.title("Прогнозирование цены автомобиля")
# Создаем окошки для ввода параметров
inputs = {}
# Выбор бренда
selected_brand = st.selectbox("Выберите бренд авто", unique_brands)
# Выбор модели в зависимости от бренда
selected_model = st.selectbox("Выберите модель авто", brand_model_mapping[selected_brand])
inputs['brand'] = selected_brand
inputs['model'] = selected_model
# Получаем категориальные столбцы из вашего DataFrame
categorical_columns = ['brand', 'model', 'поколение', 'тип продавца', 'состояние',
'модификация', 'тип двигателя', 'коробка передач', 'привод',
'комплектация', 'тип кузова', 'цвет', 'авито оценка']
# Остальные окошки для ввода параметров
for column in categorical_columns:
if column not in ['brand', 'model']:
if column == 'год выпуска' or column == 'пробег' or column == 'объем двигателя':
inputs[column] = st.number_input(f"Введите значение для {column}")
else:
inputs[column] = st.text_input(f"Введите значение для {column}")
# Кнопка для запуска предсказания
if st.button("Предсказать цену"):
# Создаем DataFrame из введенных данных
input_data = pd.DataFrame(inputs, index=[0])
# Получение предсказания
predicted_price = final_model.predict(input_data)[0]
# Вывод результатов
st.write(f"Прогнозируемая цена авто: {predicted_price} руб.")
# # Расчет важности фичей с использованием SHAP
# explainer = shap.TreeExplainer(final_model)
# shap_values = explainer.shap_values(input_data)
# # Отображение SHAP force plot
# st.write("SHAP Force Plot:")
# shap.force_plot(explainer.expected_value, shap_values, input_data, matplotlib=True, show=False)
# st.pyplot()