skanderovitch commited on
Commit
b1f21a7
·
verified ·
1 Parent(s): d450e59

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -0
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+ import pandas as pd
4
+ import os
5
+ import re
6
+ import numpy as np
7
+ from glob import glob
8
+ import lightgbm as lgb
9
+ import pickle
10
+
11
+ os.environ['S3_BUCKET'] = 'seriouslyusers'
12
+ os.environ['S3_BUCKET'] = "seriouslytestfaces"
13
+
14
+ import io
15
+
16
+ def get_s3_url(key):
17
+ url = 'https://s3.amazonaws.com/%s/%s' % (os.environ['S3_BUCKET'],key.replace(' ','+'))
18
+ return url
19
+
20
+ embeddings = pd.read_parquet('./embeddings.parquet')
21
+
22
+ def create_dir(directory):
23
+ if not os.path.exists(directory):
24
+ os.makedirs(directory)
25
+
26
+ def setup_user():
27
+ create_dir(f'./users/{st.session_state.name}')
28
+ create_dir(f'./users/{st.session_state.name}/likes')
29
+ create_dir(f'./users/{st.session_state.name}/models')
30
+
31
+ def get_filename():
32
+ if 'preds' in st.session_state:
33
+ p = st.session_state.preds**4
34
+ p /= sum(p)
35
+ choice = np.random.choice(range(len(p)),p=p)
36
+ st.session_state.pred = st.session_state.preds[choice]
37
+ return embeddings.index[choice]
38
+ st.toast('Random for now')
39
+ return np.random.choice(embeddings.index)
40
+
41
+ st.title('What does attractive mean to you?')
42
+ st.session_state.name = st.text_input(label='Invent a unique alias (and remember it)')
43
+
44
+
45
+ def liked(filename,like):
46
+ filename = f'./users/{st.session_state.name}/likes/' + filename.split('/')[-1] + '.' + str(like)[:1]
47
+ open(filename, 'a').close()
48
+
49
+ def get_train_data():
50
+ clean = lambda file : file.replace('\\','/').split('/')[-1][:-2]
51
+ true_files = list(map(clean,glob(f'./users/{st.session_state.name}/likes/*.T')))
52
+ false_files = list(map(clean,glob(f'./users/{st.session_state.name}/likes/*.F')))
53
+ true_embeddings = embeddings.loc[true_files].values
54
+ false_embeddings = embeddings.loc[false_files].values
55
+ st.toast(f'Found {len(true_files)} positives and {len(false_files)} negatives')
56
+ labels = np.array([1 for _ in true_embeddings] + [0 for _ in false_embeddings])
57
+ st.session_state.labels = pd.Series(labels,index=true_files+false_files).rename('label')
58
+ X = np.vstack([true_embeddings,false_embeddings])
59
+ return X,labels
60
+
61
+ def train_model(X,labels):
62
+ if len(labels) < 30:
63
+ st.toast('Not enough data')
64
+ return
65
+ if labels.mean() > 0.9:
66
+ st.toast('Not enough negatives')
67
+ return
68
+ if labels.mean() < 0.1:
69
+ st.toast('Not enough positives')
70
+ return
71
+ train_data = lgb.Dataset(X, label=labels)
72
+ num_round = 10
73
+ param = {'num_leaves':10, 'objective': 'binary', 'metric' : 'auc'}
74
+ bst = lgb.train(param, train_data, num_round)
75
+ in_sample_preds = bst.predict(X)
76
+ in_sample_score = np.corrcoef([in_sample_preds,np.array(labels)])[0][1]
77
+ st.session_state.score = in_sample_score
78
+ st.toast(f'Score = {in_sample_score:.1%}')
79
+ return bst
80
+
81
+ def rank_candidates(bst):
82
+ return bst.predict(embeddings.values)
83
+
84
+
85
+ def train():
86
+ X,labels = get_train_data()
87
+ bst = train_model(X,labels)
88
+ if bst is None:
89
+ return
90
+ filename = f'./users/{st.session_state.name}/models/model.txt'
91
+ bst.save_model(filename)
92
+ preds = rank_candidates(bst)
93
+ st.session_state.preds = preds
94
+
95
+
96
+ def cleanup():
97
+ files = glob(f'./users/{st.session_state.name}/likes/*')
98
+ for f in files:
99
+ os.remove(f)
100
+ if 'preds' in st.session_state:
101
+ del st.session_state.preds
102
+ del st.session_state.pred
103
+
104
+
105
+ def get_extremes(n=4):
106
+ if 'preds' in st.session_state:
107
+ preds = pd.Series(st.session_state.preds,index=embeddings.index).sort_values(ascending=False)
108
+ return preds.iloc[:n].to_dict(),preds.iloc[-n:].to_dict()
109
+
110
+ def get_strange(n=4):
111
+ if 'labels' in st.session_state:
112
+ labels = st.session_state.labels
113
+ preds = pd.Series(st.session_state.preds,index=embeddings.index).loc[labels.index].rename('pred')
114
+ data = pd.concat([labels, preds],axis=1)
115
+ st.toast(data.columns)
116
+
117
+ data['diff'] = data['pred'] - data['label']
118
+ data = data.sort_values('diff')['diff']
119
+ return data.iloc[:n].to_dict(),data.iloc[-n:].to_dict()
120
+
121
+
122
+
123
+
124
+ if st.session_state.name:
125
+ st.session_state.name = re.sub(r'[^A-Za-z0-9 ]+', '', st.session_state.name)[:100]
126
+ setup_user()
127
+ st.subheader(f"Let's start {st.session_state.name}")
128
+ filename = get_filename()
129
+
130
+
131
+ cc1, cc2 = st.columns(2)
132
+ c1,c2 = cc1.columns(2)
133
+ c1.button('Why not', on_click=liked, args=[filename,True])
134
+ c2.button('Nope', on_click=liked, args=[filename,False])
135
+ key = get_s3_url(filename)
136
+ cc1.image(key, width = 400)
137
+ c1,c2 = cc2.columns(2)
138
+ c1.button('Train',on_click=train,args=[])
139
+ c2.button('Start over',on_click=cleanup,args=[])
140
+ if 'preds' in st.session_state:
141
+
142
+ cc1.write('Here is our guess')
143
+ cc1.metric("Probability you will like", f'{st.session_state.pred:.1%}')
144
+
145
+ best,worst = get_extremes()
146
+
147
+ cc2.subheader('Predicted best')
148
+ cs = cc2.columns(len(best))
149
+ for c,(file,pred) in zip(cs,best.items()):
150
+ c.metric("", f'{pred:.0%}')
151
+ c.image(get_s3_url(file), width = 100)
152
+
153
+ cc2.subheader('Predicted worst')
154
+ cs = cc2.columns(len(worst))
155
+ for c,(file,pred) in zip(cs,worst.items()):
156
+ c.metric("", f'{pred:.0%}')
157
+ c.image(get_s3_url(file), width = 100)
158
+
159
+ cc1.metric("Overall model accuracy", f'{st.session_state.score:.0%}')
160
+
161
+ cc1.subheader('Where you confused me')
162
+
163
+ best,worst = get_strange()
164
+
165
+ cc1.write("You didn't like my picks")
166
+ cs = cc1.columns(len(best))
167
+ for c,(file,pred) in zip(cs,best.items()):
168
+ c.metric("", "",f'{pred:.0%}')
169
+ c.image(get_s3_url(file), width = 100)
170
+
171
+ cc1.write("You liked these more than I thought")
172
+ cs = cc1.columns(len(worst))
173
+ for c,(file,pred) in zip(cs,worst.items()):
174
+ c.metric("","", f'{pred:.0%}')
175
+ c.image(get_s3_url(file), width = 100)
176
+
177
+
178
+
179
+
180
+