Spaces:
Configuration error
Configuration error
import torch | |
from typing import Union,Dict | |
from transformers import pipeline | |
from constants import HUGGINGFACE_MODEL_PATH | |
def set_device() -> torch.device: | |
""" | |
Set the device to the best available option: CUDA (if available), MPS (if available on Mac), | |
or CPU as a fallback. Provides a robust selection mechanism for production environments. | |
Returns: | |
torch.device: The best available device for computation. | |
""" | |
if torch.cuda.is_available(): | |
return torch.device("cuda") | |
elif torch.backends.mps.is_available() and torch.backends.mps.is_built(): | |
return torch.device("mps") | |
else: | |
return torch.device("cpu") | |
def food_not_food_classifier( | |
text: Union[str, list], | |
model_path: str, | |
batch_size: int = 32, | |
device: str = None, | |
get_classifier:bool = False | |
) -> Dict[str, float]: | |
""" | |
Classifies whether the given text is related to food or not, returning a dictionary of labels and their scores. | |
Args: | |
text (Union[str, list]): The input text or list of texts to classify. | |
model_path (str): The path to the Hugging Face model for classification. | |
batch_size (int): The batch size for processing. Default is 32. | |
device (str): The device to run inference on (e.g., 'cuda', 'cpu'). Default is None (auto-detect best available). | |
Returns: | |
Dict[str, float]: A dictionary where the keys are the labels and the values are the classification scores. | |
""" | |
if device is None: | |
device = set_device() | |
classifier = pipeline( | |
task="text-classification", | |
model=model_path, | |
batch_size=batch_size, | |
device=device, | |
top_k=None # Keep all predictions | |
) | |
if get_classifier: | |
return classifier | |
else: | |
results = classifier(text) # [[{'label': 'food', 'score': 0.9500328898429871}, {'label': 'not_food', 'score': 0.04996709153056145}]] | |
output_dict = {} | |
for output in results[0]: | |
output_dict[output['label']] = output['score'] | |
return output_dict | |
def gradio_food_classifier(text: str) -> dict: | |
""" | |
A wrapper function for Gradio to classify text using the classify_food_text function. | |
Args: | |
text (str): The input text to classify. | |
Returns: | |
dict: Classification results as a dictionary of label and score. | |
""" | |
classifier = food_not_food_classifier(text=text, | |
model_path=HUGGINGFACE_MODEL_PATH, | |
get_classifier=True) | |
results = classifier(text) | |
output_dict = {} | |
for output in results[0]: | |
output_dict[output['label']] = output['score'] | |
return output_dict |