File size: 382 Bytes
3e2038a
d678220
3e2038a
 
6ee82fe
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# All supported agents
from MCAgent import MCAgent
from DPAgent import DPAgent

AGENTS_MAP = {"MCAgent": MCAgent, "DPAgent": DPAgent}


def load_agent(agent_name, **kwargs):
    if agent_name not in AGENTS_MAP:
        raise ValueError(
            f"ERROR: Agent '{agent_name}' not valid. Must be one of: {AGENTS_MAP.keys()}"
        )

    return AGENTS_MAP[agent_name](**kwargs)