File size: 877 Bytes
f678e87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
def load_model(model_path):
"""Load the trained model from the specified path."""
model = AutoModelForSequenceClassification.from_pretrained(model_path)
return model
def load_tokenizer(model_path):
"""Load the tokenizer from the specified path."""
tokenizer = AutoTokenizer.from_pretrained(model_path)
return tokenizer
def predict(model, tokenizer, text, device='cpu'):
"""Predict the class of the input text."""
model.to(device)
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
inputs = {key: value.to(device) for key, value in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=-1).item()
return predicted_class
|