import logging from typing import Any, Optional import torch from omegaconf import DictConfig, OmegaConf from safetensors.torch import load_model def load_config(cfg_path: str) -> Any: """ Load and resolve a configuration file. Args: cfg_path (str): The path to the configuration file. Returns: Any: The loaded and resolved configuration object. Raises: AssertionError: If the loaded configuration is not an instance of DictConfig. """ cfg = OmegaConf.load(cfg_path) OmegaConf.resolve(cfg) assert isinstance(cfg, DictConfig) return cfg def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any: """ Parses a configuration dictionary into a structured configuration object. Args: cfg_type (Any): The type of the structured configuration object. cfg (DictConfig): The configuration dictionary to be parsed. Returns: Any: The structured configuration object created from the dictionary. """ scfg = OmegaConf.structured(cfg_type(**cfg)) return scfg def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None: """ Load a safetensors checkpoint into a PyTorch model. The model is updated in place. Args: model: PyTorch model to load weights into ckpt_path: Path to the safetensors checkpoint file Returns: None """ assert ckpt_path.endswith(".safetensors"), ( f"Checkpoint path '{ckpt_path}' is not a safetensors file" ) load_model(model, ckpt_path)