Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
# pyre-unsafe | |
import logging | |
import os | |
from typing import Optional | |
import torch.optim | |
from accelerate import Accelerator | |
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase | |
from pytorch3d.implicitron.tools import model_io | |
from pytorch3d.implicitron.tools.config import ( | |
registry, | |
ReplaceableBase, | |
run_auto_creation, | |
) | |
from pytorch3d.implicitron.tools.stats import Stats | |
logger = logging.getLogger(__name__) | |
class ModelFactoryBase(ReplaceableBase): | |
resume: bool = True # resume from the last checkpoint | |
def __call__(self, **kwargs) -> ImplicitronModelBase: | |
""" | |
Initialize the model (possibly from a previously saved state). | |
Returns: An instance of ImplicitronModelBase. | |
""" | |
raise NotImplementedError() | |
def load_stats(self, **kwargs) -> Stats: | |
""" | |
Initialize or load a Stats object. | |
""" | |
raise NotImplementedError() | |
class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13] | |
""" | |
A factory class that initializes an implicit rendering model. | |
Members: | |
model: An ImplicitronModelBase object. | |
resume: If True, attempt to load the last checkpoint from `exp_dir` | |
passed to __call__. Failure to do so will return a model with ini- | |
tial weights unless `force_resume` is True. | |
resume_epoch: If `resume` is True: Resume a model at this epoch, or if | |
`resume_epoch` <= 0, then resume from the latest checkpoint. | |
force_resume: If True, throw a FileNotFoundError if `resume` is True but | |
a model checkpoint cannot be found. | |
""" | |
model: ImplicitronModelBase | |
model_class_type: str = "GenericModel" | |
resume: bool = True | |
resume_epoch: int = -1 | |
force_resume: bool = False | |
def __post_init__(self): | |
run_auto_creation(self) | |
def __call__( | |
self, | |
exp_dir: str, | |
accelerator: Optional[Accelerator] = None, | |
) -> ImplicitronModelBase: | |
""" | |
Returns an instance of `ImplicitronModelBase`, possibly loaded from a | |
checkpoint (if self.resume, self.resume_epoch specify so). | |
Args: | |
exp_dir: Root experiment directory. | |
accelerator: An Accelerator object. | |
Returns: | |
model: The model with optionally loaded weights from checkpoint | |
Raise: | |
FileNotFoundError if `force_resume` is True but checkpoint not found. | |
""" | |
# Determine the network outputs that should be logged | |
if hasattr(self.model, "log_vars"): | |
log_vars = list(self.model.log_vars) | |
else: | |
log_vars = ["objective"] | |
if self.resume_epoch > 0: | |
# Resume from a certain epoch | |
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch) | |
if not os.path.isfile(model_path): | |
raise ValueError(f"Cannot find model from epoch {self.resume_epoch}.") | |
else: | |
# Retrieve the last checkpoint | |
model_path = model_io.find_last_checkpoint(exp_dir) | |
if model_path is not None: | |
logger.info(f"Found previous model {model_path}") | |
if self.force_resume or self.resume: | |
logger.info("Resuming.") | |
map_location = None | |
if accelerator is not None and not accelerator.is_local_main_process: | |
map_location = { | |
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index | |
} | |
model_state_dict = torch.load( | |
model_io.get_model_path(model_path), map_location=map_location | |
) | |
try: | |
self.model.load_state_dict(model_state_dict, strict=True) | |
except RuntimeError as e: | |
logger.error(e) | |
logger.info( | |
"Cannot load state dict in strict mode! -> trying non-strict" | |
) | |
self.model.load_state_dict(model_state_dict, strict=False) | |
self.model.log_vars = log_vars | |
else: | |
logger.info("Not resuming -> starting from scratch.") | |
elif self.force_resume: | |
raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!") | |
return self.model | |