LINC-BIT's picture
Upload 1912 files
b84549f verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import tensorflow as tf
_counter = 0
def global_mutable_counting():
global _counter
_counter += 1
return _counter
class AverageMeter:
def __init__(self, name):
self.name = name
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val):
self.val = val
self.sum += val
self.count += 1
self.avg = self.sum / self.count
def __str__(self):
return '{name} {val:4f} ({avg:4f})'.format(**self.__dict__)
def summary(self):
return '{name}: {avg:4f}'.format(**self.__dict__)
class AverageMeterGroup:
def __init__(self):
self.meters = {}
def update(self, data):
for k, v in data.items():
if k not in self.meters:
self.meters[k] = AverageMeter(k)
self.meters[k].update(v)
def __str__(self):
return ' '.join(str(v) for v in self.meters.values())
def summary(self):
return ' '.join(v.summary() for v in self.meters.values())
class StructuredMutableTreeNode:
def __init__(self, mutable):
self.mutable = mutable
self.children = []
def add_child(self, mutable):
self.children.append(StructuredMutableTreeNode(mutable))
return self.children[-1]
def type(self):
return type(self.mutable)
def __iter__(self):
return self.traverse()
def traverse(self, order="pre", deduplicate=True, memo=None):
if memo is None:
memo = set()
assert order in ["pre", "post"]
if order == "pre":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
for child in self.children:
for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo):
yield m
if order == "post":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
def fill_zero_grads(grads, weights):
ret = []
for grad, weight in zip(grads, weights):
if grad is not None:
ret.append(grad)
else:
ret.append(tf.zeros_like(weight))
return ret