|
from transformers import PreTrainedModel |
|
import torch |
|
from .proto import ProtoModule |
|
from .configuration_proto import ProtoConfig |
|
|
|
class ProtoForMultiLabelClassification(PreTrainedModel): |
|
config_class = ProtoConfig |
|
|
|
def __init__(self, config: ProtoConfig): |
|
super().__init__(config) |
|
self.proto_module = ProtoModule( |
|
pretrained_model=config.pretrained_model_name_or_path, |
|
num_classes=config.num_classes, |
|
label_order_path=config.label_order_path, |
|
use_sigmoid=config.use_sigmoid, |
|
use_cuda=config.use_cuda, |
|
lr_prototypes=config.lr_prototypes, |
|
lr_features=config.lr_features, |
|
lr_others=config.lr_others, |
|
num_training_steps=config.num_training_steps, |
|
num_warmup_steps=config.num_warmup_steps, |
|
loss=config.loss, |
|
save_dir=config.save_dir, |
|
use_attention=config.use_attention, |
|
dot_product=config.dot_product, |
|
normalize=config.normalize, |
|
final_layer=config.final_layer, |
|
reduce_hidden_size=config.reduce_hidden_size, |
|
use_prototype_loss=config.use_prototype_loss, |
|
prototype_vector_path=config.prototype_vector_path, |
|
attention_vector_path=config.attention_vector_path, |
|
eval_buckets=config.eval_buckets, |
|
seed=config.seed |
|
) |
|
self.init_weights() |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids, **kwargs): |
|
batch = { |
|
"input_ids": input_ids, |
|
"attention_masks": attention_mask, |
|
"token_type_ids": token_type_ids, |
|
} |
|
logits, metadata = self.proto_module(batch) |
|
return {"logits": logits, "metadata": metadata} |
|
|