Spaces:
Running
on
L40S
Running
on
L40S
File size: 1,554 Bytes
616f571 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import logging
from typing import Any, Optional
import torch
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_model
def load_config(cfg_path: str) -> Any:
"""
Load and resolve a configuration file.
Args:
cfg_path (str): The path to the configuration file.
Returns:
Any: The loaded and resolved configuration object.
Raises:
AssertionError: If the loaded configuration is not an instance of DictConfig.
"""
cfg = OmegaConf.load(cfg_path)
OmegaConf.resolve(cfg)
assert isinstance(cfg, DictConfig)
return cfg
def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any:
"""
Parses a configuration dictionary into a structured configuration object.
Args:
cfg_type (Any): The type of the structured configuration object.
cfg (DictConfig): The configuration dictionary to be parsed.
Returns:
Any: The structured configuration object created from the dictionary.
"""
scfg = OmegaConf.structured(cfg_type(**cfg))
return scfg
def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
"""
Load a safetensors checkpoint into a PyTorch model.
The model is updated in place.
Args:
model: PyTorch model to load weights into
ckpt_path: Path to the safetensors checkpoint file
Returns:
None
"""
assert ckpt_path.endswith(".safetensors"), (
f"Checkpoint path '{ckpt_path}' is not a safetensors file"
)
load_model(model, ckpt_path)
|