Spaces:
Running
on
T4
Running
on
T4
File size: 1,748 Bytes
b03a8f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# configuration_auto.py
from collections import OrderedDict
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from ...models.camn_audio.configuration_camn_audio import CamnAudioConfig
logger = logging.get_logger(__name__)
CONFIG_MAPPING = OrderedDict(
[
("camn_audio", CamnAudioConfig),
# Add other model configurations here if needed
]
)
class AutoConfig:
r"""
AutoConfig is a generic configuration class to instantiate a model configuration.
It is designed to be instantiated using the `from_pretrained` method:
config = AutoConfig.from_pretrained("model_name_or_path")
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
from transformers import PretrainedConfig
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "model_type" in config_dict:
model_type = config_dict["model_type"]
if model_type in CONFIG_MAPPING:
config_class = CONFIG_MAPPING[model_type]
return config_class.from_dict(config_dict, **kwargs)
else:
raise ValueError(f"Unrecognized model type {model_type} in config.json.")
else:
# Fallback: If no model_type, try all config classes and see if any matches
for pattern, config_class in CONFIG_MAPPING.items():
if pattern in pretrained_model_name_or_path:
return config_class.from_dict(config_dict, **kwargs)
# If still not found
raise ValueError("Config file is missing the `model_type` field and no default model type could be inferred.")
|