File size: 900 Bytes
d1ceb73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# mypy: allow-untyped-defs
r"""This file provides a location for operators that help exporting models via onnx.

E.g. `shape_as_tensor` and `reshape_from_tensor_shape`
are to make all dynamic sizes operations traceable.

NOTE: at one point these functions were implemented differently.
Since then we have implemented these directly in ATen, so this
file is kept purely for backward-compatibility.
"""

import torch
import torch.onnx


def shape_as_tensor(x):
    """Get the shape of a tensor as a tensor.

    Args:
        x (Tensor): The input tensor.

    Returns:
        Tensor: A tensor of shape [len(x.shape)] containing the size of each dimension of x.

    Example:
        >>> x = torch.randn(2, 3)
        >>> shape_as_tensor(x)
        tensor([2, 3])

    """
    return torch._shape_as_tensor(x)


def reshape_from_tensor_shape(x, shape):
    return torch._reshape_from_tensor(x, shape)