Akash Garg
adding cube sources
616f571
import logging
from typing import Any, Optional
import torch
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_model
def load_config(cfg_path: str) -> Any:
"""
Load and resolve a configuration file.
Args:
cfg_path (str): The path to the configuration file.
Returns:
Any: The loaded and resolved configuration object.
Raises:
AssertionError: If the loaded configuration is not an instance of DictConfig.
"""
cfg = OmegaConf.load(cfg_path)
OmegaConf.resolve(cfg)
assert isinstance(cfg, DictConfig)
return cfg
def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any:
"""
Parses a configuration dictionary into a structured configuration object.
Args:
cfg_type (Any): The type of the structured configuration object.
cfg (DictConfig): The configuration dictionary to be parsed.
Returns:
Any: The structured configuration object created from the dictionary.
"""
scfg = OmegaConf.structured(cfg_type(**cfg))
return scfg
def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
"""
Load a safetensors checkpoint into a PyTorch model.
The model is updated in place.
Args:
model: PyTorch model to load weights into
ckpt_path: Path to the safetensors checkpoint file
Returns:
None
"""
assert ckpt_path.endswith(".safetensors"), (
f"Checkpoint path '{ckpt_path}' is not a safetensors file"
)
load_model(model, ckpt_path)