zeroshotcat / app.py
Davide Fiocco
Try out better caching for the pipeline
57fe04c
raw
history blame
1.75 kB
import pandas as pd
import streamlit as st
import tokenizers
import torch
from transformers import Pipeline, pipeline
st.set_page_config(
page_title="Zero-shot classification from tabular data",
page_icon=None,
layout="wide",
initial_sidebar_state="auto",
menu_items=None,
)
@st.cache(
hash_funcs={
torch.nn.parameter.Parameter: lambda _: None,
tokenizers.Tokenizer: lambda _: None,
tokenizers.AddedToken: lambda _: None,
},
allow_output_mutation=True,
)
def load_classifier() -> Pipeline:
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
return classifier
classifier = load_classifier()
st.title("Zero-shot classification from tabular data")
st.text(
"Upload an Excel table and perform zero-shot classification on a set of custom labels"
)
data = st.file_uploader("Upload Excel file (it should contain a `text` column):")
labels = st.text_input("Enter comma-separated labels:")
# classify first N snippets only for faster inference
N = 100
if st.button("Calculate labels"):
try:
labels_list = labels.split(",")
table = pd.read_excel(data)
table = table.loc[table["text"].apply(len) > 10].reset_index(drop=True).head(N)
prog_bar = st.progress(0)
preds = []
for i in range(len(table)):
preds.append(classifier(table.loc[i, "text"], labels)["labels"][0])
prog_bar.progress((i + 1) / len(table))
table["label"] = preds
st.table(table[["text", "label"]])
except:
st.error(
"Something went wrong. Make sure you upload an Excel file containing a `text` column and a set of comma-separated labels is provided"
)