import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModel, AutoConfig class MultiLabelEmotionClassifier(nn.Module): def __init__(self, model_name, num_labels, dropout_rate=0.3): super().__init__() # Load pre-trained model self.config = AutoConfig.from_pretrained(model_name) self.transformer = AutoModel.from_pretrained(model_name) # Classifier head self.dropout = nn.Dropout(dropout_rate) self.classifier = nn.Linear(self.config.hidden_size, num_labels) # Initialize weights self._init_weights() def _init_weights(self): """Initialize classifier weights""" nn.init.normal_(self.classifier.weight, std=0.02) nn.init.zeros_(self.classifier.bias) def forward(self, input_ids, attention_mask): # Get transformer outputs outputs = self.transformer( input_ids=input_ids, attention_mask=attention_mask ) # Use [CLS] token representation pooled_output = outputs.last_hidden_state[:, 0] # [CLS] token # Apply dropout and classifier pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) return logits def load_model(model_path="."): """Load the custom model""" import json import os # Load config with open(os.path.join(model_path, "config.json"), "r") as f: config = json.load(f) # Initialize model model = MultiLabelEmotionClassifier( model_name=config["base_model"], num_labels=config["num_labels"], dropout_rate=config["dropout_rate"] ) # Load weights checkpoint = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu") model.load_state_dict(checkpoint["model_state_dict"]) return model, config def predict_emotions(text, model_path=".", threshold=0.5): """Predict emotions for given text""" # Load model and tokenizer model, config = load_model(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) # Tokenize input encoding = tokenizer( text, truncation=True, padding='max_length', max_length=config["max_position_embeddings"], return_tensors='pt' ) # Predict model.eval() with torch.no_grad(): logits = model(encoding['input_ids'], encoding['attention_mask']) probabilities = torch.sigmoid(logits) predictions = (probabilities > threshold).int() # Format results emotion_labels = config["emotion_labels"] result = {emotion: bool(pred) for emotion, pred in zip(emotion_labels, predictions[0])} return result # Example usage: # emotions = predict_emotions("I am so happy and excited!") # print(emotions)