File size: 22,116 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93528c6
5672777
 
 
 
 
 
93528c6
 
 
 
 
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Training utils."""

import dataclasses
import inspect
import json
import os
import pprint
from typing import Any, Callable, Dict, List, Optional, Union

from absl import logging
import gin
import numpy as np
import orbit
import tensorflow as tf, tf_keras

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import ops
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import
from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions
from official.core import exp_factory
from official.modeling import hyperparams


BEST_CHECKPOINT_NAME = 'best_ckpt'


def get_leaf_nested_dict(d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
  """Get leaf from a dictionary with arbitrary depth with a list of keys.

  Args:
    d: The dictionary to extract value from.
    keys: The list of keys to extract values recursively.

  Returns:
    The value of the leaf.

  Raises:
    KeyError: If the value of keys extracted is a dictionary.
  """
  leaf = d
  for k in keys:
    if not isinstance(leaf, dict) or k not in leaf:
      raise KeyError(
          'Path not exist while traversing the dictionary: d with keys'
          ': %s.' % keys)
    leaf = leaf[k]

  if isinstance(leaf, dict):
    raise KeyError('The value extracted with keys: %s is not a leaf of the '
                   'dictionary: %s.' % (keys, d))
  return leaf


def cast_leaf_nested_dict(d: Dict[str, Any],
                          cast_fn: Callable[[Any], Any]) -> Dict[str, Any]:
  """Cast the leaves of a dictionary with arbitrary depth in place.

  Args:
    d: The dictionary to extract value from.
    cast_fn: The casting function.

  Returns:
    A dictionray with the same structure as d.
  """
  for key, value in d.items():
    if isinstance(value, dict):
      d[key] = cast_leaf_nested_dict(value, cast_fn)
    else:
      d[key] = cast_fn(value)
  return d


def _filter_leaf_nested_dict(
    d: Dict[str, Any], predicate: Callable[[Any], bool]
) -> Dict[str, Any]:
  """Filters the leaves of a dictionary with arbitrary depth in place.

  Args:
    d: The dictionary to extract value from.
    predicate: A function that will be called on every leave item. When the
      function returns True the leave will be kept. Otherwise the leave will be
      dropped.

  Returns:
    A new dictionray with filtered result.
  """
  result = {}
  for key, value in d.items():
    if isinstance(value, dict):
      result[key] = _filter_leaf_nested_dict(value, predicate)
    elif predicate(value):
      result[key] = value
  return result


def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
                                    data_dir: str) -> Any:
  """Maybe create a BestCheckpointExporter object, according to the config."""
  export_subdir = params.trainer.best_checkpoint_export_subdir
  metric_name = params.trainer.best_checkpoint_eval_metric
  metric_comp = params.trainer.best_checkpoint_metric_comp
  if data_dir and export_subdir and metric_name:
    best_ckpt_dir = os.path.join(data_dir, export_subdir)
    best_ckpt_exporter = BestCheckpointExporter(best_ckpt_dir, metric_name,
                                                metric_comp)
    logging.info(
        'Created the best checkpoint exporter. '
        'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
        export_subdir, metric_name)
  else:
    best_ckpt_exporter = None

  return best_ckpt_exporter


class BestCheckpointExporter:
  """Keeps track of the best result, and saves its checkpoint.

  Orbit will support an API for checkpoint exporter. This class will be used
  together with orbit once this functionality is ready.
  """

  def __init__(self, export_dir: str, metric_name: str, metric_comp: str):
    """Initialization.

    Args:
      export_dir: The directory that will contain exported checkpoints.
      metric_name: Indicates which metric to look at, when determining which
        result is better. If eval_logs being passed to maybe_export_checkpoint
        is a nested dictionary, use `|` as a seperator for different layers.
      metric_comp: Indicates how to compare results. Either `lower` or `higher`.
    """
    self._export_dir = export_dir
    self._metric_name = metric_name.split('|')
    self._metric_comp = metric_comp
    if self._metric_comp not in ('lower', 'higher'):
      raise ValueError('best checkpoint metric comp must be one of '
                       'higher, lower. Got: {}'.format(self._metric_comp))
    tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
    self._best_ckpt_logs = self._maybe_load_best_eval_metric()
    self._checkpoint_manager = None

  def _get_checkpoint_manager(self, checkpoint):
    """Gets an existing checkpoint manager or creates a new one."""
    if self._checkpoint_manager is None or (self._checkpoint_manager.checkpoint
                                            != checkpoint):
      logging.info('Creates a new checkpoint manager.')
      self._checkpoint_manager = tf.train.CheckpointManager(
          checkpoint,
          directory=self._export_dir,
          max_to_keep=1,
          checkpoint_name=BEST_CHECKPOINT_NAME)

    return self._checkpoint_manager

  def maybe_export_checkpoint(
      self, checkpoint, eval_logs, global_step, write_logs=True) -> bool:
    """Compare eval_logs with past eval_logs and export checkpoint if better."""
    logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
                 eval_logs, global_step)
    if self._best_ckpt_logs is None or self._new_metric_is_better(
        self._best_ckpt_logs, eval_logs):
      self._best_ckpt_logs = eval_logs
      if write_logs:
        self.export_best_eval_metric(self._best_ckpt_logs, global_step)
      self._get_checkpoint_manager(checkpoint).save()
      return True
    return False

  def _maybe_load_best_eval_metric(self):
    if not tf.io.gfile.exists(self.best_ckpt_logs_path):
      return None
    with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'r') as reader:
      return json.loads(reader.read())

  def _new_metric_is_better(self, old_logs, new_logs):
    """Check if the metric in new_logs is better than the metric in old_logs."""
    old_value = float(
        orbit.utils.get_value(
            get_leaf_nested_dict(old_logs, self._metric_name)))
    new_value = float(
        orbit.utils.get_value(
            get_leaf_nested_dict(new_logs, self._metric_name)))

    logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
                 old_value, new_value)
    if self._metric_comp == 'higher':
      if new_value > old_value:
        logging.info('[BestCheckpointExporter] '
                     'the new number is better since it is higher.')
        return True
    else:  # self._metric_comp == 'lower':
      if new_value < old_value:
        logging.info('[BestCheckpointExporter] '
                     'the new number is better since it is lower.')
        return True
    return False

  def export_best_eval_metric(self, eval_logs, global_step):
    """Export evaluation results of the best checkpoint into a json file."""
    # eval_log_ext may contains non-scalar tensors, such as image data when
    # `allow_image_summary` is True. Here we only keep scalar tensors.
    eval_logs_ext = _filter_leaf_nested_dict(
        eval_logs, lambda x: tf.rank(x) <= 1
    )
    eval_logs_ext['best_ckpt_global_step'] = global_step
    eval_logs_ext = cast_leaf_nested_dict(
        eval_logs_ext, lambda x: float(orbit.utils.get_value(x)))
    # Saving json file is very fast.
    with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
      writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')

  @property
  def best_ckpt_logs(self):
    return self._best_ckpt_logs

  @property
  def best_ckpt_logs_path(self):
    return os.path.join(self._export_dir, 'info.json')

  @property
  def best_ckpt_path(self):
    """Returns the best ckpt path or None if there is no ckpt yet."""
    return tf.train.latest_checkpoint(self._export_dir)


def create_optimizer(task: base_task.Task,
                     params: config_definitions.ExperimentConfig
                     ) -> tf_keras.optimizers.Optimizer:
  """A create optimizer util to be backward compatability with new args."""
  if 'dp_config' in inspect.signature(task.create_optimizer).parameters:
    dp_config = None
    if hasattr(params.task, 'differential_privacy_config'):
      dp_config = params.task.differential_privacy_config
    optimizer = task.create_optimizer(
        params.trainer.optimizer_config, params.runtime,
        dp_config=dp_config)
  else:
    if hasattr(params.task, 'differential_privacy_config'
              ) and params.task.differential_privacy_config is not None:
      raise ValueError('Differential privacy config is specified but '
                       'task.create_optimizer api does not accept it.')
    optimizer = task.create_optimizer(
        params.trainer.optimizer_config,
        params.runtime)
  return optimizer


@gin.configurable
def create_trainer(params: config_definitions.ExperimentConfig,
                   task: base_task.Task,
                   train: bool,
                   evaluate: bool,
                   checkpoint_exporter: Optional[BestCheckpointExporter] = None,
                   trainer_cls=base_trainer.Trainer) -> base_trainer.Trainer:
  """Create trainer."""
  logging.info('Running default trainer.')
  model = task.build_model()
  optimizer = create_optimizer(task, params)
  return trainer_cls(
      params,
      task,
      model=model,
      optimizer=optimizer,
      train=train,
      evaluate=evaluate,
      checkpoint_exporter=checkpoint_exporter)


@dataclasses.dataclass
class ParseConfigOptions:
  """Use this dataclass instead of FLAGS to customize parse_configuration()."""
  experiment: str
  config_file: List[str]
  tpu: str = ''
  tf_data_service: str = ''
  params_override: str = ''

  def __contains__(self, name):
    return name in dataclasses.asdict(self)


class ExperimentParser:
  """Constructs the Experiment config from Flags or equivalent object.

  Most of the cases, users only need to call the `parse()` function:
  ```
  builder = ExperimentParser(FLAGS)
  params = builder.parse()
  ```

  The advanced users can modify the flow by calling the parse_*() functions
  separately.
  """

  def __init__(self, flags_obj):
    self._flags_obj = flags_obj

  def parse(self):
    """Overrall process of constructing Experiment config."""
    params = self.base_experiment()
    params = self.parse_config_file(params)
    params = self.parse_runtime(params)
    params = self.parse_data_service(params)
    params = self.parse_params_override(params)
    return params

  def base_experiment(self):
    """Get the base experiment config from --experiment field."""
    if self._flags_obj.experiment is None:
      raise ValueError('The flag --experiment must be specified.')
    return exp_factory.get_exp_config(self._flags_obj.experiment)

  def parse_config_file(self, params):
    """Override the configs of params from the config_file."""
    for config_file in self._flags_obj.config_file or []:
      params = hyperparams.override_params_dict(
          params, config_file, is_strict=True)
    return params

  def parse_runtime(self, params):
    """Override the runtime configs of params from flags."""
    # Override the TPU address and tf.data service address.
    params.override({
        'runtime': {
            'tpu': self._flags_obj.tpu,
        },
    })
    return params

  def parse_data_service(self, params):
    """Override the data service configs of params from flags."""
    if ('tf_data_service' in self._flags_obj and
        self._flags_obj.tf_data_service and
        isinstance(params.task, config_definitions.TaskConfig)):
      params.override({
          'task': {
              'train_data': {
                  'tf_data_service_address': self._flags_obj.tf_data_service,
              },
              'validation_data': {
                  'tf_data_service_address': self._flags_obj.tf_data_service,
              }
          }
      })
    return params

  def parse_params_override(self, params):
    # Get the second level of override from `--params_override`.
    # `--params_override` is typically used as a further override over the
    # template. For example, one may define a particular template for training
    # ResNet50 on ImageNet in a config file and pass it via `--config_file`,
    # then define different learning rates and pass it via `--params_override`.
    if self._flags_obj.params_override:
      params = hyperparams.override_params_dict(
          params, self._flags_obj.params_override, is_strict=True)
    return params


def parse_configuration(flags_obj, lock_return=True, print_return=True):
  """Parses ExperimentConfig from flags."""

  params = ExperimentParser(flags_obj).parse()

  params.validate()
  if lock_return:
    params.lock()

  if print_return:
    pp = pprint.PrettyPrinter()
    logging.info('Final experiment parameters:\n%s',
                 pp.pformat(params.as_dict()))

  return params


def serialize_config(params: config_definitions.ExperimentConfig,
                     model_dir: str):
  """Serializes and saves the experiment config."""
  if model_dir is None:
    raise ValueError('model_dir must be specified, but got None')
  params_save_path = os.path.join(model_dir, 'params.yaml')
  logging.info('Saving experiment configuration to %s', params_save_path)
  tf.io.gfile.makedirs(model_dir)
  hyperparams.save_params_dict_to_yaml(params, params_save_path)


def save_gin_config(filename_suffix: str, model_dir: str):
  """Serializes and saves the experiment config."""
  gin_save_path = os.path.join(
      model_dir, 'operative_config.{}.gin'.format(filename_suffix))
  logging.info('Saving gin configurations to %s', gin_save_path)
  tf.io.gfile.makedirs(model_dir)
  with tf.io.gfile.GFile(gin_save_path, 'w') as f:
    f.write(gin.operative_config_str())


def read_global_step_from_checkpoint(ckpt_file_path):
  """Read global step from checkpoint, or get global step from its filename."""
  global_step = tf.Variable(-1, dtype=tf.int64)
  ckpt = tf.train.Checkpoint(global_step=global_step)
  try:
    ckpt.restore(ckpt_file_path).expect_partial()
    global_step_maybe_restored = global_step.numpy()
  except tf.errors.InvalidArgumentError:
    global_step_maybe_restored = -1

  if global_step_maybe_restored == -1:
    raise ValueError('global_step not found in checkpoint {}. '
                     'If you want to run finetune eval jobs, you need to '
                     'make sure that your pretrain model writes '
                     'global_step in its checkpoints.'.format(ckpt_file_path))
  global_step_restored = global_step.numpy()
  logging.info('get global_step %d from checkpoint %s', global_step_restored,
               ckpt_file_path)
  return global_step_restored


def write_json_summary(log_dir, global_step, eval_metrics):
  """Dump evaluation metrics to json file."""
  serializable_dict = {}
  for name, value in eval_metrics.items():
    if hasattr(value, 'numpy'):
      serializable_dict[name] = str(value.numpy())
    else:
      serializable_dict[name] = str(value)
  output_json = os.path.join(log_dir, 'metrics-{}.json'.format(global_step))
  logging.info('Evaluation results at pretrain step %d: %s', global_step,
               serializable_dict)
  with tf.io.gfile.GFile(output_json, 'w') as writer:
    writer.write(json.dumps(serializable_dict, indent=4) + '\n')


def write_summary(summary_writer, global_step, eval_metrics):
  """Write evaluation metrics to TF summary."""
  numeric_dict = {}
  for name, value in eval_metrics.items():
    numeric_dict[name] = float(orbit.utils.get_value(value))
  with summary_writer.as_default():
    for name, value in numeric_dict.items():
      tf.summary.scalar(name, value, step=global_step)
    summary_writer.flush()


def remove_ckpts(model_dir):
  """Remove model checkpoints, so we can restart."""
  ckpts = os.path.join(model_dir, 'ckpt-*')
  logging.info('removing checkpoint files %s', ckpts)
  for file_to_remove in tf.io.gfile.glob(ckpts):
    tf.io.gfile.rmtree(file_to_remove)

  file_to_remove = os.path.join(model_dir, 'checkpoint')
  if tf.io.gfile.exists(file_to_remove):
    tf.io.gfile.remove(file_to_remove)


def write_model_params(model: Union[tf.Module, tf_keras.Model],
                       output_path: str) -> None:
  """Writes the model parameters and shapes to a file.

  Args:
    model: A model instance.
    output_path: Output file path.
  """
  with tf.io.gfile.GFile(output_path, 'w') as f:
    total_params = 0
    for var in model.variables:
      shape = tf.shape(var)
      total_params += tf.math.reduce_prod(shape).numpy()
      f.write(f'{var.name} {shape.numpy().tolist()}\n')
    f.write(f'\nTotal params: {total_params}\n')


def try_count_params(
    model: Union[tf.Module, tf_keras.Model],
    trainable_only: bool = False):
  """Count the number of parameters if model is possible.

  Args:
    model: Try to count the number of params in this model.
    trainable_only: Whether to calculate trainable params only. This flag is
      not used when the model has `count_params` attribute.

  Returns:
    The number of parameters or None.
  """
  if hasattr(model, 'count_params'):
    try:
      return model.count_params()
    except ValueError:
      logging.info('Number of trainable params unknown, because the build() '
                   'methods in keras layers were not called. This is probably '
                   'because the model was not feed any input, e.g., the max '
                   'train step already reached before this run.')
      return None
  else:
    total_params = 0
    variables = model.trainable_variables if trainable_only else model.variables
    for var in variables:
      shape = tf.shape(var)
      total_params += tf.math.reduce_prod(shape).numpy()
  return total_params


def try_count_flops(model: Union[tf.Module, tf_keras.Model],
                    inputs_kwargs: Optional[Dict[str, Any]] = None,
                    output_path: Optional[str] = None):
  """Counts and returns model FLOPs.

  Args:
    model: A model instance.
    inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
      shape specifications to getting corresponding concrete function.
    output_path: A file path to write the profiling results to.

  Returns:
    The model's FLOPs.
  """
  if hasattr(model, 'inputs'):
    try:
      # Get input shape and set batch size to 1.
      if model.inputs:
        inputs = [
            tf.TensorSpec([1] + input.shape[1:], input.dtype)
            for input in model.inputs
        ]
        concrete_func = tf.function(model).get_concrete_function(inputs)
      # If model.inputs is invalid, try to use the input to get concrete
      # function for model.call (subclass model).
      else:
        concrete_func = tf.function(model.call).get_concrete_function(
            **inputs_kwargs)
      frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func)

      # Calculate FLOPs.
      run_meta = tf.compat.v1.RunMetadata()
      opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
      if output_path is not None:
        opts['output'] = f'file:outfile={output_path}'
      else:
        opts['output'] = 'none'
      flops = tf.compat.v1.profiler.profile(
          graph=frozen_func.graph, run_meta=run_meta, options=opts)
      return flops.total_float_ops
    except Exception as e:  # pylint: disable=broad-except
      logging.info(
          'Failed to count model FLOPs with error %s, because the build() '
          'methods in keras layers were not called. This is probably because '
          'the model was not feed any input, e.g., the max train step already '
          'reached before this run.', e)
      return None
  return None


@ops.RegisterStatistics('Einsum', 'flops')
def _einsum_flops(graph, node):
  """Calculates the compute resources needed for Einsum."""
  assert len(node.input) == 2
  x_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
      graph, node.input[0])
  y_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
      graph, node.input[1])
  x_shape.assert_is_fully_defined()
  y_shape.assert_is_fully_defined()
  x_shape = x_shape.as_list()
  y_shape = y_shape.as_list()
  equation = str(node.attr['equation'])
  equation = (
      equation.replace('s:', '')
      .replace('"', '')
      .replace(' ', '')
      .replace('\n', '')
  )
  x_str = equation.split(',')[0]
  y_r_str = equation.split(',')[1]
  y_str = y_r_str.split('->')[0]
  r_str = y_r_str.split('->')[1]
  shape_dic = {}
  contracted = set()
  for indice in x_str + y_str:
    if indice in x_str:
      indice_dim = x_shape[x_str.find(indice)]
    elif indice in y_str:
      indice_dim = y_shape[y_str.find(indice)]
    else:
      raise ValueError('indice {} not found in inputs'.format(indice))
    shape_dic[indice] = indice_dim
    if indice not in r_str:
      contracted.add(indice)
  madds = np.prod([shape_dic[indice] for indice in r_str]) * (
      np.prod([shape_dic[indice] for indice in contracted]))
  flops = 2 * madds
  return ops.OpStats('flops', flops)