# Copyright 2016 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== r"""Generaly Utilities. """ import numpy as np, cPickle, os, time from six.moves import xrange import src.file_utils as fu import logging class Timer(): def __init__(self): self.calls = 0. self.start_time = 0. self.time_per_call = 0. self.total_time = 0. self.last_log_time = 0. def tic(self): self.start_time = time.time() def toc(self, average=True, log_at=-1, log_str='', type='calls'): if self.start_time == 0: logging.error('Timer not started by calling tic().') t = time.time() diff = time.time() - self.start_time self.total_time += diff self.calls += 1. self.time_per_call = self.total_time/self.calls if type == 'calls' and log_at > 0 and np.mod(self.calls, log_at) == 0: _ = [] logging.info('%s: %f seconds.', log_str, self.time_per_call) elif type == 'time' and log_at > 0 and t - self.last_log_time >= log_at: _ = [] logging.info('%s: %f seconds.', log_str, self.time_per_call) self.last_log_time = t if average: return self.time_per_call else: return diff class Foo(object): def __init__(self, **kwargs): self.__dict__.update(kwargs) def __str__(self): str_ = '' for v in vars(self).keys(): a = getattr(self, v) if True: #isinstance(v, object): str__ = str(a) str__ = str__.replace('\n', '\n ') else: str__ = str(a) str_ += '{:s}: {:s}'.format(v, str__) str_ += '\n' return str_ def dict_equal(dict1, dict2): assert(set(dict1.keys()) == set(dict2.keys())), "Sets of keys between 2 dictionaries are different." for k in dict1.keys(): assert(type(dict1[k]) == type(dict2[k])), "Type of key '{:s}' if different.".format(k) if type(dict1[k]) == np.ndarray: assert(dict1[k].dtype == dict2[k].dtype), "Numpy Type of key '{:s}' if different.".format(k) assert(np.allclose(dict1[k], dict2[k])), "Value for key '{:s}' do not match.".format(k) else: assert(dict1[k] == dict2[k]), "Value for key '{:s}' do not match.".format(k) return True def subplot(plt, Y_X, sz_y_sz_x = (10, 10)): Y,X = Y_X sz_y, sz_x = sz_y_sz_x plt.rcParams['figure.figsize'] = (X*sz_x, Y*sz_y) fig, axes = plt.subplots(Y, X) plt.subplots_adjust(wspace=0.1, hspace=0.1) return fig, axes def tic_toc_print(interval, string): global tic_toc_print_time_old if 'tic_toc_print_time_old' not in globals(): tic_toc_print_time_old = time.time() print(string) else: new_time = time.time() if new_time - tic_toc_print_time_old > interval: tic_toc_print_time_old = new_time; print(string) def mkdir_if_missing(output_dir): if not fu.exists(output_dir): fu.makedirs(output_dir) def save_variables(pickle_file_name, var, info, overwrite = False): if fu.exists(pickle_file_name) and overwrite == False: raise Exception('{:s} exists and over write is false.'.format(pickle_file_name)) # Construct the dictionary assert(type(var) == list); assert(type(info) == list); d = {} for i in xrange(len(var)): d[info[i]] = var[i] with fu.fopen(pickle_file_name, 'w') as f: cPickle.dump(d, f, cPickle.HIGHEST_PROTOCOL) def load_variables(pickle_file_name): if fu.exists(pickle_file_name): with fu.fopen(pickle_file_name, 'r') as f: d = cPickle.load(f) return d else: raise Exception('{:s} does not exists.'.format(pickle_file_name)) def voc_ap(rec, prec): rec = rec.reshape((-1,1)) prec = prec.reshape((-1,1)) z = np.zeros((1,1)) o = np.ones((1,1)) mrec = np.vstack((z, rec, o)) mpre = np.vstack((z, prec, z)) for i in range(len(mpre)-2, -1, -1): mpre[i] = max(mpre[i], mpre[i+1]) I = np.where(mrec[1:] != mrec[0:-1])[0]+1; ap = 0; for i in I: ap = ap + (mrec[i] - mrec[i-1])*mpre[i]; return ap def tight_imshow_figure(plt, figsize=None): fig = plt.figure(figsize=figsize) ax = plt.Axes(fig, [0,0,1,1]) ax.set_axis_off() fig.add_axes(ax) return fig, ax def calc_pr(gt, out, wt=None): if wt is None: wt = np.ones((gt.size,1)) gt = gt.astype(np.float64).reshape((-1,1)) wt = wt.astype(np.float64).reshape((-1,1)) out = out.astype(np.float64).reshape((-1,1)) gt = gt*wt tog = np.concatenate([gt, wt, out], axis=1)*1. ind = np.argsort(tog[:,2], axis=0)[::-1] tog = tog[ind,:] cumsumsortgt = np.cumsum(tog[:,0]) cumsumsortwt = np.cumsum(tog[:,1]) prec = cumsumsortgt / cumsumsortwt rec = cumsumsortgt / np.sum(tog[:,0]) ap = voc_ap(rec, prec) return ap, rec, prec