import os
import pycaret
from pycaret.datasets import get_data
# import pycaret clustering and init setup
from pycaret.clustering import *
# import ClusteringExperiment and init the class
from pycaret.clustering import ClusteringExperiment

import matplotlib.pyplot as plt
import matplotlib as mpl

import numpy as np
import streamlit as st

# For measuring the inference time.
import time

def main():
    data = get_data('jewellery')
    s = setup(data, session_id = 123)

    exp = ClusteringExperiment()

    # init setup on exp
    exp.setup(data, session_id = 123)

    # train kmeans model
    kmeans = create_model('kmeans')

    kmeans_cluster = assign_model(kmeans)
    kmeans_cluster

    if st.button("Prediction"):
        # plot pca cluster plot 
        plot_model(kmeans, plot = 'cluster')

        
if __name__ == '__main__':
    main()