File size: 3,162 Bytes
385dc2f
 
 
 
 
 
 
 
 
6fcc580
 
 
 
 
 
 
 
6a7dca8
6fcc580
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f08be9
6fcc580
 
385dc2f
 
 
6fcc580
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import streamlit as st
from transformers import pipeline

# Load the zero-shot classification model
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

# Define Streamlit app
def main():
    # Set page title and favicon
    st.set_page_config(
        page_title="Zero-Shot Text Classification",
        page_icon=":rocket:",
        layout="wide",  # Set layout to wide for better spacing
        initial_sidebar_state="expanded"  # Expand sidebar by default
    )

    # App title and description with colorful text
    st.title("Zero-Shot Text Classification")
    st.markdown(
        """
        This app performs zero-shot text classification using the Facebook BART-Large-MNLI model.
        Enter a sentence and candidate labels, and the model will predict the most relevant label.
        """
    )

    # Create a two-column layout
    col1, col2 = st.columns([1, 2])

    # Left pane: Input elements
    with col1:
        # Input text box for the sentence to classify
        sequence_to_classify = st.text_input("Enter the sentence to classify:")

        # Candidate labels input with help text
        st.text("Enter candidate labels separated by commas (e.g., travel, cooking, dancing):")
        candidate_labels = st.text_input("Candidate Labels:")

        # Confidence threshold slider with colorful track
        confidence_threshold = st.slider(
            "Confidence Threshold:",
            min_value=0.0,
            max_value=1.0,
            value=0.5,
            step=0.01,
            key="confidence_threshold",
            help="Move the slider to adjust the confidence threshold."
        )

        # Classification button with colorful background
        classify_button = st.button(
            "Classify",
            key="classify_button",
            help="Click the button to classify the input text with the provided labels."
        )

    # Right pane: Results
    with col2:
        if classify_button:
            if sequence_to_classify and candidate_labels:
                # Split candidate labels into a list
                candidate_labels = [label.strip() for label in candidate_labels.split(",")]

                # Perform classification
                classification_result = classifier(sequence_to_classify, candidate_labels)

                # Find label with highest score
                max_score_index = classification_result["scores"].index(max(classification_result["scores"]))
                max_label = classification_result["labels"][max_score_index]
                max_score = classification_result["scores"][max_score_index]

                # Display only the label with the highest score
                if max_score >= confidence_threshold:
                    st.subheader("Classification Result:")
                    st.write(f"- **{max_label}**: {max_score:.2f}", unsafe_allow_html=True)
                else:
                    st.subheader("Classification Result:")
                    st.write(f"- <span style='color: #888;'>{max_label}:</span> Below threshold ({max_score:.2f})", unsafe_allow_html=True)

if __name__ == "__main__":
    main()