File size: 2,249 Bytes
ea99abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from llms import LLM
from utils.remote_client import execute_remote_task

def topic_classification(text: str, model: str, candidate_labels=None, custom_instructions: str = "", use_llm: bool = True) -> str:
    """
    Classify text into topics using either LLM or traditional (Modal API) method.
    
    Args:
        text: The text to classify
        model: The model to use
        candidate_labels: List of candidate topics/categories
        custom_instructions: Optional instructions for LLM
        use_llm: Whether to use LLM or traditional method
    """
    if not text.strip():
        return ""
    if use_llm:
        return _topic_classification_with_llm(text, model, candidate_labels, custom_instructions)
    else:
        return _topic_classification_with_traditional(text, model, candidate_labels)

def _topic_classification_with_llm(text: str, model: str, candidate_labels=None, custom_instructions: str = "") -> str:
    try:
        llm = LLM(model=model)
        labels_str = ", ".join(candidate_labels) if candidate_labels else "any appropriate topic"
        prompt = (
            f"Classify the following text into ONE of these categories: {labels_str}.\n" +
            f"Return ONLY the most appropriate category name.\n" +
            (f"{custom_instructions}\n" if custom_instructions else "") +
            f"Text: {text}\nCategory:"
        )
        result = llm.generate(prompt)
        return result.strip()
    except Exception as e:
        print(f"Error in LLM topic classification: {str(e)}")
        return "Oops! Something went wrong. Please try again later."

def _topic_classification_with_traditional(text: str, model: str, labels=None) -> str:
    try:
        payload = {
            "text": text, 
            "model": model,
            "task": "topic"
        }
        if labels is not None:
            payload["labels"] = labels
        resp = execute_remote_task("classification", payload)
        if "error" in resp:
            return "Oops! Something went wrong. Please try again later."
        return resp.get("labels", "")
    except Exception as e:
        print(f"Error in traditional topic classification: {str(e)}")
        return "Oops! Something went wrong. Please try again later."