|
import torch |
|
from transformers import DistilBertModel, DistilBertTokenizer |
|
|
|
|
|
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") |
|
model = DistilBertCNN(num_labels=3) # Assuming you have defined the custom classification layers |
|
|
|
|
|
device = torch.device("cpu") |
|
model.to(device) |
|
|
|
|
|
model.load_state_dict(torch.load("model.pt", map_location=device)) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
def classify_tweet(tweet): |
|
inputs = tokenizer.encode_plus( |
|
tweet, |
|
add_special_tokens=True, |
|
max_length=128, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt" |
|
) |
|
input_ids = inputs["input_ids"].to(device) |
|
attention_mask = inputs["attention_mask"].to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
|
logits = outputs[0] |
|
predicted_class = torch.argmax(logits).item() |
|
|
|
return predicted_class |
|
|
|
|
|
tweet = "This is a sample tweet." |
|
predicted_class = classify_tweet(tweet) |
|
print(f"Predicted Class: {predicted_class}") |
|
|