''' Converts UniRep model checkpoints to weights ''' import tensorflow.compat.v1 as tf import numpy as np import glob import os checkpoint_dir = 'weights/evotuned_release_ckpt/model-13560' target_dir = 'weights/evotuned_release' def dump_weights(sess, dir_name): """ Saves the weights of the model in dir_name in the format required for loading in this module. Must be called within a tf.Session For which the weights are already initialized. """ vs = tf.trainable_variables() for v in vs: name = v.name value = sess.run(v) print(name) np.save(os.path.join(dir_name,name.replace('/', '_') + ".npy"), np.array(value)) with tf.Session() as sess: # Restore variables from disk. saver = tf.train.import_meta_graph(checkpoint_dir + '.meta') saver.restore(sess, checkpoint_dir) print("Variables restored from %s, writing to target dir %s." % (checkpoint_dir, target_dir)) print("Saved variables:") dump_weights(sess, dir_name=target_dir)