File size: 5,972 Bytes
9b19c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import pickle
from copy import deepcopy
from numbers import Number
from typing import Any, Union, no_type_check

import h5py
import numpy as np
import torch

from tianshou.data.batch import Batch, _parse_value


# TODO: confusing name, could actually return a batch...
#  Overrides and generic types should be added
# todo check for ActBatchProtocol
@no_type_check
def to_numpy(x: Any) -> Batch | np.ndarray:
    """Return an object without torch.Tensor."""
    if isinstance(x, torch.Tensor):  # most often case
        return x.detach().cpu().numpy()
    if isinstance(x, np.ndarray):  # second often case
        return x
    if isinstance(x, np.number | np.bool_ | Number):
        return np.asanyarray(x)
    if x is None:
        return np.array(None, dtype=object)
    if isinstance(x, dict | Batch):
        x = Batch(x) if isinstance(x, dict) else deepcopy(x)
        x.to_numpy_()
        return x
    if isinstance(x, list | tuple):
        return to_numpy(_parse_value(x))
    # fallback
    return np.asanyarray(x)


@no_type_check
def to_torch(
    x: Any,
    dtype: torch.dtype | None = None,
    device: str | int | torch.device = "cpu",
) -> Batch | torch.Tensor:
    """Return an object without np.ndarray."""
    if isinstance(x, np.ndarray) and issubclass(
        x.dtype.type,
        np.bool_ | np.number,
    ):  # most often case
        x = torch.from_numpy(x).to(device)
        if dtype is not None:
            x = x.type(dtype)
        return x
    if isinstance(x, torch.Tensor):  # second often case
        if dtype is not None:
            x = x.type(dtype)
        return x.to(device)
    if isinstance(x, np.number | np.bool_ | Number):
        return to_torch(np.asanyarray(x), dtype, device)
    if isinstance(x, dict | Batch):
        x = Batch(x, copy=True) if isinstance(x, dict) else deepcopy(x)
        x.to_torch_(dtype, device)
        return x
    if isinstance(x, list | tuple):
        return to_torch(_parse_value(x), dtype, device)
    # fallback
    raise TypeError(f"object {x} cannot be converted to torch.")


@no_type_check
def to_torch_as(x: Any, y: torch.Tensor) -> Batch | torch.Tensor:
    """Return an object without np.ndarray.

    Same as ``to_torch(x, dtype=y.dtype, device=y.device)``.
    """
    assert isinstance(y, torch.Tensor)
    return to_torch(x, dtype=y.dtype, device=y.device)


# Note: object is used as a proxy for objects that can be pickled
# Note: mypy does not support cyclic definition currently
Hdf5ConvertibleValues = Union[
    int,
    float,
    Batch,
    np.ndarray,
    torch.Tensor,
    object,
    "Hdf5ConvertibleType",
]

Hdf5ConvertibleType = dict[str, Hdf5ConvertibleValues]


def to_hdf5(x: Hdf5ConvertibleType, y: h5py.Group, compression: str | None = None) -> None:
    """Copy object into HDF5 group."""

    def to_hdf5_via_pickle(
        x: object,
        y: h5py.Group,
        key: str,
        compression: str | None = None,
    ) -> None:
        """Pickle, convert to numpy array and write to HDF5 dataset."""
        data = np.frombuffer(pickle.dumps(x), dtype=np.byte)
        y.create_dataset(key, data=data, compression=compression)

    for k, v in x.items():
        if isinstance(v, Batch | dict):
            # dicts and batches are both represented by groups
            subgrp = y.create_group(k)
            if isinstance(v, Batch):
                subgrp_data = v.__getstate__()
                subgrp.attrs["__data_type__"] = "Batch"
            else:
                subgrp_data = v
            to_hdf5(subgrp_data, subgrp, compression=compression)
        elif isinstance(v, torch.Tensor):
            # PyTorch tensors are written to datasets
            y.create_dataset(k, data=to_numpy(v), compression=compression)
            y[k].attrs["__data_type__"] = "Tensor"
        elif isinstance(v, np.ndarray):
            try:
                # NumPy arrays are written to datasets
                y.create_dataset(k, data=v, compression=compression)
                y[k].attrs["__data_type__"] = "ndarray"
            except TypeError:
                # If data type is not supported by HDF5 fall back to pickle.
                # This happens if dtype=object (e.g. due to entries being None)
                # and possibly in other cases like structured arrays.
                try:
                    to_hdf5_via_pickle(v, y, k, compression=compression)
                except Exception as exception:
                    raise RuntimeError(
                        f"Attempted to pickle {v.__class__.__name__} due to "
                        "data type not supported by HDF5 and failed.",
                    ) from exception
                y[k].attrs["__data_type__"] = "pickled_ndarray"
        elif isinstance(v, int | float):
            # ints and floats are stored as attributes of groups
            y.attrs[k] = v
        else:  # resort to pickle for any other type of object
            try:
                to_hdf5_via_pickle(v, y, k, compression=compression)
            except Exception as exception:
                raise NotImplementedError(
                    f"No conversion to HDF5 for object of type '{type(v)}' "
                    "implemented and fallback to pickle failed.",
                ) from exception
            y[k].attrs["__data_type__"] = v.__class__.__name__


def from_hdf5(x: h5py.Group, device: str | None = None) -> Hdf5ConvertibleValues:
    """Restore object from HDF5 group."""
    if isinstance(x, h5py.Dataset):
        # handle datasets
        if x.attrs["__data_type__"] == "ndarray":
            return np.array(x)
        if x.attrs["__data_type__"] == "Tensor":
            return torch.tensor(x, device=device)
        return pickle.loads(x[()])
    # handle groups representing a dict or a Batch
    y = dict(x.attrs.items())
    data_type = y.pop("__data_type__", None)
    for k, v in x.items():
        y[k] = from_hdf5(v, device)
    return Batch(y) if data_type == "Batch" else y