|
|
|
|
|
|
|
import tensorflow as tf |
|
|
|
_counter = 0 |
|
|
|
def global_mutable_counting(): |
|
global _counter |
|
_counter += 1 |
|
return _counter |
|
|
|
|
|
class AverageMeter: |
|
def __init__(self, name): |
|
self.name = name |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val): |
|
self.val = val |
|
self.sum += val |
|
self.count += 1 |
|
self.avg = self.sum / self.count |
|
|
|
def __str__(self): |
|
return '{name} {val:4f} ({avg:4f})'.format(**self.__dict__) |
|
|
|
def summary(self): |
|
return '{name}: {avg:4f}'.format(**self.__dict__) |
|
|
|
|
|
class AverageMeterGroup: |
|
def __init__(self): |
|
self.meters = {} |
|
|
|
def update(self, data): |
|
for k, v in data.items(): |
|
if k not in self.meters: |
|
self.meters[k] = AverageMeter(k) |
|
self.meters[k].update(v) |
|
|
|
def __str__(self): |
|
return ' '.join(str(v) for v in self.meters.values()) |
|
|
|
def summary(self): |
|
return ' '.join(v.summary() for v in self.meters.values()) |
|
|
|
|
|
class StructuredMutableTreeNode: |
|
def __init__(self, mutable): |
|
self.mutable = mutable |
|
self.children = [] |
|
|
|
def add_child(self, mutable): |
|
self.children.append(StructuredMutableTreeNode(mutable)) |
|
return self.children[-1] |
|
|
|
def type(self): |
|
return type(self.mutable) |
|
|
|
def __iter__(self): |
|
return self.traverse() |
|
|
|
def traverse(self, order="pre", deduplicate=True, memo=None): |
|
if memo is None: |
|
memo = set() |
|
assert order in ["pre", "post"] |
|
if order == "pre": |
|
if self.mutable is not None: |
|
if not deduplicate or self.mutable.key not in memo: |
|
memo.add(self.mutable.key) |
|
yield self.mutable |
|
for child in self.children: |
|
for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo): |
|
yield m |
|
if order == "post": |
|
if self.mutable is not None: |
|
if not deduplicate or self.mutable.key not in memo: |
|
memo.add(self.mutable.key) |
|
yield self.mutable |
|
|
|
|
|
def fill_zero_grads(grads, weights): |
|
ret = [] |
|
for grad, weight in zip(grads, weights): |
|
if grad is not None: |
|
ret.append(grad) |
|
else: |
|
ret.append(tf.zeros_like(weight)) |
|
return ret |
|
|