# /// script # requires-python = "==3.12" # dependencies = [ # "marimo", # "polars==1.23.0", # "scikit-learn==1.6.1", # "numpy==2.1.3", # "mohtml==0.1.2", # "model2vec==0.4.0", # "altair==5.5.0", # ] # /// import marimo __generated_with = "0.11.14" app = marimo.App() @app.cell def _(mo): mo.md("""### Fast labelling demo""") return @app.cell def _(mo, use_default_switch): uploaded_file = mo.ui.file(kind="area") if not use_default_switch.value else None uploaded_file return (uploaded_file,) @app.cell def _(mo): use_default_switch = mo.ui.switch(False, label="Use default dataset") use_default_switch return (use_default_switch,) @app.cell def _(mo): pos_label = mo.ui.text("pos", placeholder="positive label name", label="positive class name") neg_label = mo.ui.text("neg", placeholder="negative label name", label="negative class name") return neg_label, pos_label @app.cell def _(uploaded_file, use_default_switch): should_stop = not use_default_switch.value and len(uploaded_file.value) == 0 return (should_stop,) @app.cell def _(mo, pl, should_stop, uploaded_file, use_default_switch): mo.stop(should_stop , mo.md("**Submit a dataset or use default one to continue.**")) if use_default_switch.value: df = pl.read_csv("spam.csv") else: df = pl.read_csv(uploaded_file.value[0].contents) texts = df["text"].to_list() return df, texts @app.cell def _(StaticModel, mo): with mo.status.spinner(subtitle="Loading model ...") as _spinner: tfm = StaticModel.from_pretrained("minishlab/potion-retrieval-32M") return (tfm,) @app.cell def _(mo, should_stop): mo.stop(should_stop) text_input = mo.ui.text_area("you will win a free ringtone!", label="Reference sentences") form = mo.md("""{text_input}""").batch(text_input=text_input).form() form return form, text_input @app.cell def _(mo, texts, tfm): with mo.status.spinner(subtitle="Creating embeddings ...") as _spinner: X = tfm.encode(texts) return (X,) @app.cell def _(add_label, get_example, mo, neg_label, pos_label, undo): btn_spam = mo.ui.button( label=f"Annotate {neg_label.value}", on_click=lambda d: add_label(get_example(), neg_label.value), keyboard_shortcut="Ctrl-L" ) btn_ham = mo.ui.button( label=f"Annotate {pos_label.value}", on_click=lambda d: add_label(get_example(), pos_label.value), keyboard_shortcut="Ctrl-K" ) btn_undo = mo.ui.button( label="Undo", on_click=lambda d: undo(), keyboard_shortcut="Ctrl-U" ) return btn_ham, btn_spam, btn_undo @app.cell def _(gen, get_label, set_example, set_label): def add_label(text, lab): current_labels = get_label() set_label(current_labels + [{"text": text, "label": lab}]) set_example(next(gen)) def undo(): current_labels = get_label() set_label(current_labels[:-2]) return add_label, undo @app.cell def _(): from mohtml import br return (br,) @app.cell def _(br, btn_ham, btn_spam, btn_undo, example, mo, neg_label, p, pos_label): mo.vstack([ mo.hstack([ pos_label, neg_label ]), br(), mo.hstack([ btn_ham, btn_spam, btn_undo ]), br(), p("Current example:", klass="font-bold"), example ]) return @app.cell def _(mo): get_label, set_label = mo.state([]) return get_label, set_label @app.cell def _(gen, mo): get_example, set_example = mo.state(next(gen)) return get_example, set_example @app.cell def _(): from mohtml import tailwind_css, div, p tailwind_css() return div, p, tailwind_css @app.cell def _(get_label, mo): import json data = get_label() json_download = mo.download( data=json.dumps(data).encode("utf-8"), filename="data.json", mimetype="application/json", label="Download JSON", ) return data, json, json_download @app.cell def _(X, cosine_similarity, form, get_label, mo, pl, texts, tfm): mo.stop(not form.value, "Need a text input to fetch example") mo.stop(not form.value.get("text_input", None), "Need a text input to fetch example") df_emb = ( pl.DataFrame({ "index": range(X.shape[0]), "text": texts }).with_columns(sim=pl.lit(1)) ) query = tfm.encode([form.value["text_input"]]) similarity = cosine_similarity(query, X)[0] df_emb = df_emb.with_columns(sim=similarity).sort(pl.col("sim"), descending=True) label_texts = [_["text"] for _ in get_label()] gen = (_["text"] for _ in df_emb.head(100).to_dicts() if _["text"] not in label_texts) return df_emb, gen, label_texts, query, similarity @app.cell def _(div, get_example, p): example = div( p(get_example()), klass="bg-gray-100 p-4 rounded-lg" ) return (example,) @app.cell def _(get_label, mo, pl, should_stop): mo.stop(should_stop) pl.DataFrame(get_label()).reverse() return @app.cell def _(mo): with mo.status.spinner(subtitle="Loading libraries ...") as _spinner: import polars as pl import numpy as np from sklearn.metrics.pairwise import cosine_similarity return cosine_similarity, np, pl @app.cell def _(mo): with mo.status.spinner(subtitle="Loading model2vec ...") as _spinner: from model2vec import StaticModel return (StaticModel,) @app.cell def _(): import marimo as mo return (mo,) @app.cell def _(): return if __name__ == "__main__": app.run()