|
|
|
|
|
"""Model construction functions.""" |
|
|
|
import torch |
|
from fvcore.common.registry import Registry |
|
|
|
MODEL_REGISTRY = Registry("MODEL") |
|
MODEL_REGISTRY.__doc__ = """ |
|
Registry for video model. |
|
|
|
The registered object will be called with `obj(cfg)`. |
|
The call should return a `torch.nn.Module` object. |
|
""" |
|
|
|
|
|
def build_model(cfg, gpu_id=None): |
|
""" |
|
Builds the video model. |
|
Args: |
|
cfg (configs): configs that contains the hyper-parameters to build the |
|
backbone. Details can be seen in slowfast/config/defaults.py. |
|
gpu_id (Optional[int]): specify the gpu index to build model. |
|
""" |
|
if torch.cuda.is_available(): |
|
assert ( |
|
cfg.NUM_GPUS <= torch.cuda.device_count() |
|
), "Cannot use more GPU devices than available" |
|
else: |
|
assert ( |
|
cfg.NUM_GPUS == 0 |
|
), "Cuda is not available. Please set `NUM_GPUS: 0 for running on CPUs." |
|
|
|
|
|
name = cfg.MODEL.MODEL_NAME |
|
model = MODEL_REGISTRY.get(name)(cfg) |
|
|
|
if cfg.NUM_GPUS: |
|
if gpu_id is None: |
|
|
|
cur_device = torch.cuda.current_device() |
|
else: |
|
cur_device = gpu_id |
|
|
|
model = model.cuda(device=cur_device) |
|
|
|
|
|
|
|
if cfg.NUM_GPUS > 1: |
|
|
|
model = torch.nn.parallel.DistributedDataParallel( |
|
module=model, device_ids=[cur_device], output_device=cur_device |
|
) |
|
return model |
|
|