ProtoPatient / proto_model /modeling_proto.py
row56's picture
Upload proto_model/modeling_proto.py with huggingface_hub
832f402 verified
from transformers import PreTrainedModel
import torch
from .proto import ProtoModule
from .configuration_proto import ProtoConfig
class ProtoForMultiLabelClassification(PreTrainedModel):
config_class = ProtoConfig
def __init__(self, config: ProtoConfig):
super().__init__(config)
self.proto_module = ProtoModule(
pretrained_model=config.pretrained_model_name_or_path,
num_classes=config.num_classes,
label_order_path=config.label_order_path,
use_sigmoid=config.use_sigmoid,
use_cuda=config.use_cuda,
lr_prototypes=config.lr_prototypes,
lr_features=config.lr_features,
lr_others=config.lr_others,
num_training_steps=config.num_training_steps,
num_warmup_steps=config.num_warmup_steps,
loss=config.loss,
save_dir=config.save_dir,
use_attention=config.use_attention,
dot_product=config.dot_product,
normalize=config.normalize,
final_layer=config.final_layer,
reduce_hidden_size=config.reduce_hidden_size,
use_prototype_loss=config.use_prototype_loss,
prototype_vector_path=config.prototype_vector_path,
attention_vector_path=config.attention_vector_path,
eval_buckets=config.eval_buckets,
seed=config.seed
)
self.init_weights()
def forward(self, input_ids, attention_mask, token_type_ids, **kwargs):
batch = {
"input_ids": input_ids,
"attention_masks": attention_mask,
"token_type_ids": token_type_ids,
}
logits, metadata = self.proto_module(batch)
return {"logits": logits, "metadata": metadata}