Spaces:
Running
on
L40S
Running
on
L40S
import logging | |
from typing import Any, Optional | |
import torch | |
from omegaconf import DictConfig, OmegaConf | |
from safetensors.torch import load_model | |
def load_config(cfg_path: str) -> Any: | |
""" | |
Load and resolve a configuration file. | |
Args: | |
cfg_path (str): The path to the configuration file. | |
Returns: | |
Any: The loaded and resolved configuration object. | |
Raises: | |
AssertionError: If the loaded configuration is not an instance of DictConfig. | |
""" | |
cfg = OmegaConf.load(cfg_path) | |
OmegaConf.resolve(cfg) | |
assert isinstance(cfg, DictConfig) | |
return cfg | |
def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any: | |
""" | |
Parses a configuration dictionary into a structured configuration object. | |
Args: | |
cfg_type (Any): The type of the structured configuration object. | |
cfg (DictConfig): The configuration dictionary to be parsed. | |
Returns: | |
Any: The structured configuration object created from the dictionary. | |
""" | |
scfg = OmegaConf.structured(cfg_type(**cfg)) | |
return scfg | |
def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None: | |
""" | |
Load a safetensors checkpoint into a PyTorch model. | |
The model is updated in place. | |
Args: | |
model: PyTorch model to load weights into | |
ckpt_path: Path to the safetensors checkpoint file | |
Returns: | |
None | |
""" | |
assert ckpt_path.endswith(".safetensors"), ( | |
f"Checkpoint path '{ckpt_path}' is not a safetensors file" | |
) | |
load_model(model, ckpt_path) | |