File size: 4,317 Bytes
ea99abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
from utils.ner_helpers import is_llm_model
from typing import Dict, List, Any
from tasks.topic_classification import topic_classification

def topic_ui():
    """Topic classification UI component"""
    
    # Define models and default labels
    TOPIC_MODELS = [
        "gemini-2.0-flash"  # Only allow gemini-2.0-flash for now
        # "gpt-4",
        # "claude-2",
        # "facebook/bart-large-mnli",
        # "joeddav/xlm-roberta-large-xnli"
    ]
    DEFAULT_MODEL = "gemini-2.0-flash"
    DEFAULT_LABELS = [
        "Sports", "Economy", "Politics", "Entertainment", "Technology", "Education", "Law"
    ]
    
    def classify(text, model, use_custom, labels, custom_instructions):
        """Process text for topic classification"""
        if not text.strip():
            return "No text provided"
        use_llm = is_llm_model(model)
        label_list = [l.strip() for l in labels.split('\n') if l.strip()] if use_custom else None
        if use_custom and (not label_list or len(label_list) == 0):
            return "Please provide at least one category"
        result = topic_classification(
            text=text, 
            model=model, 
            candidate_labels=label_list,
            custom_instructions=custom_instructions,
            use_llm=use_llm
        )
        return result.strip()
    
    # UI Components
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(
                label="Input Text", 
                lines=6,
                placeholder="Enter text to classify...",
                elem_id="topic-input-text"
            )
            gr.Examples(
                examples=[
                    ["Apple has announced the release of a new iPhone model this fall."],
                    ["The United Nations held a climate summit to discuss global warming solutions."]
                ],
                inputs=[input_text],
                label="Examples"
            )
            use_custom_topics = gr.Checkbox(
                label="Use custom topics",
                value=True,
                elem_id="topic-use-custom-topics"
            )
            topics_area = gr.TextArea(
                label="Candidate Topics (one per line)",
                value='\n'.join(DEFAULT_LABELS),
                lines=5,
                visible=True,
                elem_id="topic-candidate-topics"
            )
            def toggle_topics_area(use_custom):
                return gr.update(visible=use_custom)
            use_custom_topics.change(toggle_topics_area, inputs=use_custom_topics, outputs=topics_area)
            model = gr.Dropdown(
                TOPIC_MODELS,
                value=DEFAULT_MODEL,
                label="Model",
                interactive=True,
                elem_id="topic-model-dropdown"
            )
            custom_instructions = gr.Textbox(
                label="Custom Instructions (optional)",
                lines=2,
                placeholder="Add any custom instructions for the model...",
                elem_id="topic-custom-instructions"
            )
            classify_btn = gr.Button("Classify Topic", variant="primary", elem_id="topic-classify-btn")
        with gr.Column():
            output_box = gr.Textbox(
                label="Classification Result",
                lines=2,
                elem_id="topic-output"
            )
        def run_topic_classification(text, model, use_custom, topics, custom_instructions):
            return classify(text, model, use_custom, topics, custom_instructions)
        classify_btn.click(
            run_topic_classification,
            inputs=[input_text, model, use_custom_topics, topics_area, custom_instructions],
            outputs=output_box
        )
            #     4. Click "Classify" to analyze
                
            #     ### Model Types
                
            #     - **LLM Models** (Gemini, GPT, Claude): Provide sophisticated classification with better understanding of context and nuance
            #     - **Traditional Models**: Specialized models trained specifically for zero-shot classification tasks
                
            #     Use the advanced options to customize how the model classifies your text.
            #     """)
    
    return None