Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
import gradio as gr | |
def run_inference(review_text: str) -> str: | |
""" | |
Perform inference on the given wine review text and return the predicted wine variety. | |
Args: | |
review_text (str): Wine review text in the format "country [SEP] description". | |
Returns: | |
str: The predicted wine variety using the model's id2label mapping if available. | |
""" | |
# Define model and tokenizer identifiers | |
model_id = "spawn99/modernbert-wine-classification" | |
tokenizer_id = "answerdotai/ModernBERT-base" | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) | |
model = AutoModelForSequenceClassification.from_pretrained(model_id) | |
# Tokenize the input text | |
inputs = tokenizer( | |
review_text, | |
return_tensors="pt", | |
padding="max_length", | |
truncation=True, | |
max_length=256 | |
) | |
model.eval() | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Determine prediction and map to label if available | |
pred = torch.argmax(logits, dim=-1).item() | |
variety = ( | |
model.config.id2label.get(pred, str(pred)) | |
if hasattr(model.config, "id2label") and model.config.id2label | |
else str(pred) | |
) | |
return variety | |
def predict_wine_variety(country: str, description: str) -> dict: | |
""" | |
Combine the provided country and description, then perform inference. | |
Enforces a maximum character limit of 750 on the description. | |
Args: | |
country (str): The country of wine origin. | |
description (str): The wine review description. | |
Returns: | |
dict: Dictionary containing the predicted wine variety or an error message if the limit is exceeded. | |
""" | |
# Validate description length | |
if len(description) > 750: | |
return {"error": "Description exceeds 750 character limit. Please shorten your input."} | |
# Capitalize input values and format the review text accordingly. | |
review_text = f"{country.capitalize()} [SEP] {description.capitalize()}" | |
predicted_variety = run_inference(review_text) | |
return {"Variety": predicted_variety} | |
if __name__ == "__main__": | |
iface = gr.Interface( | |
fn=predict_wine_variety, | |
inputs=[ | |
gr.Textbox(label="Country", placeholder="Enter country of origin..."), | |
gr.Textbox(label="Description", placeholder="Enter wine review description...") | |
], | |
outputs=gr.JSON(label="Prediction"), | |
title="Wine Variety Predictor", | |
description="Predict the wine variety based on country and description." | |
) | |
iface.launch() |