File size: 4,319 Bytes
97b6013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright 2018 Google, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Script that iteratively applies the unsupervised update rule and evaluates the

meta-objective performance.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import flags
from absl import app

from learning_unsupervised_learning import evaluation
from learning_unsupervised_learning import datasets
from learning_unsupervised_learning import architectures
from learning_unsupervised_learning import summary_utils
from learning_unsupervised_learning import meta_objective

import tensorflow as tf
import sonnet as snt

from tensorflow.contrib.framework.python.framework import checkpoint_utils

flags.DEFINE_string("checkpoint_dir", None, "Dir to load pretrained update rule from")
flags.DEFINE_string("train_log_dir", None, "Training log directory")

FLAGS = flags.FLAGS


def train(train_log_dir, checkpoint_dir, eval_every_n_steps=10, num_steps=3000):
  dataset_fn = datasets.mnist.TinyMnist
  w_learner_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateWLearner
  theta_process_fn = architectures.more_local_weight_update.MoreLocalWeightUpdateProcess

  meta_objectives = []
  meta_objectives.append(
      meta_objective.linear_regression.LinearRegressionMetaObjective)
  meta_objectives.append(meta_objective.sklearn.LogisticRegression)

  checkpoint_vars, train_one_step_op, (
      base_model, dataset) = evaluation.construct_evaluation_graph(
          theta_process_fn=theta_process_fn,
          w_learner_fn=w_learner_fn,
          dataset_fn=dataset_fn,
          meta_objectives=meta_objectives)
  batch = dataset()
  pre_logit, outputs = base_model(batch)

  global_step = tf.train.get_or_create_global_step()
  var_list = list(
      snt.get_variables_in_module(base_model, tf.GraphKeys.TRAINABLE_VARIABLES))

  tf.logging.info("all vars")
  for v in tf.all_variables():
    tf.logging.info("   %s" % str(v))
  global_step = tf.train.get_global_step()
  accumulate_global_step = global_step.assign_add(1)
  reset_global_step = global_step.assign(0)

  train_op = tf.group(
      train_one_step_op, accumulate_global_step, name="train_op")

  summary_op = tf.summary.merge_all()

  file_writer = summary_utils.LoggingFileWriter(train_log_dir, regexes=[".*"])
  if checkpoint_dir:
    str_var_list = checkpoint_utils.list_variables(checkpoint_dir)
    name_to_v_map = {v.op.name: v for v in tf.all_variables()}
    var_list = [
        name_to_v_map[vn] for vn, _ in str_var_list if vn in name_to_v_map
    ]
    saver = tf.train.Saver(var_list)
    missed_variables = [
        v.op.name for v in set(
            snt.get_variables_in_scope("LocalWeightUpdateProcess",
                                       tf.GraphKeys.GLOBAL_VARIABLES)) -
        set(var_list)
    ]
    assert len(missed_variables) == 0, "Missed a theta variable."

  hooks = []

  with tf.train.SingularMonitoredSession(master="", hooks=hooks) as sess:

    # global step should be restored from the evals job checkpoint or zero for fresh.
    step = sess.run(global_step)

    if step == 0 and checkpoint_dir:
      tf.logging.info("force restore")
      saver.restore(sess, checkpoint_dir)
      tf.logging.info("force restore done")
      sess.run(reset_global_step)
      step = sess.run(global_step)

    while step < num_steps:
      if step % eval_every_n_steps == 0:
        s, _, step = sess.run([summary_op, train_op, global_step])
        file_writer.add_summary(s, step)
      else:
        _, step = sess.run([train_op, global_step])


def main(argv):
  train(FLAGS.train_log_dir, FLAGS.checkpoint_dir)


if __name__ == "__main__":
  app.run(main)