File size: 2,060 Bytes
df432d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4efdbb5
df432d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Necessary imports
import sys
from typing import Dict
import torch
from transformers import pipeline
import gradio as gr

# Local imports
from src.logger import logging
from src.exception import CustomExceptionHandling


# Load the zero-shot classification model
classifier = pipeline(
    "zero-shot-classification",
    model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0",
    torch_dtype=torch.bfloat16,
)


def ZeroShotTextClassification(
    text_input: str, candidate_labels: str, multi_label: bool
) -> Dict[str, float]:
    """
    Performs zero-shot classification on the given text input and candidate labels.

    Args:
        - text_input: The input text to classify.
        - candidate_labels: A comma-separated string of candidate labels.
        - multi_label: A boolean indicating whether to allow the model to choose multiple classes.

    Returns:
        Dictionary containing label-score pairs.
    """
    try:
        # Check if the input and candidate labels are valid
        if not text_input or not candidate_labels:
            gr.Warning("Please provide valid input and candidate labels")

        # Split and clean the candidate labels
        labels = [label.strip() for label in candidate_labels.split(",")]

        # Log the classification attempt
        logging.info(f"Attempting classification with {len(labels)} labels")

        # Perform zero-shot classification
        hypothesis_template = "This text is about {}"
        prediction = classifier(
            text_input,
            labels,
            hypothesis_template=hypothesis_template,
            multi_label=multi_label,
        )

        # Return the classification results
        logging.info("Classification completed successfully")
        return {
            prediction["labels"][i]: prediction["scores"][i]
            for i in range(len(prediction["labels"]))
        }

    # Handle exceptions that may occur during the process
    except Exception as e:
        # Custom exception handling
        raise CustomExceptionHandling(e, sys) from e