pandrei7's picture
Save the first version of the model
56dfd9c
raw
history blame contribute delete
628 Bytes
from transformers import PretrainedConfig
class AutextificationMTLConfig(PretrainedConfig):
model_type = "custom-text-classifier"
def __init__(
self,
transformer_name: str = "xlm-roberta-base",
hidden_nodes: int = 64,
threshold: float = 0.9919,
**kwargs,
):
if hidden_nodes <= 0:
raise ValueError(
f"`hidden_size` must be a positive number, got {hidden_nodes}."
)
self.transformer_name = transformer_name
self.hidden_nodes = hidden_nodes
self.threshold = threshold
super().__init__(**kwargs)