Spaces:
Runtime error
Runtime error
File size: 1,054 Bytes
1a80035 06ba6f1 1a80035 8f8b328 81b6ae4 4791a7a 06ba6f1 81b6ae4 1a80035 79b6ab8 8f8b328 79b6ab8 06ba6f1 2d6cd9e 2945a40 1256802 06ba6f1 db840ed 4791a7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import streamlit as st
import json
from transformers import pipeline
classifier = pipeline("zero-shot-classification",
model="valhalla/distilbart-mnli-12-1")
with st.form('inputs'):
input_text = st.text_area("Input text")
input_label = st.text_input("Labels", placeholder="support, help, important")
input_multi = st.checkbox('Allow multiple true classes', value=False)
submit_button = st.form_submit_button(label='Submit')
if submit_button:
labels = list(l.strip() for l in input_label.split(','))
pred = classifier(input_text, labels, multi_class=input_multi)
if input_multi:
st.vega_lite_chart(pred, {'mark': {'type': 'bar', 'tooltip': False},
'encoding': {
'x': {'field': 'scores', 'type': 'quantitative'},
'y': {'field': 'labels', 'type': 'nominal'},
},
}, use_container_width=True)
else:
out = f"Top predicted labels are {', '.join(p for p in pred['labels'][0:2])}"
st.success(out)
# st.markdown(pred)
|