File size: 3,185 Bytes
9b19c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import Any

import numpy as np

from tianshou.data import (
    HERReplayBuffer,
    HERReplayBufferManager,
    PrioritizedReplayBuffer,
    PrioritizedReplayBufferManager,
    ReplayBuffer,
    ReplayBufferManager,
)


class VectorReplayBuffer(ReplayBufferManager):
    """VectorReplayBuffer contains n ReplayBuffer with the same size.

    It is used for storing transition from different environments yet keeping the order
    of time.

    :param total_size: the total size of VectorReplayBuffer.
    :param buffer_num: the number of ReplayBuffer it uses, which are under the same
        configuration.

    Other input arguments (stack_num/ignore_obs_next/save_only_last_obs/sample_avail)
    are the same as :class:`~tianshou.data.ReplayBuffer`.

    .. seealso::

        Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
    """

    def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
        assert buffer_num > 0
        size = int(np.ceil(total_size / buffer_num))
        buffer_list = [ReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
        super().__init__(buffer_list)


class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager):
    """PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size.

    It is used for storing transition from different environments yet keeping the order
    of time.

    :param total_size: the total size of PrioritizedVectorReplayBuffer.
    :param buffer_num: the number of PrioritizedReplayBuffer it uses, which are
        under the same configuration.

    Other input arguments (alpha/beta/stack_num/ignore_obs_next/save_only_last_obs/
    sample_avail) are the same as :class:`~tianshou.data.PrioritizedReplayBuffer`.

    .. seealso::

        Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
    """

    def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
        assert buffer_num > 0
        size = int(np.ceil(total_size / buffer_num))
        buffer_list = [PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
        super().__init__(buffer_list)

    def set_beta(self, beta: float) -> None:
        for buffer in self.buffers:
            buffer.set_beta(beta)


class HERVectorReplayBuffer(HERReplayBufferManager):
    """HERVectorReplayBuffer contains n HERReplayBuffer with same size.

    It is used for storing transition from different environments yet keeping the order
    of time.

    :param total_size: the total size of HERVectorReplayBuffer.
    :param buffer_num: the number of HERReplayBuffer it uses, which are
        under the same configuration.

    Other input arguments are the same as :class:`~tianshou.data.HERReplayBuffer`.

    .. seealso::
        Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
    """

    def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
        assert buffer_num > 0
        size = int(np.ceil(total_size / buffer_num))
        buffer_list = [HERReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
        super().__init__(buffer_list)