|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from .. import tasks |
|
from .. import models |
|
from .. import losses |
|
from ..datasets import MMDataset |
|
from .. import processors |
|
|
|
|
|
class Task(object): |
|
""" |
|
A task refers to one generic training task (e.g., training one model). |
|
""" |
|
|
|
@classmethod |
|
def config_task(cls, config): |
|
""" |
|
determine whether to load a hard-coded task or config from a generic one. |
|
via if a task string is available in config. |
|
""" |
|
if config.task is not None: |
|
|
|
task_cls = getattr(tasks, config.task) |
|
return task_cls(config) |
|
else: |
|
return Task(config) |
|
|
|
def __init__(self, config): |
|
self.config = config |
|
self.train_data = None |
|
self.val_data = None |
|
self.test_data = None |
|
|
|
self.model = None |
|
self.loss_fn = None |
|
self.eval_fn = None |
|
|
|
def build_dataset(self): |
|
"""TODO (huxu): move processor breakdown to MMDataset.""" |
|
"""fill-in `self.train_data`, `self.val_data` and `self.test_data`.""" |
|
|
|
meta_processor_cls = getattr( |
|
processors, self.config.dataset.meta_processor) |
|
video_processor_cls = getattr( |
|
processors, self.config.dataset.video_processor) |
|
text_processor_cls = getattr( |
|
processors, self.config.dataset.text_processor) |
|
aligner_cls = getattr( |
|
processors, self.config.dataset.aligner) |
|
|
|
if self.config.dataset.train_path is not None: |
|
self.config.dataset.split = "train" |
|
|
|
|
|
meta_processor = meta_processor_cls(self.config.dataset) |
|
video_processor = video_processor_cls(self.config.dataset) |
|
text_processor = text_processor_cls(self.config.dataset) |
|
aligner = aligner_cls(self.config.dataset) |
|
self.train_data = MMDataset( |
|
meta_processor, video_processor, text_processor, aligner |
|
) |
|
print("train_len", len(self.train_data)) |
|
output = self.train_data[0] |
|
self.train_data.print_example(output) |
|
if self.config.dataset.val_path is not None: |
|
self.config.dataset.split = "valid" |
|
|
|
meta_processor = meta_processor_cls(self.config.dataset) |
|
video_processor = video_processor_cls(self.config.dataset) |
|
text_processor = text_processor_cls(self.config.dataset) |
|
aligner = aligner_cls(self.config.dataset) |
|
self.val_data = MMDataset( |
|
meta_processor, video_processor, text_processor, aligner |
|
) |
|
print("val_len", len(self.val_data)) |
|
output = self.val_data[0] |
|
self.val_data.print_example(output) |
|
|
|
if self.config.dataset.split == "test": |
|
|
|
meta_processor = meta_processor_cls(self.config.dataset) |
|
video_processor = video_processor_cls(self.config.dataset) |
|
text_processor = text_processor_cls(self.config.dataset) |
|
|
|
self.test_data = MMDataset( |
|
meta_processor, video_processor, text_processor, aligner |
|
) |
|
print("test_len", len(self.test_data)) |
|
output = self.test_data[0] |
|
self.test_data.print_example(output) |
|
|
|
def build_model(self, checkpoint=None): |
|
if self.model is None: |
|
model_cls = getattr(models, self.config.model.model_cls) |
|
self.model = model_cls(self.config) |
|
if checkpoint is not None: |
|
self.load_checkpoint(checkpoint) |
|
return self.model |
|
|
|
def load_checkpoint(self, checkpoint): |
|
if self.model is None: |
|
raise ValueError("model is not initialized.") |
|
state_dict = torch.load(checkpoint) |
|
state_dict = self._trim_state_dict(state_dict) |
|
self.model.load_state_dict(state_dict, strict=False) |
|
|
|
if next(self.model.parameters()).dtype == torch.float16: |
|
self.model = self.model.float() |
|
return self.model |
|
|
|
def _trim_state_dict(self, state_dict): |
|
from collections import OrderedDict |
|
|
|
if "state_dict" in state_dict: |
|
state_dict = state_dict["state_dict"] |
|
if "model" in state_dict: |
|
state_dict = state_dict["model"] |
|
ret_state_dict = OrderedDict() |
|
for ( |
|
key, |
|
value, |
|
) in state_dict.items(): |
|
|
|
if key.startswith("mmmodel"): |
|
key = key[len("mmmodel."):] |
|
ret_state_dict[key] = value |
|
return ret_state_dict |
|
|
|
def build_loss(self): |
|
if self.loss_fn is None and self.config.loss is not None: |
|
loss_cls = getattr(losses, self.config.loss.loss_cls) |
|
self.loss_fn = loss_cls() |
|
return self.loss_fn |
|
|
|
def flat_subsample(self, tensor): |
|
size = tensor.size() |
|
if len(size) >= 2: |
|
batch_size = size[0] * size[1] |
|
expanded_size = ( |
|
(batch_size,) + size[2:] if len(size) > 2 |
|
else (batch_size,) |
|
) |
|
tensor = tensor.view(expanded_size) |
|
return tensor |
|
|
|
def reshape_subsample(self, sample): |
|
if ( |
|
hasattr(self.config.dataset, "subsampling") |
|
and self.config.dataset.subsampling is not None |
|
and self.config.dataset.subsampling > 1 |
|
): |
|
for key in sample: |
|
if torch.is_tensor(sample[key]): |
|
sample[key] = self.flat_subsample(sample[key]) |
|
return sample |
|
|
|
def __call__(self, model, sample): |
|
loss = None |
|
loss_scalar = float("inf") |
|
|
|
sample = self.reshape_subsample(sample) |
|
outputs = self.model(**sample) |
|
sample.update(outputs) |
|
if self.loss_fn is not None: |
|
loss = self.loss_fn(**sample) |
|
loss_scalar = loss.item() |
|
|
|
batch_size = sample["caps"].size(0) |
|
sample_size = 1 |
|
return { |
|
"loss": loss, |
|
"loss_scalar": loss_scalar, |
|
"max_len": self.config.dataset.max_len, |
|
"batch_size": batch_size, |
|
"sample_size": sample_size, |
|
} |
|
|
|
def build_dataloader(self): |
|
"""only used for trainer that lacks building loaders.""" |
|
raise NotImplementedError |
|
|