from transformers import PretrainedConfig | |
class AutextificationMTLConfig(PretrainedConfig): | |
model_type = "custom-text-classifier" | |
def __init__( | |
self, | |
transformer_name: str = "xlm-roberta-base", | |
hidden_nodes: int = 64, | |
threshold: float = 0.9919, | |
**kwargs, | |
): | |
if hidden_nodes <= 0: | |
raise ValueError( | |
f"`hidden_size` must be a positive number, got {hidden_nodes}." | |
) | |
self.transformer_name = transformer_name | |
self.hidden_nodes = hidden_nodes | |
self.threshold = threshold | |
super().__init__(**kwargs) | |