Spaces:
Sleeping
Sleeping
import streamlit as st | |
import streamlit.components.v1 as components | |
import pandas as pd | |
import os | |
import re | |
import numpy as np | |
from glob import glob | |
import lightgbm as lgb | |
os.environ['S3_BUCKET'] = "seriouslytestfaces" | |
import io | |
def get_s3_url(key): | |
url = 'https://s3.amazonaws.com/%s/%s' % (os.environ['S3_BUCKET'],key.replace(' ','+')) | |
return url | |
embeddings = pd.read_parquet('./embeddings.parquet') | |
def create_dir(directory): | |
if not os.path.exists(directory): | |
os.makedirs(directory) | |
def setup_user(): | |
create_dir(f'./users/{st.session_state.name}') | |
create_dir(f'./users/{st.session_state.name}/likes') | |
create_dir(f'./users/{st.session_state.name}/models') | |
if 'count' not in st.session_state: | |
st.session_state.count = 0 | |
st.session_state.neg = 0 | |
st.session_state.pos = 0 | |
import requests | |
def check_image_url_accessible(url): | |
try: | |
# Send a HEAD request to save bandwidth | |
response = requests.head(url, allow_redirects=True, timeout=5) | |
# If the HEAD request fails, fallback to GET request | |
if response.status_code != 200: | |
response = requests.get(url, stream=True, timeout=5) | |
# Check the status code | |
if response.status_code == 200: | |
# Verify if it's an image | |
content_type = response.headers.get("Content-Type", "") | |
if "image" in content_type: | |
return True | |
else: | |
return False | |
else: | |
return False | |
except requests.RequestException: | |
return False | |
def get_filename(): | |
if 'preds' in st.session_state: | |
p = st.session_state.preds**4 | |
p /= sum(p) | |
choice = np.random.choice(range(len(p)),p=p) | |
st.session_state.pred = st.session_state.preds[choice] | |
url = get_s3_url(embeddings.index[choice]) | |
if check_image_url_accessible(url): | |
return embeddings.index[choice] | |
else: | |
return get_filename() | |
# st.toast('Random for now') | |
choice = np.random.choice(embeddings.index) | |
url = get_s3_url(choice) | |
if check_image_url_accessible(url): | |
return choice | |
else: | |
return get_filename() | |
st.title('What does attractive mean to you?') | |
st.session_state.name = st.text_input(label='Invent a unique alias (and remember it)') | |
def liked(filename,like): | |
filename = f'./users/{st.session_state.name}/likes/' + filename.split('/')[-1] + '.' + str(like)[:1] | |
open(filename, 'a').close() | |
st.session_state.count += 1 | |
if like: | |
st.session_state.pos += 1 | |
else: | |
st.session_state.neg += 1 | |
def get_train_data(): | |
clean = lambda file : file.replace('\\','/').split('/')[-1][:-2] | |
true_files = list(map(clean,glob(f'./users/{st.session_state.name}/likes/*.T'))) | |
false_files = list(map(clean,glob(f'./users/{st.session_state.name}/likes/*.F'))) | |
true_embeddings = embeddings.loc[true_files].values | |
false_embeddings = embeddings.loc[false_files].values | |
# st.toast(f'Found {len(true_files)} positives and {len(false_files)} negatives') | |
labels = np.array([1 for _ in true_embeddings] + [0 for _ in false_embeddings]) | |
st.session_state.labels = pd.Series(labels,index=true_files+false_files).rename('label') | |
X = np.vstack([true_embeddings,false_embeddings]) | |
return X,labels | |
def train_model(X,labels): | |
if len(labels) < 30: | |
st.toast('Not enough data') | |
return | |
if labels.mean() > 0.9: | |
st.toast('Not enough negatives') | |
return | |
if labels.mean() < 0.1: | |
st.toast('Not enough positives') | |
return | |
train_data = lgb.Dataset(X, label=labels) | |
num_round = 10 | |
param = {'num_leaves': 30, 'objective': 'binary', 'metric' : 'binary'} | |
bst = lgb.train(param, train_data, num_round) | |
in_sample_preds = bst.predict(X) | |
in_sample_score = np.corrcoef([in_sample_preds,np.array(labels)])[0][1] | |
st.session_state.score = in_sample_score | |
st.toast(f'Score = {in_sample_score:.1%}') | |
return bst | |
def rank_candidates(bst): | |
return bst.predict(embeddings.values) | |
def train(): | |
X,labels = get_train_data() | |
bst = train_model(X,labels) | |
if bst is None: | |
return | |
filename = f'./users/{st.session_state.name}/models/model.txt' | |
bst.save_model(filename) | |
preds = rank_candidates(bst) | |
st.session_state.preds = preds | |
st.balloons() | |
def cleanup(): | |
files = glob(f'./users/{st.session_state.name}/likes/*') | |
for f in files: | |
os.remove(f) | |
for var in 'preds pred count pos neg'.split(): | |
if var in st.session_state: | |
del st.session_state[var] | |
st.session_state.count = 0 | |
st.session_state.neg = 0 | |
st.session_state.pos = 0 | |
def get_extremes(n=4): | |
if 'preds' in st.session_state: | |
preds = pd.Series(st.session_state.preds,index=embeddings.index).sort_values(ascending=False) | |
return preds.iloc[:n].to_dict(),preds.iloc[-n:].to_dict() | |
def get_strange(n=4): | |
if 'labels' in st.session_state: | |
labels = st.session_state.labels | |
preds = pd.Series(st.session_state.preds,index=embeddings.index).loc[labels.index].rename('pred') | |
data = pd.concat([labels, preds],axis=1) | |
# st.toast(data.columns) | |
data['diff'] = data['pred'] - data['label'] | |
data = data.sort_values('diff',ascending=False)['diff'] | |
surprising_dislikes = data.iloc[:n].to_dict() | |
surprising_likes = data.iloc[-n:].to_dict() | |
return surprising_dislikes,surprising_likes | |
if st.session_state.name: | |
st.session_state.name = re.sub(r'[^A-Za-z0-9 ]+', '', st.session_state.name)[:100] | |
setup_user() | |
st.subheader(f"Let's start {st.session_state.name}") | |
c1,c2 = st.columns(2) | |
my_bar = c1.progress(min(st.session_state.count/40,1.)) | |
p_liked = (st.session_state.pos / st.session_state.count) if st.session_state.count else 0 | |
c2.metric('%age liked so far',f'{p_liked:.1%}') | |
filename = get_filename() | |
cc1, cc2 = st.columns(2) | |
c1,c2 = cc1.columns(2) | |
c1.button('Why not', on_click=liked, args=[filename,True]) | |
c2.button('Nope', on_click=liked, args=[filename,False]) | |
key = get_s3_url(filename) | |
cc1.image(key, width = 400) | |
c1,c2 = cc2.columns(2) | |
if st.session_state.count>40 and st.session_state.pos > 5 and st.session_state.neg > 5: | |
c1.button('Train',on_click=train,args=[]) | |
if st.session_state.count == 41: | |
st.balloons() | |
st.toast('Ready for training') | |
c2.button('Start over',on_click=cleanup,args=[]) | |
if 'preds' in st.session_state: | |
cc1.write('Here is our guess') | |
c1,c2 = cc1.columns(2) | |
c1.metric("Probability you will like", f'{st.session_state.pred:.1%}') | |
c2.metric("Overall model accuracy", f'{st.session_state.score:.0%}') | |
best,worst = get_extremes() | |
cc2.subheader('Predicted best') | |
cs = cc2.columns(len(best)) | |
for c,(file,pred) in zip(cs,best.items()): | |
c.metric("", f'{pred:.0%}') | |
c.image(get_s3_url(file), width = 100) | |
cc2.subheader('Predicted worst') | |
cs = cc2.columns(len(worst)) | |
for c,(file,pred) in zip(cs,worst.items()): | |
c.metric("", f'{pred:.0%}') | |
c.image(get_s3_url(file), width = 100) | |
cc1.subheader('Where you confused me') | |
surprising_dislikes,surprising_likes = get_strange() | |
cc1.write("You didn't like my picks") | |
cs = cc1.columns(len(surprising_dislikes)) | |
for c,(file,pred) in zip(cs,surprising_dislikes.items()): | |
c.metric("", "",f'{-pred:.0%}') | |
c.image(get_s3_url(file), width = 100) | |
cc1.write("You liked these more than I thought") | |
cs = cc1.columns(len(surprising_likes)) | |
for c,(file,pred) in zip(cs,surprising_likes.items()): | |
c.metric("","", f'{-pred:.0%}') | |
c.image(get_s3_url(file), width = 100) | |