# Copyright Generate Biomedicines, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Utilities to save and load models with metadata. """ import os import os.path as osp import tempfile from pathlib import Path from urllib.parse import parse_qs, urlparse from uuid import uuid4 import torch import chroma.utility.api as api from chroma.constants.named_models import NAMED_MODELS def save_model(model, weight_file, metadata=None): """Save model, including optional metadata. Args: model (nn.Module): The model to save. Details about the model needed for initialization, such as layer sizes, should be in model.kwargs. weight_file (str): The destination path for saving model weights. metadata (dict): A dictionary of additional metadata to add to the model weights. For example, when saving models during training it can be useful to store `args` representing the CLI args, the date and time of training, etc. """ save_dict = {"init_kwargs": model.kwargs, "model_state_dict": model.state_dict()} if metadata is not None: save_dict.update(metadata) local_path = str( Path(tempfile.gettempdir(), str(uuid4())[:8]) if weight_file.startswith("s3:") else weight_file ) torch.save(save_dict, local_path) if weight_file.startswith("s3:"): raise NotImplementedError("Uploading to an s3 link not supported.") def load_model( weights, model_class, device="cpu", strict=False, strict_unexpected=True, verbose=True, ): """Load model saved with save_model. Args: weights (str): The destination path of the model weights to load. Compatible with files saved by `save_model`. model_class: Name of model class. device (str, optional): Pytorch device specification, e.g. `'cuda'` for GPU. Default is `'cpu'`. strict (bool): Whether to require that the keys match between the input file weights and the model created from the parameters stored in the model kwargs. strict_unexpected (bool): Whether to require that there are no unexpected keys when loading model weights, as distinct from the strict option which doesn't allow for missing keys either. By default, we use this option rather than strict for ease of development when adding model features. verbose (bool, optional): Show outputs from download and loading. Default True. Returns: model (nn.Module): Torch model with loaded weights. """ # Process weights path if str(weights).startswith("named:"): weights = weights.split("named:")[1] if weights not in NAMED_MODELS[model_class.__name__]: raise Exception(f"Unknown {model_class.__name__} model name: {weights},") weights = NAMED_MODELS[model_class.__name__][weights]["s3_uri"] # resolve s3 paths if str(weights).startswith("s3:"): raise NotImplementedError("Loading Models from an S3 link not supported.") # download public models from generate if str(weights).startswith("https:"): # Decompose into arguments parsed_url = urlparse(weights) base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}" model_name = parse_qs(parsed_url.query).get("weights", [None])[0] weights = api.download_from_generate( base_url, model_name, force=False, exist_ok=True ) # load model weights params = torch.load(weights, map_location="cpu") model = model_class(**params["init_kwargs"]).to(device) missing_keys, unexpected_keys = model.load_state_dict( params["model_state_dict"], strict=strict ) if strict_unexpected and len(unexpected_keys) > 0: raise Exception( f"Error loading model from checkpoint file: {weights} contains {len(unexpected_keys)} unexpected keys: {unexpected_keys}" ) return model