File size: 1,944 Bytes
2a33798
5f98914
2a33798
d731338
 
1aa4792
d731338
 
 
 
 
 
 
 
1aa4792
d731338
 
 
2a33798
d731338
5f98914
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from pydantic import BaseModel, Field, validator
from typing import List

class DisActionModel(BaseModel):
    action: int = Field(description="the chosen action to perform")

    @classmethod
    def create_validator(cls, max_action):
        @validator('action', allow_reuse=True)
        def action_is_valid(cls, field):
            if field not in range(1, max_action + 1):
                raise ValueError(f"Action is not valid ([1, {max_action}])!")
            return field
        return action_is_valid

# Generate classes dynamically
def generate_action_class(max_action):
    return type(f"{max_action}Action", (DisActionModel,), {'action_is_valid': DisActionModel.create_validator(max_action)})
    
# Dictionary of parsers with dynamic class generation
DISPARSERS = {num: generate_action_class(num) for num in [2, 3, 4, 6, 9, 18]}

class ContinuousActionBase(BaseModel):
    action: List[float] = Field(description="the chosen continuous actions to perform")

    @classmethod
    def set_expected_length(cls, length):
        cls.expected_length = length

    @validator('action', pre=True)
    def validate_length(cls, action):
        if len(action) != cls.expected_length:
            raise ValueError(f"The action list must have exactly {cls.expected_length} items.")
        return action

    @validator('action', each_item=True)
    def action_is_valid(cls, item):
        if not -1 <= item <= 1:
            raise ValueError("Each action dimension must be in the range [-1, 1]!")
        return item

# Generate classes dynamically
def generate_continuous_action_class(expected_length):
    NewClass = type(
        f"{expected_length}DContinuousAction",
        (ContinuousActionBase,),
        {}
    )
    NewClass.set_expected_length(expected_length)
    return NewClass


# Dictionary of parsers with dynamic class generation
CONPARSERS = {length: generate_continuous_action_class(length) for length in range(1, 17)}