|
|
|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoTokenizer, AutoModel, AutoConfig |
|
|
|
class MultiLabelEmotionClassifier(nn.Module): |
|
def __init__(self, model_name, num_labels, dropout_rate=0.3): |
|
super().__init__() |
|
|
|
|
|
self.config = AutoConfig.from_pretrained(model_name) |
|
self.transformer = AutoModel.from_pretrained(model_name) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout_rate) |
|
self.classifier = nn.Linear(self.config.hidden_size, num_labels) |
|
|
|
|
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
"""Initialize classifier weights""" |
|
nn.init.normal_(self.classifier.weight, std=0.02) |
|
nn.init.zeros_(self.classifier.bias) |
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
|
outputs = self.transformer( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask |
|
) |
|
|
|
|
|
pooled_output = outputs.last_hidden_state[:, 0] |
|
|
|
|
|
pooled_output = self.dropout(pooled_output) |
|
logits = self.classifier(pooled_output) |
|
|
|
return logits |
|
|
|
def load_model(model_path="."): |
|
"""Load the custom model""" |
|
import json |
|
import os |
|
|
|
|
|
with open(os.path.join(model_path, "config.json"), "r") as f: |
|
config = json.load(f) |
|
|
|
|
|
model = MultiLabelEmotionClassifier( |
|
model_name=config["base_model"], |
|
num_labels=config["num_labels"], |
|
dropout_rate=config["dropout_rate"] |
|
) |
|
|
|
|
|
checkpoint = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu") |
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
|
return model, config |
|
|
|
def predict_emotions(text, model_path=".", threshold=0.5): |
|
"""Predict emotions for given text""" |
|
|
|
model, config = load_model(model_path) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
encoding = tokenizer( |
|
text, |
|
truncation=True, |
|
padding='max_length', |
|
max_length=config["max_position_embeddings"], |
|
return_tensors='pt' |
|
) |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
logits = model(encoding['input_ids'], encoding['attention_mask']) |
|
probabilities = torch.sigmoid(logits) |
|
predictions = (probabilities > threshold).int() |
|
|
|
|
|
emotion_labels = config["emotion_labels"] |
|
result = {emotion: bool(pred) for emotion, pred in zip(emotion_labels, predictions[0])} |
|
return result |
|
|
|
|
|
|
|
|
|
|