Suraj-Yadav's picture
Uploading food not food text classifier demo app
494f436 verified
import torch
from typing import Union,Dict
from transformers import pipeline
from constants import HUGGINGFACE_MODEL_PATH
def set_device() -> torch.device:
"""
Set the device to the best available option: CUDA (if available), MPS (if available on Mac),
or CPU as a fallback. Provides a robust selection mechanism for production environments.
Returns:
torch.device: The best available device for computation.
"""
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
return torch.device("mps")
else:
return torch.device("cpu")
def food_not_food_classifier(
text: Union[str, list],
model_path: str,
batch_size: int = 32,
device: str = None,
get_classifier:bool = False
) -> Dict[str, float]:
"""
Classifies whether the given text is related to food or not, returning a dictionary of labels and their scores.
Args:
text (Union[str, list]): The input text or list of texts to classify.
model_path (str): The path to the Hugging Face model for classification.
batch_size (int): The batch size for processing. Default is 32.
device (str): The device to run inference on (e.g., 'cuda', 'cpu'). Default is None (auto-detect best available).
Returns:
Dict[str, float]: A dictionary where the keys are the labels and the values are the classification scores.
"""
if device is None:
device = set_device()
classifier = pipeline(
task="text-classification",
model=model_path,
batch_size=batch_size,
device=device,
top_k=None # Keep all predictions
)
if get_classifier:
return classifier
else:
results = classifier(text) # [[{'label': 'food', 'score': 0.9500328898429871}, {'label': 'not_food', 'score': 0.04996709153056145}]]
output_dict = {}
for output in results[0]:
output_dict[output['label']] = output['score']
return output_dict
def gradio_food_classifier(text: str) -> dict:
"""
A wrapper function for Gradio to classify text using the classify_food_text function.
Args:
text (str): The input text to classify.
Returns:
dict: Classification results as a dictionary of label and score.
"""
classifier = food_not_food_classifier(text=text,
model_path=HUGGINGFACE_MODEL_PATH,
get_classifier=True)
results = classifier(text)
output_dict = {}
for output in results[0]:
output_dict[output['label']] = output['score']
return output_dict