File size: 1,425 Bytes
e986ee1
 
 
 
38f87b5
 
 
 
e986ee1
 
 
38f87b5
 
 
 
 
 
 
 
 
 
 
 
e986ee1
 
 
 
38f87b5
e986ee1
 
38f87b5
e986ee1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from functools import partial

import numpy as np


def gather(a, index):
    return a[np.arange(np.shape(a)[0]), index]


import raven_utils.group as group


def def_init_image(shape=(10, 64, 64, 3), mode="uniform", min=0, max=1):
    if mode == "normal" or mode == "n":
        return np.random.normal(min, max, shape)
    elif mode == "zero" or mode == 0:
        return np.zeros(shape)
    elif mode == "one" or mode == 1:
        return np.ones(shape)
    elif mode == "int" or isinstance(mode, int):
        return np.random.randint(min, max, shape)
    return np.random.uniform(min, max, shape)


init_image = partial(def_init_image, shape=(16, 8, 80, 80, 1))


def get_val_index(no=group.NO, base=3, add_end=False):
    indexes = np.arange(no) * 2000 + base
    if add_end:
        indexes = np.concatenate([indexes, no * 2000])
    return indexes


def get_matrix(inputs, index):
    return np.concatenate([inputs[:, :8], gather(inputs, index[:, 0])[:, None]], axis=1)


def get_matrix_from_data(x):
    inputs = x["inputs"]
    index = x["index"]
    return get_matrix(inputs, index)


def compare_from_result(result, data):
    data = data.data.data
    answer = D.gather(data['target'].data, data['index'].data[:, 0])
    import raven_utils as rv
    predict = result['predict']
    predict_mask = result['predict_mask']
    return np.all(rv.decode.compare(answer[:len(predict)], predict, predict_mask), axis=-1)