|
import gradio as gr |
|
from utils.ner_helpers import is_llm_model |
|
from typing import Dict, List, Any |
|
from tasks.topic_classification import topic_classification |
|
|
|
def topic_ui(): |
|
"""Topic classification UI component""" |
|
|
|
|
|
TOPIC_MODELS = [ |
|
"gemini-2.0-flash" |
|
|
|
|
|
|
|
|
|
] |
|
DEFAULT_MODEL = "gemini-2.0-flash" |
|
DEFAULT_LABELS = [ |
|
"Sports", "Economy", "Politics", "Entertainment", "Technology", "Education", "Law" |
|
] |
|
|
|
def classify(text, model, use_custom, labels, custom_instructions): |
|
"""Process text for topic classification""" |
|
if not text.strip(): |
|
return "No text provided" |
|
use_llm = is_llm_model(model) |
|
label_list = [l.strip() for l in labels.split('\n') if l.strip()] if use_custom else None |
|
if use_custom and (not label_list or len(label_list) == 0): |
|
return "Please provide at least one category" |
|
result = topic_classification( |
|
text=text, |
|
model=model, |
|
candidate_labels=label_list, |
|
custom_instructions=custom_instructions, |
|
use_llm=use_llm |
|
) |
|
return result.strip() |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.Textbox( |
|
label="Input Text", |
|
lines=6, |
|
placeholder="Enter text to classify...", |
|
elem_id="topic-input-text" |
|
) |
|
gr.Examples( |
|
examples=[ |
|
["Apple has announced the release of a new iPhone model this fall."], |
|
["The United Nations held a climate summit to discuss global warming solutions."] |
|
], |
|
inputs=[input_text], |
|
label="Examples" |
|
) |
|
use_custom_topics = gr.Checkbox( |
|
label="Use custom topics", |
|
value=True, |
|
elem_id="topic-use-custom-topics" |
|
) |
|
topics_area = gr.TextArea( |
|
label="Candidate Topics (one per line)", |
|
value='\n'.join(DEFAULT_LABELS), |
|
lines=5, |
|
visible=True, |
|
elem_id="topic-candidate-topics" |
|
) |
|
def toggle_topics_area(use_custom): |
|
return gr.update(visible=use_custom) |
|
use_custom_topics.change(toggle_topics_area, inputs=use_custom_topics, outputs=topics_area) |
|
model = gr.Dropdown( |
|
TOPIC_MODELS, |
|
value=DEFAULT_MODEL, |
|
label="Model", |
|
interactive=True, |
|
elem_id="topic-model-dropdown" |
|
) |
|
custom_instructions = gr.Textbox( |
|
label="Custom Instructions (optional)", |
|
lines=2, |
|
placeholder="Add any custom instructions for the model...", |
|
elem_id="topic-custom-instructions" |
|
) |
|
classify_btn = gr.Button("Classify Topic", variant="primary", elem_id="topic-classify-btn") |
|
with gr.Column(): |
|
output_box = gr.Textbox( |
|
label="Classification Result", |
|
lines=2, |
|
elem_id="topic-output" |
|
) |
|
def run_topic_classification(text, model, use_custom, topics, custom_instructions): |
|
return classify(text, model, use_custom, topics, custom_instructions) |
|
classify_btn.click( |
|
run_topic_classification, |
|
inputs=[input_text, model, use_custom_topics, topics_area, custom_instructions], |
|
outputs=output_box |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return None |
|
|