File size: 657 Bytes
67a9b5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np


def assert_and_normalize_shape(x, length):
    """
    Args:
        x: ndarray
        length: int
    """
    if x.ndim == 0:
        return x
    elif x.ndim == 1:
        if len(x) == 1:
            return x
        elif len(x) == length:
            return x
        else:
            raise ValueError('Incompatible shape!')
    elif x.ndim == 2:
        if x.shape == (1, 1):
            return np.squeeze(x, axis=-1)
        elif x.shape == (length, 1):
            return np.squeeze(x, axis=-1)
        else:
            raise ValueError('Incompatible shape!') 
    else:
        raise ValueError('Incompatible ndim!')