import os import streamlit as st from PIL import Image from inference import get_predictions, get_nearest_k st.title('Fashion accessories prediction and search Demo') tot_index = 44065 # multiple folders to handle 10,000 files folder limit in git sample_path1 = './data/small_images_0_9999' sample_path2 = './data/small_images_10000_19999' sample_path3 = './data/small_images_20000_29999' sample_path4 = './data/small_images_30000_39999' sample_path5 = './data/small_images_40000_49999' if 'image_index' not in st.session_state: st.session_state['image_index'] = 0 if 'which_button' not in st.session_state: st.session_state['which_button'] = 'sample_button' sample_col, upload_col = st.tabs(['Select from sample images', 'Upload file']) with upload_col: use_uploaded_image = True uploaded_file = st.file_uploader("Select a picture from your computer(png/jpg) :", type=['png', 'jpg', 'jpeg']) if uploaded_file is not None: img = Image.open(uploaded_file) st.image(img, caption='Uploaded Image') use_uploaded_image = st.button("Use uploaded image") if use_uploaded_image is True: st.session_state['which_button'] = 'upload_button' with sample_col: use_sample_image = True st.write("Select one from these available samples: ") current_index = st.session_state['image_index'] prev_button, next_button = st.columns(2) with prev_button: prev = st.button('prev_image') with next_button: next = st.button('next_image') if prev: current_index = (current_index - 1) % tot_index if next: current_index = (current_index + 1) % tot_index st.session_state['image_index'] = current_index if current_index < 9999: sample_path = sample_path1 elif current_index < 19998: sample_path = sample_path2 elif current_index < 29997: sample_path = sample_path3 elif current_index < 39996: sample_path = sample_path4 else: sample_path = sample_path5 sample_image = Image.open(os.path.join(sample_path, str(current_index)+'.jpg')) st.image(sample_image, caption='Chosen image') use_sample_image = st.button("Use this Sample") if use_sample_image is True: st.session_state['which_button'] = 'sample_button' classification_button, search_button = st.columns(2) with classification_button: predict_clicked = st.button("Get categories predictions") with search_button: search_clicked = st.button("Get similar looking products") if predict_clicked: which_button = st.session_state['which_button'] if which_button == 'sample_button': predictions = get_predictions(sample_image) elif which_button == 'upload_button': predictions = get_predictions(img) st.markdown('**The model predictions along with their probabilities are :**') st.table(predictions) elif search_clicked: which_button = st.session_state['which_button'] if which_button == 'sample_button': top_k_preds = get_nearest_k(sample_image) elif which_button == 'upload_button': top_k_preds = get_nearest_k(img) all_distances = top_k_preds[0][0] all_valid_distances = [dist for dist in all_distances if dist < 300] pred_to_show = len(all_valid_distances) st.markdown('**The top 5 similar product predictions are :**') if pred_to_show == 0: st.markdown('No similar visually looking similar products found in the database.') else: pred_cols = st.columns(pred_to_show) for i in range(len(pred_cols)): with pred_cols[i]: nearest_index = top_k_preds[1][0][i] if nearest_index < 9999: the_path = sample_path1 elif nearest_index < 19998: the_path = sample_path2 elif nearest_index < 29997: the_path = sample_path3 elif nearest_index < 39996: the_path = sample_path4 else: the_path = sample_path5 temp_img = Image.open(os.path.join(the_path, str(nearest_index)+'.jpg')) st.image(temp_img, caption=str(round(top_k_preds[0][0][i], 2))+' distance')