Last commit not found
""" | |
Test classes for checking functionality of the cfg normalization | |
""" | |
import unittest | |
from unittest.mock import patch | |
from axolotl.utils.config import normalize_cfg_datasets, normalize_config | |
from axolotl.utils.dict import DictDefault | |
class NormalizeConfigTestCase(unittest.TestCase): | |
""" | |
test class for normalize_config checks | |
""" | |
def _get_base_cfg(self): | |
return DictDefault( | |
{ | |
"base_model": "JackFram/llama-68m", | |
"base_model_config": "JackFram/llama-68m", | |
"tokenizer_type": "LlamaTokenizer", | |
"num_epochs": 1, | |
"micro_batch_size": 1, | |
"gradient_accumulation_steps": 1, | |
} | |
) | |
def test_lr_as_float(self): | |
cfg = ( | |
self._get_base_cfg() | |
| DictDefault( # pylint: disable=unsupported-binary-operation | |
{ | |
"learning_rate": "5e-5", | |
} | |
) | |
) | |
normalize_config(cfg) | |
assert cfg.learning_rate == 0.00005 | |
def test_base_model_config_set_when_empty(self): | |
cfg = self._get_base_cfg() | |
del cfg.base_model_config | |
normalize_config(cfg) | |
assert cfg.base_model_config == cfg.base_model | |
def test_chat_template_chatml(self): | |
cfg = DictDefault( | |
{ | |
"chat_template": "chatml", | |
"datasets": [ | |
{ | |
"path": "lorem/ipsum", | |
"type": "sharegpt", | |
"conversation": "vicuna_v1.1", | |
}, | |
{ | |
"path": "sit/amet", | |
"type": "sharegpt", | |
}, | |
], | |
} | |
) | |
normalize_cfg_datasets(cfg) | |
assert cfg.datasets[0].conversation == "vicuna_v1.1" | |
assert cfg.datasets[1].conversation == "chatml" | |
def test_bf16_auto_setter_available(self, mock_bf16_avail): | |
cfg = self._get_base_cfg() | |
cfg.bf16 = "auto" | |
mock_bf16_avail.return_value = True | |
normalize_config(cfg) | |
self.assertTrue(cfg.bf16) | |
self.assertFalse(cfg.fp16) | |
def test_bf16_auto_setter_not_available(self, mock_bf16_avail): | |
cfg = self._get_base_cfg() | |
cfg.bf16 = "auto" | |
cfg.fp16 = None | |
mock_bf16_avail.return_value = False | |
normalize_config(cfg) | |
self.assertFalse(cfg.bf16) | |
self.assertTrue(cfg.fp16) | |
def test_bf16_disables_fp16(self, mock_bf16_avail): | |
cfg = self._get_base_cfg() | |
cfg.bf16 = True | |
cfg.fp16 = False | |
mock_bf16_avail.return_value = True | |
normalize_config(cfg) | |
self.assertTrue(cfg.bf16) | |
self.assertFalse(cfg.fp16) | |