from transformers import PreTrainedModel import torch from .proto import ProtoModule from .configuration_proto import ProtoConfig class ProtoForMultiLabelClassification(PreTrainedModel): config_class = ProtoConfig def __init__(self, config: ProtoConfig): super().__init__(config) self.proto_module = ProtoModule( pretrained_model=config.pretrained_model_name_or_path, num_classes=config.num_classes, label_order_path=config.label_order_path, use_sigmoid=config.use_sigmoid, use_cuda=config.use_cuda, lr_prototypes=config.lr_prototypes, lr_features=config.lr_features, lr_others=config.lr_others, num_training_steps=config.num_training_steps, num_warmup_steps=config.num_warmup_steps, loss=config.loss, save_dir=config.save_dir, use_attention=config.use_attention, dot_product=config.dot_product, normalize=config.normalize, final_layer=config.final_layer, reduce_hidden_size=config.reduce_hidden_size, use_prototype_loss=config.use_prototype_loss, prototype_vector_path=config.prototype_vector_path, attention_vector_path=config.attention_vector_path, eval_buckets=config.eval_buckets, seed=config.seed ) self.init_weights() def forward(self, input_ids, attention_mask, token_type_ids, **kwargs): batch = { "input_ids": input_ids, "attention_masks": attention_mask, "token_type_ids": token_type_ids, } logits, metadata = self.proto_module(batch) return {"logits": logits, "metadata": metadata}