File size: 2,368 Bytes
46b0409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bb60fd
46b0409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# All supported agents
import os
from MCAgent import MCAgent
from DPAgent import DPAgent
import warnings

AGENTS_MAP = {"MCAgent": MCAgent, "DPAgent": DPAgent}


def load_agent(agent_key, **kwargs):
    """
    Loads an agent from a file or from the AGENTS_MAP.
    :param agent_key: Which agent to load. Can be a key in AGENTS_MAP or a path to a policy file ending with ".npy".
                      If a policy file is provided, the agent name, environment name, and other parameters will be parsed from the file name.
    :param kwargs: Additional arguments to pass to the agent constructor. If loading from a policy file, any conflicting arguments will be overwritten.
    """
    agent_policy_file = agent_key if agent_key.endswith(".npy") else None
    # if loading from a policy file, parse the agent key, environment key, and other parameters from the file name
    if agent_policy_file is not None:
        props = os.path.basename(agent_key).split("_")
        try:
            # Parsing arguments from file name
            agent_key, env_key = props[0], props[1]
            parsed_args = {}
            for prop in props[2:]:
                props_split = prop.split(":")
                if len(props_split) == 2:
                    parsed_args[props_split[0]] = props_split[1]
                else:
                    warnings.warn(
                        f"Skipping property {prop} as it does not have the format 'key:value'.",
                        UserWarning,
                    )
            # Overwrite any conflicting arguments with those from the file name
            parsed_args["env"] = env_key
            kwargs.update(parsed_args)
            print("agent_args:", kwargs)
        except IndexError as e:
            raise ValueError(
                "ERROR: Could not parse agent properties. Must be of the format 'AgentName_EnvName_key:value_key:value...'."
            ) from e

    # Check if agent key is valid
    if agent_key not in AGENTS_MAP:
        raise ValueError(
            f"ERROR: Agent '{agent_key}' not valid. Must be one of: {AGENTS_MAP.keys()}"
        )

    # Load agent based on key and arguments
    agent = AGENTS_MAP[agent_key](**kwargs)
    # If loading from a policy file, load the policy into the agent
    if agent_policy_file is not None:
        agent.load_policy(agent_policy_file)

    return agent