fast-bulk / app.py
koaning's picture
Update app.py
3b8b7e8 verified
# /// script
# requires-python = "==3.12"
# dependencies = [
# "marimo",
# "polars==1.23.0",
# "scikit-learn==1.6.1",
# "numpy==2.1.3",
# "mohtml==0.1.2",
# "model2vec==0.4.0",
# "altair==5.5.0",
# ]
# ///
import marimo
__generated_with = "0.11.14"
app = marimo.App()
@app.cell
def _(mo):
mo.md("""### Fast labelling demo""")
return
@app.cell
def _(mo, use_default_switch):
uploaded_file = mo.ui.file(kind="area") if not use_default_switch.value else None
uploaded_file
return (uploaded_file,)
@app.cell
def _(mo):
use_default_switch = mo.ui.switch(False, label="Use default dataset")
use_default_switch
return (use_default_switch,)
@app.cell
def _(mo):
pos_label = mo.ui.text("pos", placeholder="positive label name", label="positive class name")
neg_label = mo.ui.text("neg", placeholder="negative label name", label="negative class name")
return neg_label, pos_label
@app.cell
def _(uploaded_file, use_default_switch):
should_stop = not use_default_switch.value and len(uploaded_file.value) == 0
return (should_stop,)
@app.cell
def _(mo, pl, should_stop, uploaded_file, use_default_switch):
mo.stop(should_stop , mo.md("**Submit a dataset or use default one to continue.**"))
if use_default_switch.value:
df = pl.read_csv("spam.csv")
else:
df = pl.read_csv(uploaded_file.value[0].contents)
texts = df["text"].to_list()
return df, texts
@app.cell
def _(StaticModel, mo):
with mo.status.spinner(subtitle="Loading model ...") as _spinner:
tfm = StaticModel.from_pretrained("minishlab/potion-retrieval-32M")
return (tfm,)
@app.cell
def _(mo, should_stop):
mo.stop(should_stop)
text_input = mo.ui.text_area("you will win a free ringtone!", label="Reference sentences")
form = mo.md("""{text_input}""").batch(text_input=text_input).form()
form
return form, text_input
@app.cell
def _(mo, texts, tfm):
with mo.status.spinner(subtitle="Creating embeddings ...") as _spinner:
X = tfm.encode(texts)
return (X,)
@app.cell
def _(add_label, get_example, mo, neg_label, pos_label, undo):
btn_spam = mo.ui.button(
label=f"Annotate {neg_label.value}",
on_click=lambda d: add_label(get_example(), neg_label.value),
keyboard_shortcut="Ctrl-L"
)
btn_ham = mo.ui.button(
label=f"Annotate {pos_label.value}",
on_click=lambda d: add_label(get_example(), pos_label.value),
keyboard_shortcut="Ctrl-K"
)
btn_undo = mo.ui.button(
label="Undo",
on_click=lambda d: undo(),
keyboard_shortcut="Ctrl-U"
)
return btn_ham, btn_spam, btn_undo
@app.cell
def _(gen, get_label, set_example, set_label):
def add_label(text, lab):
current_labels = get_label()
set_label(current_labels + [{"text": text, "label": lab}])
set_example(next(gen))
def undo():
current_labels = get_label()
set_label(current_labels[:-2])
return add_label, undo
@app.cell
def _():
from mohtml import br
return (br,)
@app.cell
def _(br, btn_ham, btn_spam, btn_undo, example, mo, neg_label, p, pos_label):
mo.vstack([
mo.hstack([
pos_label, neg_label
]),
br(),
mo.hstack([
btn_ham, btn_spam, btn_undo
]),
br(),
p("Current example:", klass="font-bold"),
example
])
return
@app.cell
def _(mo):
get_label, set_label = mo.state([])
return get_label, set_label
@app.cell
def _(gen, mo):
get_example, set_example = mo.state(next(gen))
return get_example, set_example
@app.cell
def _():
from mohtml import tailwind_css, div, p
tailwind_css()
return div, p, tailwind_css
@app.cell
def _(get_label, mo):
import json
data = get_label()
json_download = mo.download(
data=json.dumps(data).encode("utf-8"),
filename="data.json",
mimetype="application/json",
label="Download JSON",
)
return data, json, json_download
@app.cell
def _(X, cosine_similarity, form, get_label, mo, pl, texts, tfm):
mo.stop(not form.value, "Need a text input to fetch example")
mo.stop(not form.value.get("text_input", None), "Need a text input to fetch example")
df_emb = (
pl.DataFrame({
"index": range(X.shape[0]),
"text": texts
}).with_columns(sim=pl.lit(1))
)
query = tfm.encode([form.value["text_input"]])
similarity = cosine_similarity(query, X)[0]
df_emb = df_emb.with_columns(sim=similarity).sort(pl.col("sim"), descending=True)
label_texts = [_["text"] for _ in get_label()]
gen = (_["text"] for _ in df_emb.head(100).to_dicts() if _["text"] not in label_texts)
return df_emb, gen, label_texts, query, similarity
@app.cell
def _(div, get_example, p):
example = div(
p(get_example()),
klass="bg-gray-100 p-4 rounded-lg"
)
return (example,)
@app.cell
def _(get_label, mo, pl, should_stop):
mo.stop(should_stop)
pl.DataFrame(get_label()).reverse()
return
@app.cell
def _(mo):
with mo.status.spinner(subtitle="Loading libraries ...") as _spinner:
import polars as pl
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
return cosine_similarity, np, pl
@app.cell
def _(mo):
with mo.status.spinner(subtitle="Loading model2vec ...") as _spinner:
from model2vec import StaticModel
return (StaticModel,)
@app.cell
def _():
import marimo as mo
return (mo,)
@app.cell
def _():
return
if __name__ == "__main__":
app.run()