File size: 2,811 Bytes
494f436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from typing import Union,Dict
from transformers import pipeline
from constants import HUGGINGFACE_MODEL_PATH


def set_device() -> torch.device:
    """

    Set the device to the best available option: CUDA (if available), MPS (if available on Mac),

    or CPU as a fallback. Provides a robust selection mechanism for production environments.



    Returns:

        torch.device: The best available device for computation.

    """
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return torch.device("mps")
    else:
        return torch.device("cpu")


def food_not_food_classifier(

    text: Union[str, list],

    model_path: str,

    batch_size: int = 32,

    device: str = None,

    get_classifier:bool = False

) -> Dict[str, float]:
    """

    Classifies whether the given text is related to food or not, returning a dictionary of labels and their scores.



    Args:

        text (Union[str, list]): The input text or list of texts to classify.

        model_path (str): The path to the Hugging Face model for classification.

        batch_size (int): The batch size for processing. Default is 32.

        device (str): The device to run inference on (e.g., 'cuda', 'cpu'). Default is None (auto-detect best available).



    Returns:

        Dict[str, float]: A dictionary where the keys are the labels and the values are the classification scores.

    """

    if device is None:
        device = set_device()

    classifier = pipeline(
        task="text-classification",
        model=model_path,
        batch_size=batch_size,
        device=device,
        top_k=None  # Keep all predictions
    )

    if get_classifier:
      return classifier
    else:

      results = classifier(text)    # [[{'label': 'food', 'score': 0.9500328898429871}, {'label': 'not_food', 'score': 0.04996709153056145}]]

      output_dict = {}
      for output in results[0]:
          output_dict[output['label']] = output['score']

      return output_dict


def gradio_food_classifier(text: str) -> dict:
    """

    A wrapper function for Gradio to classify text using the classify_food_text function.



    Args:

        text (str): The input text to classify.



    Returns:

        dict: Classification results as a dictionary of label and score.

    """
    classifier = food_not_food_classifier(text=text,
                                          model_path=HUGGINGFACE_MODEL_PATH,
                                          get_classifier=True)

    results = classifier(text)

    output_dict = {}
    for output in results[0]:
        output_dict[output['label']] = output['score']

    return output_dict