import streamlit as st import streamlit.components.v1 as components import pandas as pd import os import re import numpy as np from glob import glob import lightgbm as lgb os.environ['S3_BUCKET'] = "seriouslytestfaces" import io def get_s3_url(key): url = 'https://s3.amazonaws.com/%s/%s' % (os.environ['S3_BUCKET'],key.replace(' ','+')) return url embeddings = pd.read_parquet('./embeddings.parquet') def create_dir(directory): if not os.path.exists(directory): os.makedirs(directory) def setup_user(): create_dir(f'./users/{st.session_state.name}') create_dir(f'./users/{st.session_state.name}/likes') create_dir(f'./users/{st.session_state.name}/models') if 'count' not in st.session_state: st.session_state.count = 0 st.session_state.neg = 0 st.session_state.pos = 0 import requests def check_image_url_accessible(url): try: # Send a HEAD request to save bandwidth response = requests.head(url, allow_redirects=True, timeout=5) # If the HEAD request fails, fallback to GET request if response.status_code != 200: response = requests.get(url, stream=True, timeout=5) # Check the status code if response.status_code == 200: # Verify if it's an image content_type = response.headers.get("Content-Type", "") if "image" in content_type: return True else: return False else: return False except requests.RequestException: return False def get_filename(): if 'preds' in st.session_state: p = st.session_state.preds**4 p /= sum(p) choice = np.random.choice(range(len(p)),p=p) st.session_state.pred = st.session_state.preds[choice] url = get_s3_url(embeddings.index[choice]) if check_image_url_accessible(url): return embeddings.index[choice] else: return get_filename() # st.toast('Random for now') choice = np.random.choice(embeddings.index) url = get_s3_url(choice) if check_image_url_accessible(url): return choice else: return get_filename() st.title('What does attractive mean to you?') st.session_state.name = st.text_input(label='Invent a unique alias (and remember it)') def liked(filename,like): filename = f'./users/{st.session_state.name}/likes/' + filename.split('/')[-1] + '.' + str(like)[:1] open(filename, 'a').close() st.session_state.count += 1 if like: st.session_state.pos += 1 else: st.session_state.neg += 1 def get_train_data(): clean = lambda file : file.replace('\\','/').split('/')[-1][:-2] true_files = list(map(clean,glob(f'./users/{st.session_state.name}/likes/*.T'))) false_files = list(map(clean,glob(f'./users/{st.session_state.name}/likes/*.F'))) true_embeddings = embeddings.loc[true_files].values false_embeddings = embeddings.loc[false_files].values # st.toast(f'Found {len(true_files)} positives and {len(false_files)} negatives') labels = np.array([1 for _ in true_embeddings] + [0 for _ in false_embeddings]) st.session_state.labels = pd.Series(labels,index=true_files+false_files).rename('label') X = np.vstack([true_embeddings,false_embeddings]) return X,labels def train_model(X,labels): if len(labels) < 30: st.toast('Not enough data') return if labels.mean() > 0.9: st.toast('Not enough negatives') return if labels.mean() < 0.1: st.toast('Not enough positives') return train_data = lgb.Dataset(X, label=labels) num_round = 10 param = {'num_leaves': 30, 'objective': 'binary', 'metric' : 'binary'} bst = lgb.train(param, train_data, num_round) in_sample_preds = bst.predict(X) in_sample_score = np.corrcoef([in_sample_preds,np.array(labels)])[0][1] st.session_state.score = in_sample_score st.toast(f'Score = {in_sample_score:.1%}') return bst def rank_candidates(bst): return bst.predict(embeddings.values) def train(): X,labels = get_train_data() bst = train_model(X,labels) if bst is None: return filename = f'./users/{st.session_state.name}/models/model.txt' bst.save_model(filename) preds = rank_candidates(bst) st.session_state.preds = preds st.balloons() def cleanup(): files = glob(f'./users/{st.session_state.name}/likes/*') for f in files: os.remove(f) for var in 'preds pred count pos neg'.split(): if var in st.session_state: del st.session_state[var] st.session_state.count = 0 st.session_state.neg = 0 st.session_state.pos = 0 def get_extremes(n=4): if 'preds' in st.session_state: preds = pd.Series(st.session_state.preds,index=embeddings.index).sort_values(ascending=False) return preds.iloc[:n].to_dict(),preds.iloc[-n:].to_dict() def get_strange(n=4): if 'labels' in st.session_state: labels = st.session_state.labels preds = pd.Series(st.session_state.preds,index=embeddings.index).loc[labels.index].rename('pred') data = pd.concat([labels, preds],axis=1) # st.toast(data.columns) data['diff'] = data['pred'] - data['label'] data = data.sort_values('diff',ascending=False)['diff'] surprising_dislikes = data.iloc[:n].to_dict() surprising_likes = data.iloc[-n:].to_dict() return surprising_dislikes,surprising_likes if st.session_state.name: st.session_state.name = re.sub(r'[^A-Za-z0-9 ]+', '', st.session_state.name)[:100] setup_user() st.subheader(f"Let's start {st.session_state.name}") c1,c2 = st.columns(2) my_bar = c1.progress(min(st.session_state.count/40,1.)) p_liked = (st.session_state.pos / st.session_state.count) if st.session_state.count else 0 c2.metric('%age liked so far',f'{p_liked:.1%}') filename = get_filename() cc1, cc2 = st.columns(2) c1,c2 = cc1.columns(2) c1.button('Why not', on_click=liked, args=[filename,True]) c2.button('Nope', on_click=liked, args=[filename,False]) key = get_s3_url(filename) cc1.image(key, width = 400) c1,c2 = cc2.columns(2) if st.session_state.count>40 and st.session_state.pos > 5 and st.session_state.neg > 5: c1.button('Train',on_click=train,args=[]) if st.session_state.count == 41: st.balloons() st.toast('Ready for training') c2.button('Start over',on_click=cleanup,args=[]) if 'preds' in st.session_state: cc1.write('Here is our guess') c1,c2 = cc1.columns(2) c1.metric("Probability you will like", f'{st.session_state.pred:.1%}') c2.metric("Overall model accuracy", f'{st.session_state.score:.0%}') best,worst = get_extremes() cc2.subheader('Predicted best') cs = cc2.columns(len(best)) for c,(file,pred) in zip(cs,best.items()): c.metric("", f'{pred:.0%}') c.image(get_s3_url(file), width = 100) cc2.subheader('Predicted worst') cs = cc2.columns(len(worst)) for c,(file,pred) in zip(cs,worst.items()): c.metric("", f'{pred:.0%}') c.image(get_s3_url(file), width = 100) cc1.subheader('Where you confused me') surprising_dislikes,surprising_likes = get_strange() cc1.write("You didn't like my picks") cs = cc1.columns(len(surprising_dislikes)) for c,(file,pred) in zip(cs,surprising_dislikes.items()): c.metric("", "",f'{-pred:.0%}') c.image(get_s3_url(file), width = 100) cc1.write("You liked these more than I thought") cs = cc1.columns(len(surprising_likes)) for c,(file,pred) in zip(cs,surprising_likes.items()): c.metric("","", f'{-pred:.0%}') c.image(get_s3_url(file), width = 100)