Spaces:
Sleeping
Sleeping
File size: 4,541 Bytes
ce7bf5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# Copyright Generate Biomedicines, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utilities to save and load models with metadata.
"""
import os
import os.path as osp
import tempfile
from pathlib import Path
from urllib.parse import parse_qs, urlparse
from uuid import uuid4
import torch
import chroma.utility.api as api
from chroma.constants.named_models import NAMED_MODELS
def save_model(model, weight_file, metadata=None):
"""Save model, including optional metadata.
Args:
model (nn.Module): The model to save. Details about the model needed
for initialization, such as layer sizes, should be in model.kwargs.
weight_file (str): The destination path for saving model weights.
metadata (dict): A dictionary of additional metadata to add to the model
weights. For example, when saving models during training it can be
useful to store `args` representing the CLI args, the date and time
of training, etc.
"""
save_dict = {"init_kwargs": model.kwargs, "model_state_dict": model.state_dict()}
if metadata is not None:
save_dict.update(metadata)
local_path = str(
Path(tempfile.gettempdir(), str(uuid4())[:8])
if weight_file.startswith("s3:")
else weight_file
)
torch.save(save_dict, local_path)
if weight_file.startswith("s3:"):
raise NotImplementedError("Uploading to an s3 link not supported.")
def load_model(
weights,
model_class,
device="cpu",
strict=False,
strict_unexpected=True,
verbose=True,
):
"""Load model saved with save_model.
Args:
weights (str): The destination path of the model weights to load.
Compatible with files saved by `save_model`.
model_class: Name of model class.
device (str, optional): Pytorch device specification, e.g. `'cuda'` for
GPU. Default is `'cpu'`.
strict (bool): Whether to require that the keys match between the
input file weights and the model created from the parameters stored
in the model kwargs.
strict_unexpected (bool): Whether to require that there are no
unexpected keys when loading model weights, as distinct from the
strict option which doesn't allow for missing keys either. By
default, we use this option rather than strict for ease of
development when adding model features.
verbose (bool, optional): Show outputs from download and loading. Default True.
Returns:
model (nn.Module): Torch model with loaded weights.
"""
# Process weights path
if str(weights).startswith("named:"):
weights = weights.split("named:")[1]
if weights not in NAMED_MODELS[model_class.__name__]:
raise Exception(f"Unknown {model_class.__name__} model name: {weights},")
weights = NAMED_MODELS[model_class.__name__][weights]["s3_uri"]
# resolve s3 paths
if str(weights).startswith("s3:"):
raise NotImplementedError("Loading Models from an S3 link not supported.")
# download public models from generate
if str(weights).startswith("https:"):
# Decompose into arguments
parsed_url = urlparse(weights)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}"
model_name = parse_qs(parsed_url.query).get("weights", [None])[0]
weights = api.download_from_generate(
base_url, model_name, force=False, exist_ok=True
)
# load model weights
params = torch.load(weights, map_location="cpu")
model = model_class(**params["init_kwargs"]).to(device)
missing_keys, unexpected_keys = model.load_state_dict(
params["model_state_dict"], strict=strict
)
if strict_unexpected and len(unexpected_keys) > 0:
raise Exception(
f"Error loading model from checkpoint file: {weights} contains {len(unexpected_keys)} unexpected keys: {unexpected_keys}"
)
return model
|