lunde's picture
Initial commit
bd65e34
raw
history blame contribute delete
425 Bytes
import polars as pl
def balance_labels(df: pl.DataFrame, fraction: float = 0.5) -> pl.DataFrame:
prev_num_labels, prev_total = df["label"].sum(), len(df)
to_remove = df.filter(pl.col("label") == 0).sample(fraction=fraction)
df = df.join(to_remove, on="path", how="anti")
print(
f"Previously {prev_num_labels / prev_total:.1%} highlights, now {df['label'].sum() / len(df):.1%}"
)
return df