File size: 511 Bytes
a6c03b1
c7bffdb
 
 
a6c03b1
 
ce56284
 
a6c03b1
 
ce56284
 
c664e6a
a6c03b1
c664e6a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import streamlit as st
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

st.title("What's the category?")


cats = st.text_input("Enter categories (comma separated)")
text = st.text_input("Enter words")

if text is not None and text != "":
    candidate_labels = cats.split(",")
    res = classifier(text, candidate_labels)

    for index, name in enumerate(res['labels']):
        st.text(f"{name} : {round(res['scores'][index], 3) * 100}%")