File size: 10,683 Bytes
6ef31de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
import functools

from absl import app
from absl import flags
from absl import logging

from clu import metric_writers
from clu import platform
import flax.linen as nn
import jax
from ml_collections import config_flags
import tensorflow as tf

import sys, os
from pathlib import Path
# append current path to sys.path
sys.path.append(str(Path(__file__).parent.parent.parent / "scenic"))

import logging
import flax
from flax import jax_utils
from flax.training import checkpoints
from scenic.projects.vid2seq import models, trainer
from scenic.train_lib_deprecated import train_utils
from scenic import app
import ml_collections
import numpy as np
import jax.numpy as jnp
from clu import metric_writers
from scenic.projects.vid2seq.datasets.dense_video_captioning_tfrecord_dataset import get_datasets
from scenic.projects.vid2seq import dvc_eval

MAX_CAPTION_STR_LEN = 200
MAX_KEY_STR_LEN = 400

class ScenicModel:
    def __init__(self, flags):
        self.FLAGS = flags
        jax.config.config_with_absl()
        run = (functools.partial(self._run_main, main=self._init_model))
        run(self._init_model)
    def _run_main(self, argv, *, main):
        """Runs the `main` method after some initial setup."""
        del argv
        # Hide any GPUs form TensorFlow. Otherwise, TF might reserve memory and make
        # it unavailable to JAX.
        tf.config.experimental.set_visible_devices([], 'GPU')

        # Enable wrapping of all module calls in a named_call for easier profiling:
        nn.enable_named_call()

        logging.info('JAX host: %d / %d', jax.process_index(), jax.process_count())
        logging.info('JAX devices: %r', jax.devices())

        # Add a note so that we can tell which task is which JAX host.
        # (task 0 is not guaranteed to be the host 0)
        platform.work_unit().set_task_status(
            f'host_id: {jax.process_index()}, host_count: {jax.process_count()}')
        if jax.process_index() == 0:
            platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                                self.FLAGS.workdir, 'Workdir')
        self.FLAGS.config.dataset_configs.base_dir = self.FLAGS.data_dir
        self.FLAGS.config.init_from.checkpoint_path = self.FLAGS.ckpt_dir
        rng = jax.random.PRNGKey(self.FLAGS.config.rng_seed)
        logging.info('RNG: %s', rng)

        writer = metric_writers.create_default_writer(
            self.FLAGS.workdir, just_logging=jax.process_index() > 0, asynchronous=True)

        return main(rng=rng, config=self.FLAGS.config, workdir=self.FLAGS.workdir, writer=writer)


    def _init_model(self, rng: jnp.ndarray, config: ml_collections.ConfigDict, workdir: str,
         writer: metric_writers.MetricWriter):
        data_rng, rng = jax.random.split(rng)
        dataset_dict = get_datasets(config, data_rng=data_rng)

        datasets_metadata = {
            name: ds.meta_data
            for name, ds in dataset_dict.items()
        }
        all_datasets = []
        all_datasets_num_train_examples = []
        for name, metadata in datasets_metadata.items():
            all_datasets.append(name)
            all_datasets_num_train_examples.append(
                metadata.get('num_train_examples', 0))
        dataset = dataset_dict[all_datasets[0]]

        model_cls = models.DenseVideoCaptioningModel
        model = model_cls(config, dataset.meta_data)
        train_state, start_step = trainer.init_state(model, dataset, config,
                                                    workdir, rng)

        self.train_state = jax_utils.replicate(train_state)
        logging.info('Number of processes is %s', jax.process_count())
        del rng

        import functools
        self.infer_step_pmapped = jax.pmap(
        functools.partial(
            trainer.infer_step,
            model=model,
            config=config,
            debug=config.debug_eval),
        axis_name='batch',
        )
        
        self.tokenizer = trainer.get_tokenizer(config)
        # dsname = 'validation'
        # self.iterator = dataset.valid_iter[dsname]
        
        self.config = config
        self.data_rng = data_rng
    
    def __call__(self, data_dir=None):
        # self.FLAGS.config.dataset_configs.base_dir = data_dir
        dataset_dict = get_datasets(self.config, data_rng=self.data_rng)
        self.iterator = dataset_dict["youcook"].valid_iter['validation']
        batch = next(self.iterator)
        
        train_state = train_utils.sync_model_state_across_replicas(self.train_state)
        eval_packs = {}
        keys = []
        eval_pack = {
        'gts':
            dvc_eval.convert_strings_to_uint8_arrays(
                batch['caption_strings'], MAX_CAPTION_STR_LEN),
        'key':
            dvc_eval.convert_strings_to_uint8_arrays(
                batch['videoid'], MAX_KEY_STR_LEN),
        'batch_mask':
            batch['batch_mask'],
        'duration':
            batch['duration'],
        'gts_start':
            batch['timestamp_start'],
        'gts_end':
            batch['timestamp_end'],
        'split':
            batch['split'] if 'split' in batch else
            np.ones_like(batch['timestamp_start']),
        }
        to_del = ['caption_strings', 'key', 'videoid', 'timestamp_start',
                    'timestamp_end', 'split']  # 'duration',
        for x in to_del:
            if x in batch:
                del batch[x]
        
        # import pdb
        # pdb.set_trace()
        
        _, preds = self.infer_step_pmapped(train_state, batch) #model, config)
        # import pdb
        # pdb.set_trace()
        eval_pack['pred'] = preds
        eval_pack = jax.tree_map(
            lambda x: x.reshape((np.prod(x.shape[:2]),) + x.shape[2:]), eval_pack)
        
        vocabulary_size = self.config.dataset_configs.vocabulary_size
        # pred_text = trainer.decode_tokens(preds, tokenizer, vocabulary_size)

        # print(preds, pred_text)
        format_outputs = []
        for i, valid in enumerate(eval_pack['batch_mask']):
            print("===============video[", str(0), "]====================")
            if valid:
                key = dvc_eval.convert_uint8_array_to_string(eval_pack['key'][i])
            if key in eval_packs:  # redundant video
                continue
            keys.append(key)

            pred, pred_timestamps = [], []
            # get indexes in the predicted seq that delimit the pred segments
            indexes = [
                j for j in range(len(eval_pack['pred'][i]) - 1)
                if eval_pack['pred'][i][j] >= vocabulary_size and
                eval_pack['pred'][i][j + 1] >= vocabulary_size
            ]  # pylint: disable=g-complex-comprehension

            last_processed = -2
            order = self.config.dataset_configs.order

            # iterate over predicted segments and decode them
            for j in range(len(indexes)):
                if indexes[j] == last_processed + 1:  # 3 timestamps != 2 events
                    continue
                
                # get predicted tokens and transform to string
                if order == 'ld':
                    start_idx = indexes[j] + 2
                    end_idx = indexes[j + 1] if j < len(indexes) - 1 else len(
                    eval_pack['pred'][i])
                else:
                    start_idx = indexes[j - 1] + 2 if j > 0 else 0
                    end_idx = indexes[j]
                pred_seq = [int(eval_pack['pred'][i][k]) for k in range(start_idx, end_idx)]
                pred_text = trainer.decode_tokens(pred_seq, self.tokenizer, vocabulary_size)

                # get start and end
                num_bins = 100 # from config
                max_offset = num_bins - 1
                pred_time = [
                    (int(eval_pack['pred'][i][indexes[j]])
                    - vocabulary_size) *
                    eval_pack['duration'][i] / max_offset,
                    (int(eval_pack['pred'][i][indexes[j] + 1]) -
                    vocabulary_size) *
                    eval_pack['duration'][i] / max_offset
                    ]
                
                # if pred_time[1] <= pred_time[0]:  # remove end < start
                #   continue
                last_processed = indexes[j]

                pred.append(pred_text)
                pred_timestamps.append(pred_time)
                
                # round to 2 decimal places
                format_output = "[{x}s, {y}s] ".format(x=np.around(pred_time[0][0]/1000000, decimals=2), y=np.around(pred_time[1][0]/1000000, decimals=2))
                format_output += pred_text
                format_outputs.append(format_output)
            print(format_outputs)
            print("===============================================")
            return format_outputs

class ScenicCall:
    def __init__(self, main, flags):
        self.main = main
        self.FLAGS = flags
        
    def __call__(self):
        return self.run()

    def run(self):
        # Provide access to --jax_backend_target and --jax_xla_backend flags.
        jax.config.config_with_absl()
        run = (functools.partial(self._run_main, main=self.main))
        return run(self.main)
        
    def _run_main(self, argv, *, main):
        """Runs the `main` method after some initial setup."""
        del argv
        # Hide any GPUs form TensorFlow. Otherwise, TF might reserve memory and make
        # it unavailable to JAX.
        tf.config.experimental.set_visible_devices([], 'GPU')

        # Enable wrapping of all module calls in a named_call for easier profiling:
        nn.enable_named_call()

        logging.info('JAX host: %d / %d', jax.process_index(), jax.process_count())
        logging.info('JAX devices: %r', jax.devices())

        # Add a note so that we can tell which task is which JAX host.
        # (task 0 is not guaranteed to be the host 0)
        platform.work_unit().set_task_status(
            f'host_id: {jax.process_index()}, host_count: {jax.process_count()}')
        if jax.process_index() == 0:
            platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,
                                                self.FLAGS.workdir, 'Workdir')
        self.FLAGS.config.dataset_configs.base_dir = self.FLAGS.data_dir
        rng = jax.random.PRNGKey(self.FLAGS.config.rng_seed)
        logging.info('RNG: %s', rng)

        writer = metric_writers.create_default_writer(
            self.FLAGS.workdir, just_logging=jax.process_index() > 0, asynchronous=True)

        return main(rng=rng, config=self.FLAGS.config, workdir=self.FLAGS.workdir, writer=writer)