PreMode / analysis /Hsu.et.al.git /src /convert_unirep_checkpoint_to_weights.py
gzhong's picture
Upload folder using huggingface_hub
7718235 verified
'''
Converts UniRep model checkpoints to weights
'''
import tensorflow.compat.v1 as tf
import numpy as np
import glob
import os
checkpoint_dir = 'weights/evotuned_release_ckpt/model-13560'
target_dir = 'weights/evotuned_release'
def dump_weights(sess, dir_name):
"""
Saves the weights of the model in dir_name in the format required
for loading in this module. Must be called within a tf.Session
For which the weights are already initialized.
"""
vs = tf.trainable_variables()
for v in vs:
name = v.name
value = sess.run(v)
print(name)
np.save(os.path.join(dir_name,name.replace('/', '_') + ".npy"), np.array(value))
with tf.Session() as sess:
# Restore variables from disk.
saver = tf.train.import_meta_graph(checkpoint_dir + '.meta')
saver.restore(sess, checkpoint_dir)
print("Variables restored from %s, writing to target dir %s." % (checkpoint_dir, target_dir))
print("Saved variables:")
dump_weights(sess, dir_name=target_dir)