from transformers import PretrainedConfig | |
class SSLConfig(PretrainedConfig): | |
model_type = "ssl-aasist" | |
def __init__( | |
self, | |
filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]], | |
gat_dims = [64, 32], | |
pool_ratios = [0.5, 0.5, 0.5, 0.5], | |
temperatures = [2.0, 2.0, 100.0, 100.0], | |
**kwargs, | |
): | |
self.filts = filts | |
self.gat_dims = gat_dims | |
self.pool_ratios = pool_ratios | |
self.temperatures = temperatures | |
super().__init__(**kwargs) |