|
from llms import LLM |
|
from utils.remote_client import execute_remote_task |
|
|
|
def text_classification(text: str, model: str, task: str = "topic", candidate_labels=None, custom_instructions: str = "", use_llm: bool = True) -> str: |
|
""" |
|
Classify text using either LLM or traditional (Modal API) method. |
|
|
|
Args: |
|
text: The text to classify |
|
model: The model to use |
|
task: Either "sentiment" or "topic" |
|
candidate_labels: For topic classification, the list of candidate labels |
|
custom_instructions: Optional instructions for LLM |
|
use_llm: Whether to use LLM or traditional method |
|
""" |
|
if not text.strip(): |
|
return "" |
|
if use_llm: |
|
return _classification_with_llm(text, model, task, candidate_labels, custom_instructions) |
|
else: |
|
return _classification_with_traditional(text, model, candidate_labels) |
|
|
|
def _classification_with_llm(text: str, model: str, task: str, candidate_labels=None, custom_instructions: str = "") -> str: |
|
try: |
|
llm = LLM(model=model) |
|
|
|
if task == "sentiment": |
|
prompt = ( |
|
f"Analyze the sentiment of the following text. Return ONLY one value: 'positive', 'negative', or 'neutral'.\n" + |
|
(f"{custom_instructions}\n" if custom_instructions else "") + |
|
f"Text: {text}\nSentiment:" |
|
) |
|
else: |
|
labels_str = ", ".join(candidate_labels) if candidate_labels else "any appropriate topic" |
|
prompt = ( |
|
f"Classify the following text into ONE of these categories: {labels_str}.\n" + |
|
f"Return ONLY the most appropriate category name.\n" + |
|
(f"{custom_instructions}\n" if custom_instructions else "") + |
|
f"Text: {text}\nCategory:" |
|
) |
|
|
|
result = llm.generate(prompt) |
|
return result.strip() |
|
except Exception as e: |
|
print(f"Error in LLM classification: {str(e)}") |
|
return "Oops! Something went wrong. Please try again later." |
|
|
|
def _classification_with_traditional(text: str, model: str, labels=None) -> str: |
|
try: |
|
payload = {"text": text, "model": model} |
|
if labels is not None: |
|
payload["labels"] = labels |
|
resp = execute_remote_task("classification", payload) |
|
if "error" in resp: |
|
return "Oops! Something went wrong. Please try again later." |
|
return resp.get("labels", "") |
|
except Exception as e: |
|
print(f"Error in traditional classification: {str(e)}") |
|
return "Oops! Something went wrong. Please try again later." |
|
|