from transformers import PretrainedConfig class TotalClassifierConfig(PretrainedConfig): model_type = "total_classifier" def __init__( self, backbone: str = "tf_efficientnetv2_b0", feature_dim: int = 192, cnn_dropout: float = 0.1, in_chans: int = 1, rnn_type: str = "GRU", rnn_num_layers: int = 1, rnn_dropout: float = 0.0, num_classes: int = 117, seq_len: int = 512, linear_dropout: float = 0.1, image_size: tuple[int, int] = (256, 256), **kwargs, ): self.backbone = backbone self.feature_dim = feature_dim self.cnn_dropout = cnn_dropout self.in_chans = in_chans self.rnn_type = rnn_type self.rnn_num_layers = rnn_num_layers self.rnn_dropout = rnn_dropout self.num_classes = num_classes self.seq_len = seq_len self.linear_dropout = linear_dropout self.image_size = image_size super().__init__(**kwargs)