koaning commited on
Commit
c1bba49
·
1 Parent(s): a2cb3cf
Files changed (1) hide show
  1. app.py +76 -138
app.py CHANGED
@@ -10,6 +10,7 @@
10
  # "scikit-learn==1.6.1",
11
  # "numpy==2.1.3",
12
  # "mohtml==0.1.2",
 
13
  # ]
14
  # ///
15
 
@@ -66,124 +67,105 @@ def _(mo, pl, should_stop, uploaded_file, use_default_switch):
66
 
67
 
68
  @app.cell
69
- def _(SentenceTransformer, mo, texts):
 
 
 
 
 
 
 
70
  with mo.status.spinner(subtitle="Creating embeddings ...") as _spinner:
71
- tfm = SentenceTransformer("all-MiniLM-L6-v2")
72
  X = tfm.encode(texts)
73
- return X, tfm
74
 
75
 
76
  @app.cell
77
- def _(X, mo):
78
- with mo.status.spinner(subtitle="Running UMAP ...") as _spinner:
79
- from umap import UMAP
80
-
81
- umap_tfm = UMAP()
82
- X_tfm = umap_tfm.fit_transform(X)
83
- return UMAP, X_tfm, umap_tfm
84
 
85
 
86
  @app.cell
87
- def _(add_label, mo, neg_label, pos_label, undo):
88
- btn_spam = mo.ui.button(label=f"Annotate {neg_label.value}", on_click=lambda d: add_label(neg_label.value))
89
- btn_ham = mo.ui.button(label=f"Annotate {pos_label.value}", on_click=lambda d: add_label(pos_label.value))
90
- btn_undo = mo.ui.button(label="Undo", on_click=lambda d: undo())
91
- return btn_ham, btn_spam, btn_undo
 
 
 
 
 
92
 
93
 
94
  @app.cell
95
- def _(chart, get_label, neg_label, pos_label, set_label):
96
- def add_label(lab):
97
  current_labels = get_label()
98
- if lab == neg_label.value:
99
- new_ham = list(set(current_labels[pos_label.value]).difference(chart.value["index"]))
100
- new_spam = list(set(current_labels[neg_label.value]).union(chart.value["index"]))
101
- if lab == pos_label.value:
102
- new_ham = list(set(current_labels[pos_label.value]).union(chart.value["index"]))
103
- new_spam = list(set(current_labels[neg_label.value]).difference(chart.value["index"]))
104
-
105
- set_label({neg_label.value: new_spam, pos_label.value: new_ham})
106
  return (add_label,)
107
 
108
 
109
  @app.cell
110
- def _(
111
- br,
112
- btn_ham,
113
- btn_spam,
114
- btn_undo,
115
- chart,
116
- form,
117
- json_download,
118
- mo,
119
- neg_label,
120
- pos_label,
121
- switch,
122
- ):
123
- mo.vstack([
124
- mo.md("Assign label names"),
125
- mo.hstack([pos_label, neg_label]),
126
- mo.md("Explore the data"),
127
- mo.hstack([btn_ham, btn_spam, btn_undo, switch, json_download]),
128
- br(),
129
- form if switch.value else "",
130
- br() if switch.value else "",
131
- chart
132
- ])
133
- return
134
 
135
 
136
  @app.cell
137
- def _(chart):
138
- chart.value["text"]
139
- return
140
 
141
 
142
  @app.cell
143
- def _(chart, get_label, neg_label, pos_label, set_label):
144
- def undo():
145
- current_labels = get_label()
146
- new_spam = set(current_labels[neg_label.value]).difference(chart.value["index"])
147
- new_ham = set(current_labels[pos_label.value]).difference(chart.value["index"])
148
- set_label({neg_label.value: list(new_spam), pos_label.value: list(new_ham)})
149
- return (undo,)
150
 
151
 
152
  @app.cell
153
- def _():
154
- from mohtml import br
155
- return (br,)
 
 
 
156
 
157
 
158
  @app.cell
159
- def _(mo, neg_label, pos_label):
160
- get_label, set_label = mo.state({pos_label.value: [], neg_label.value: []})
161
- return get_label, set_label
 
 
 
 
 
 
 
 
 
 
162
 
163
 
164
  @app.cell
165
  def _(mo):
166
  text_input = mo.ui.text_area(label="Reference sentences")
167
  form = mo.md("""{text_input}""").batch(text_input=text_input).form()
 
168
  return form, text_input
169
 
170
 
171
  @app.cell
172
- def _(df_emb, labels, mo):
173
- from collections import Counter
174
-
175
- with mo.status.spinner(subtitle="Starting UI ...") as _spinner:
176
- df_emb
177
-
178
- Counter(labels)
179
- return (Counter,)
180
-
181
-
182
- @app.cell
183
- def _(df_emb, mo, pl):
184
  import json
185
 
186
- data = df_emb.filter(pl.col("label") != "unlabeled").select("text", "label").to_dicts()
187
 
188
  json_download = mo.download(
189
  data=json.dumps(data).encode("utf-8"),
@@ -195,47 +177,9 @@ def _(df_emb, mo, pl):
195
 
196
 
197
  @app.cell
198
- def _(df_emb, mo, scatter):
199
- chart = mo.ui.altair_chart(scatter(df_emb))
200
- return (chart,)
201
 
202
-
203
- @app.cell
204
- def _(mo):
205
- switch = mo.ui.switch(False, label="Use search")
206
- return (switch,)
207
-
208
-
209
- @app.cell
210
- def _(alt, neg_label, pos_label, switch):
211
- def scatter(df):
212
- return (alt.Chart(df)
213
- .mark_circle()
214
- .encode(
215
- x=alt.X("x:Q"),
216
- y=alt.Y("y:Q"),
217
- color=alt.Color("sim:Q") if switch.value else alt.Color("label:N", scale=alt.Scale(
218
- domain=['unlabeled', pos_label.value, neg_label.value],
219
- range=['steelblue', 'green', 'red']
220
- ))
221
- ).properties(width=500, height=500))
222
- return (scatter,)
223
-
224
-
225
- @app.cell
226
- def _(
227
- X,
228
- X_tfm,
229
- cosine_similarity,
230
- form,
231
- get_label,
232
- neg_label,
233
- np,
234
- pl,
235
- pos_label,
236
- texts,
237
- tfm,
238
- ):
239
  df_emb = (
240
  pl.DataFrame({
241
  "x": X_tfm[:, 0],
@@ -245,25 +189,18 @@ def _(
245
  }).with_columns(sim=pl.lit(1))
246
  )
247
 
248
- if form.value:
249
- query = tfm.encode([form.value["text_input"]])
250
- similarity = cosine_similarity(query, X)[0]
251
- df_emb = df_emb.with_columns(sim=similarity)
252
 
253
- spam = set(get_label()[neg_label.value])
254
- ham = set(get_label()[pos_label.value])
 
 
 
255
 
256
- labels = []
257
- for i in range(df_emb.shape[0]):
258
- if i in spam:
259
- labels.append(neg_label.value)
260
- elif i in ham:
261
- labels.append(pos_label.value)
262
- else:
263
- labels.append("unlabeled")
264
 
265
- df_emb = df_emb.with_columns(label=np.array(labels))
266
- return df_emb, ham, i, labels, query, similarity, spam
 
 
267
 
268
 
269
  @app.cell
@@ -274,14 +211,15 @@ def _(mo):
274
  import numpy as np
275
  from sklearn.metrics.pairwise import cosine_similarity
276
  from sklearn.linear_model import LogisticRegression
277
- return LogisticRegression, alt, cosine_similarity, np, pl
 
278
 
279
 
280
  @app.cell
281
  def _(mo):
282
- with mo.status.spinner(subtitle="Loading SBERT ...") as _spinner:
283
- from sentence_transformers import SentenceTransformer
284
- return (SentenceTransformer,)
285
 
286
 
287
  @app.cell
 
10
  # "scikit-learn==1.6.1",
11
  # "numpy==2.1.3",
12
  # "mohtml==0.1.2",
13
+ # "model2vec==0.4.0",
14
  # ]
15
  # ///
16
 
 
67
 
68
 
69
  @app.cell
70
+ def _(StaticModel, mo):
71
+ with mo.status.spinner(subtitle="Loading model ...") as _spinner:
72
+ tfm = StaticModel.from_pretrained("minishlab/potion-retrieval-32M")
73
+ return (tfm,)
74
+
75
+
76
+ @app.cell
77
+ def _(mo, texts, tfm):
78
  with mo.status.spinner(subtitle="Creating embeddings ...") as _spinner:
 
79
  X = tfm.encode(texts)
80
+ return (X,)
81
 
82
 
83
  @app.cell
84
+ def _(PCA, X, mo):
85
+ with mo.status.spinner(subtitle="Running PCA ...") as _spinner:
86
+ pca_tfm = PCA()
87
+ X_tfm = pca_tfm.fit_transform(X)
88
+ return X_tfm, pca_tfm
 
 
89
 
90
 
91
  @app.cell
92
+ def _(add_label, get_example, mo, neg_label, pos_label):
93
+ btn_spam = mo.ui.button(
94
+ label=f"Annotate {neg_label.value}",
95
+ on_click=lambda d: add_label(get_example(), neg_label.value)
96
+ )
97
+ btn_ham = mo.ui.button(
98
+ label=f"Annotate {pos_label.value}",
99
+ on_click=lambda d: add_label(get_example(), pos_label.value)
100
+ )
101
+ return btn_ham, btn_spam
102
 
103
 
104
  @app.cell
105
+ def _(gen, get_label, set_example, set_label):
106
+ def add_label(text, lab):
107
  current_labels = get_label()
108
+ set_label(current_labels + [{"text": text, "label": lab}])
109
+ set_example(next(gen))
 
 
 
 
 
 
110
  return (add_label,)
111
 
112
 
113
  @app.cell
114
+ def _():
115
+ from mohtml import br
116
+ return (br,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
 
119
  @app.cell
120
+ def _(mo):
121
+ get_label, set_label = mo.state([])
122
+ return get_label, set_label
123
 
124
 
125
  @app.cell
126
+ def _(gen, mo):
127
+ get_example, set_example = mo.state(next(gen))
128
+ return get_example, set_example
 
 
 
 
129
 
130
 
131
  @app.cell
132
+ def _(div, get_example, p):
133
+ div(
134
+ p(get_example()),
135
+ klass="bg-gray-100 p-4 rounded-lg"
136
+ )
137
+ return
138
 
139
 
140
  @app.cell
141
+ def _(btn_ham, btn_spam, mo):
142
+ mo.hstack([
143
+ btn_ham, btn_spam
144
+ ])
145
+ return
146
+
147
+
148
+ @app.cell
149
+ def _():
150
+ from mohtml import tailwind_css, div, p
151
+
152
+ tailwind_css()
153
+ return div, p, tailwind_css
154
 
155
 
156
  @app.cell
157
  def _(mo):
158
  text_input = mo.ui.text_area(label="Reference sentences")
159
  form = mo.md("""{text_input}""").batch(text_input=text_input).form()
160
+ form
161
  return form, text_input
162
 
163
 
164
  @app.cell
165
+ def _(get_label, mo):
 
 
 
 
 
 
 
 
 
 
 
166
  import json
167
 
168
+ data = get_label()
169
 
170
  json_download = mo.download(
171
  data=json.dumps(data).encode("utf-8"),
 
177
 
178
 
179
  @app.cell
180
+ def _(X, X_tfm, cosine_similarity, form, mo, pl, texts, tfm):
181
+ mo.stop(not form.value["text_input"], "Need a text input to fetch example")
 
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  df_emb = (
184
  pl.DataFrame({
185
  "x": X_tfm[:, 0],
 
189
  }).with_columns(sim=pl.lit(1))
190
  )
191
 
 
 
 
 
192
 
193
+ query = tfm.encode([form.value["text_input"]])
194
+ similarity = cosine_similarity(query, X)[0]
195
+ df_emb = df_emb.with_columns(sim=similarity).sort(pl.col("sim"), descending=True)
196
+ gen = (_["text"] for _ in df_emb.head(100).to_dicts())
197
+ return df_emb, gen, query, similarity
198
 
 
 
 
 
 
 
 
 
199
 
200
+ @app.cell
201
+ def _(get_label, pl):
202
+ pl.DataFrame(get_label())
203
+ return
204
 
205
 
206
  @app.cell
 
211
  import numpy as np
212
  from sklearn.metrics.pairwise import cosine_similarity
213
  from sklearn.linear_model import LogisticRegression
214
+ from sklearn.decomposition import PCA
215
+ return LogisticRegression, PCA, alt, cosine_similarity, np, pl
216
 
217
 
218
  @app.cell
219
  def _(mo):
220
+ with mo.status.spinner(subtitle="Loading model2vec ...") as _spinner:
221
+ from model2vec import StaticModel
222
+ return (StaticModel,)
223
 
224
 
225
  @app.cell