from typing import Optional from nitrous_ema import PostHocEMA from omegaconf import DictConfig from mmaudio.model.networks import get_my_mmaudio def synthesize_ema(cfg: DictConfig, sigma: float, step: Optional[int]): vae = get_my_mmaudio(cfg.model) emas = PostHocEMA(vae, sigma_rels=cfg.ema.sigma_rels, update_every=cfg.ema.update_every, checkpoint_every_num_steps=cfg.ema.checkpoint_every, checkpoint_folder=cfg.ema.checkpoint_folder) synthesized_ema = emas.synthesize_ema_model(sigma_rel=sigma, step=step, device='cpu') state_dict = synthesized_ema.ema_model.state_dict() return state_dict