File size: 8,126 Bytes
b1f21a7
 
 
 
 
 
 
 
 
fcf19e3
 
b1f21a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f4d361
 
 
 
b1f21a7
0b9ad90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1f21a7
 
 
 
 
 
0b9ad90
 
b0bb9db
0b9ad90
 
 
5c562da
0b9ad90
 
 
5559590
0b9ad90
 
 
 
b1f21a7
 
 
 
 
 
 
3f4d361
 
 
 
 
b1f21a7
 
 
 
 
 
 
5c562da
b1f21a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0bb9db
b1f21a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9616f0
b1f21a7
 
3f4d361
b1f21a7
 
 
 
3f4d361
 
 
 
 
 
b1f21a7
 
 
 
 
 
 
 
 
 
 
 
5c562da
b1f21a7
 
e2cefa5
 
 
 
b1f21a7
 
 
 
 
 
 
3f4d361
b1f21a7
f6230c6
 
 
 
 
 
b1f21a7
 
 
 
 
 
 
 
 
 
cabb6d5
3f4d361
cabb6d5
f6230c6
 
b1f21a7
 
 
 
f6230c6
 
 
 
b1f21a7
 
 
 
 
 
 
 
 
 
 
 
 
 
f6230c6
b1f21a7
 
 
e2cefa5
b1f21a7
 
e2cefa5
 
 
b1f21a7
 
 
e2cefa5
 
 
b1f21a7
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import streamlit as st
import streamlit.components.v1 as components
import pandas as pd
import os
import re
import numpy as np
from glob import glob
import lightgbm as lgb



os.environ['S3_BUCKET'] = "seriouslytestfaces"

import io

def get_s3_url(key):
    url = 'https://s3.amazonaws.com/%s/%s' % (os.environ['S3_BUCKET'],key.replace(' ','+'))
    return url

embeddings = pd.read_parquet('./embeddings.parquet')

def create_dir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

def setup_user():
    create_dir(f'./users/{st.session_state.name}')
    create_dir(f'./users/{st.session_state.name}/likes')
    create_dir(f'./users/{st.session_state.name}/models')
    if 'count' not in st.session_state:
        st.session_state.count = 0
        st.session_state.neg = 0
        st.session_state.pos = 0

import requests

def check_image_url_accessible(url):
    try:
        # Send a HEAD request to save bandwidth
        response = requests.head(url, allow_redirects=True, timeout=5)
        # If the HEAD request fails, fallback to GET request
        if response.status_code != 200:
            response = requests.get(url, stream=True, timeout=5)
        
        # Check the status code
        if response.status_code == 200:
            # Verify if it's an image
            content_type = response.headers.get("Content-Type", "")
            if "image" in content_type:
                return True
            else:
                return False
        else:
            return False
    except requests.RequestException:
        return False



def get_filename():
    if 'preds' in st.session_state:
        p = st.session_state.preds**4
        p /= sum(p)
        choice = np.random.choice(range(len(p)),p=p)
        st.session_state.pred = st.session_state.preds[choice]
        url = get_s3_url(embeddings.index[choice])
        if check_image_url_accessible(url):
            
            return embeddings.index[choice]
        else:
            return get_filename()
    # st.toast('Random for now')
    choice = np.random.choice(embeddings.index)
    url = get_s3_url(choice)
    if check_image_url_accessible(url):
        
        return choice
    else:
        return get_filename()
    
st.title('What does attractive mean to you?')
st.session_state.name = st.text_input(label='Invent a unique alias (and remember it)')


def liked(filename,like):
    filename = f'./users/{st.session_state.name}/likes/' + filename.split('/')[-1] + '.' + str(like)[:1]
    open(filename, 'a').close()
    st.session_state.count += 1
    if like:
        st.session_state.pos += 1
    else:
        st.session_state.neg += 1

def get_train_data():
    clean = lambda file : file.replace('\\','/').split('/')[-1][:-2]
    true_files = list(map(clean,glob(f'./users/{st.session_state.name}/likes/*.T')))
    false_files = list(map(clean,glob(f'./users/{st.session_state.name}/likes/*.F')))
    true_embeddings = embeddings.loc[true_files].values
    false_embeddings = embeddings.loc[false_files].values
    # st.toast(f'Found {len(true_files)} positives and {len(false_files)} negatives')
    labels = np.array([1 for _ in true_embeddings] + [0 for _ in false_embeddings])
    st.session_state.labels = pd.Series(labels,index=true_files+false_files).rename('label')
    X = np.vstack([true_embeddings,false_embeddings])
    return X,labels

def train_model(X,labels):
    if len(labels) < 30:
        st.toast('Not enough data')
        return
    if labels.mean() > 0.9:
        st.toast('Not enough negatives')
        return
    if labels.mean() < 0.1:
        st.toast('Not enough positives')
        return
    train_data = lgb.Dataset(X, label=labels)
    num_round = 10
    param = {'num_leaves': 30, 'objective': 'binary', 'metric' : 'binary'}
    bst = lgb.train(param, train_data, num_round)
    in_sample_preds = bst.predict(X)
    in_sample_score = np.corrcoef([in_sample_preds,np.array(labels)])[0][1]
    st.session_state.score = in_sample_score
    st.toast(f'Score = {in_sample_score:.1%}')
    return bst
    
def rank_candidates(bst):
    return bst.predict(embeddings.values)


def train():
    X,labels = get_train_data()
    bst = train_model(X,labels)
    if bst is None:
        return
    filename = f'./users/{st.session_state.name}/models/model.txt'
    bst.save_model(filename)
    preds = rank_candidates(bst)
    st.session_state.preds = preds
    st.balloons()



def cleanup():
    files = glob(f'./users/{st.session_state.name}/likes/*')
    for f in files:
       os.remove(f)
    for var in 'preds pred count pos neg'.split():        
        if var in st.session_state:
            del st.session_state[var]
    st.session_state.count = 0
    st.session_state.neg = 0
    st.session_state.pos = 0
        
    
def get_extremes(n=4):
    if 'preds' in st.session_state:
        preds = pd.Series(st.session_state.preds,index=embeddings.index).sort_values(ascending=False)
        return preds.iloc[:n].to_dict(),preds.iloc[-n:].to_dict()
        
def get_strange(n=4):
    if 'labels' in st.session_state:
        labels = st.session_state.labels
        preds = pd.Series(st.session_state.preds,index=embeddings.index).loc[labels.index].rename('pred')
        data = pd.concat([labels, preds],axis=1)
        # st.toast(data.columns)
        
        data['diff'] = data['pred'] - data['label']
        data = data.sort_values('diff',ascending=False)['diff'] 
        surprising_dislikes = data.iloc[:n].to_dict()
        surprising_likes = data.iloc[-n:].to_dict()
        return surprising_dislikes,surprising_likes
        
        


if st.session_state.name:
    st.session_state.name = re.sub(r'[^A-Za-z0-9 ]+', '', st.session_state.name)[:100]
    setup_user()
    
    st.subheader(f"Let's start {st.session_state.name}")
    c1,c2 = st.columns(2)
    my_bar = c1.progress(min(st.session_state.count/40,1.))
    p_liked = (st.session_state.pos / st.session_state.count) if st.session_state.count else 0
    c2.metric('%age liked so far',f'{p_liked:.1%}')

    
    filename = get_filename()
    
    
    cc1, cc2 = st.columns(2)
    c1,c2 = cc1.columns(2)
    c1.button('Why not', on_click=liked, args=[filename,True])
    c2.button('Nope', on_click=liked, args=[filename,False])
    key = get_s3_url(filename)
    cc1.image(key, width = 400)
    c1,c2 = cc2.columns(2)
    if st.session_state.count>40 and st.session_state.pos > 5 and st.session_state.neg > 5:
        c1.button('Train',on_click=train,args=[])
        if st.session_state.count == 41:
            st.balloons()
            st.toast('Ready for training')
    c2.button('Start over',on_click=cleanup,args=[])
    if 'preds' in st.session_state:
        
        cc1.write('Here is our guess')
        c1,c2 = cc1.columns(2)
        c1.metric("Probability you will like", f'{st.session_state.pred:.1%}')
        c2.metric("Overall model accuracy", f'{st.session_state.score:.0%}')
        
        best,worst = get_extremes()
        
        cc2.subheader('Predicted best')
        cs = cc2.columns(len(best))
        for c,(file,pred) in zip(cs,best.items()):
            c.metric("", f'{pred:.0%}')
            c.image(get_s3_url(file), width = 100)
            
        cc2.subheader('Predicted worst')
        cs = cc2.columns(len(worst))
        for c,(file,pred) in zip(cs,worst.items()):
            c.metric("", f'{pred:.0%}')
            c.image(get_s3_url(file), width = 100)

        

        cc1.subheader('Where you confused me')

        surprising_dislikes,surprising_likes = get_strange()
        
        cc1.write("You didn't like my picks")
        cs = cc1.columns(len(surprising_dislikes))
        for c,(file,pred) in zip(cs,surprising_dislikes.items()):
            c.metric("", "",f'{-pred:.0%}')
            c.image(get_s3_url(file), width = 100)

        cc1.write("You liked these more than I thought")
        cs = cc1.columns(len(surprising_likes))
        for c,(file,pred) in zip(cs,surprising_likes.items()):
            c.metric("","", f'{-pred:.0%}')
            c.image(get_s3_url(file), width = 100)