Transformers
English
code
AI-DrivenExploitGeneration / src /model_inference.py
Canstralian's picture
Create src/model_inference.py
f678e87 verified
raw
history blame contribute delete
877 Bytes
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
def load_model(model_path):
"""Load the trained model from the specified path."""
model = AutoModelForSequenceClassification.from_pretrained(model_path)
return model
def load_tokenizer(model_path):
"""Load the tokenizer from the specified path."""
tokenizer = AutoTokenizer.from_pretrained(model_path)
return tokenizer
def predict(model, tokenizer, text, device='cpu'):
"""Predict the class of the input text."""
model.to(device)
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
inputs = {key: value.to(device) for key, value in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = torch.argmax(logits, dim=-1).item()
return predicted_class