dperales's picture
Update app.py
af9f583
raw
history blame
1.14 kB
import os
import pycaret
from pycaret.datasets import get_data
# import pycaret clustering and init setup
from pycaret.clustering import *
# import ClusteringExperiment and init the class
from pycaret.clustering import ClusteringExperiment
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import streamlit as st
import plotly.graph_objs as go
# For measuring the inference time.
import time
def main():
data = get_data('jewellery')
s = setup(data, session_id = 123)
exp = ClusteringExperiment()
# init setup on exp
exp.setup(data, session_id = 123)
# train kmeans model
kmeans = create_model('kmeans')
kmeans_cluster = assign_model(kmeans)
kmeans_cluster
if st.button("Prediction"):
# plot pca cluster plot
plot = plot_model(kmeans, plot = 'cluster', save = True)
# Convert the plot to a compatible format for Streamlit
compatible_plot = go.Figure(data=plot['data'], layout=plot['layout'])
# Display the plot in Streamlit
st.plotly_chart(compatible_plot)
if __name__ == '__main__':
main()