File size: 1,132 Bytes
903ea30
 
e0602a9
f7a2263
cbbf039
e0602a9
d2e7f27
ce34d64
cbbf039
 
d2e7f27
 
e0602a9
0f74464
 
 
 
903ea30
 
d2e7f27
 
 
f7a2263
 
 
 
d2e7f27
cbbf039
 
 
 
903ea30
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""Module to load prompt strategies."""

import importlib
import inspect
import logging

from axolotl.prompt_strategies.user_defined import UserDefinedDatasetConfig

LOG = logging.getLogger("axolotl.prompt_strategies")


def load(strategy, tokenizer, cfg, ds_cfg):
    try:
        load_fn = "load"
        if strategy.split(".")[-1].startswith("load_"):
            load_fn = strategy.split(".")[-1]
            strategy = ".".join(strategy.split(".")[:-1])
        mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
        func = getattr(mod, load_fn)
        load_kwargs = {}
        if strategy == "user_defined":
            load_kwargs["ds_cfg"] = UserDefinedDatasetConfig(**ds_cfg)
        else:
            sig = inspect.signature(func)
            if "ds_cfg" in sig.parameters:
                load_kwargs["ds_cfg"] = ds_cfg
        return func(tokenizer, cfg, **load_kwargs)
    except ModuleNotFoundError:
        return None
    except Exception as exc:  # pylint: disable=broad-exception-caught
        LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
        return None