File size: 1,249 Bytes
02fd376 a7fc3b1 02fd376 a7fc3b1 02fd376 a7fc3b1 02fd376 a7fc3b1 5a3da59 02fd376 a7fc3b1 02fd376 a7fc3b1 02fd376 a7fc3b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import torch
from transformers import DistilBertModel, DistilBertTokenizer
# Load the tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertCNN(num_labels=3) # Assuming you have defined the custom classification layers
# Move the model to CPU
device = torch.device("cpu")
model.to(device)
# Load the saved model state dictionary
model.load_state_dict(torch.load("model.pt", map_location=device))
# Set the model to evaluation mode
model.eval()
# Define a function to predict the class of a given tweet
def classify_tweet(tweet):
inputs = tokenizer.encode_plus(
tweet,
add_special_tokens=True,
max_length=128,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs[0]
predicted_class = torch.argmax(logits).item()
return predicted_class
# Example usage
tweet = "This is a sample tweet."
predicted_class = classify_tweet(tweet)
print(f"Predicted Class: {predicted_class}")
|