File size: 2,468 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import tensorflow as tf

_counter = 0

def global_mutable_counting():
    global _counter
    _counter += 1
    return _counter


class AverageMeter:
    def __init__(self, name):
        self.name = name
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val):
        self.val = val
        self.sum += val
        self.count += 1
        self.avg = self.sum / self.count

    def __str__(self):
        return '{name} {val:4f} ({avg:4f})'.format(**self.__dict__)

    def summary(self):
        return '{name}: {avg:4f}'.format(**self.__dict__)


class AverageMeterGroup:
    def __init__(self):
        self.meters = {}

    def update(self, data):
        for k, v in data.items():
            if k not in self.meters:
                self.meters[k] = AverageMeter(k)
            self.meters[k].update(v)

    def __str__(self):
        return '  '.join(str(v) for v in self.meters.values())

    def summary(self):
        return '  '.join(v.summary() for v in self.meters.values())


class StructuredMutableTreeNode:
    def __init__(self, mutable):
        self.mutable = mutable
        self.children = []

    def add_child(self, mutable):
        self.children.append(StructuredMutableTreeNode(mutable))
        return self.children[-1]

    def type(self):
        return type(self.mutable)

    def __iter__(self):
        return self.traverse()

    def traverse(self, order="pre", deduplicate=True, memo=None):
        if memo is None:
            memo = set()
        assert order in ["pre", "post"]
        if order == "pre":
            if self.mutable is not None:
                if not deduplicate or self.mutable.key not in memo:
                    memo.add(self.mutable.key)
                    yield self.mutable
        for child in self.children:
            for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo):
                yield m
        if order == "post":
            if self.mutable is not None:
                if not deduplicate or self.mutable.key not in memo:
                    memo.add(self.mutable.key)
                    yield self.mutable


def fill_zero_grads(grads, weights):
    ret = []
    for grad, weight in zip(grads, weights):
        if grad is not None:
            ret.append(grad)
        else:
            ret.append(tf.zeros_like(weight))
    return ret