File size: 706 Bytes
1fd4e9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from typing import Optional

from nitrous_ema import PostHocEMA
from omegaconf import DictConfig

from mmaudio.model.networks import get_my_mmaudio


def synthesize_ema(cfg: DictConfig, sigma: float, step: Optional[int]):
    vae = get_my_mmaudio(cfg.model)
    emas = PostHocEMA(vae,
                      sigma_rels=cfg.ema.sigma_rels,
                      update_every=cfg.ema.update_every,
                      checkpoint_every_num_steps=cfg.ema.checkpoint_every,
                      checkpoint_folder=cfg.ema.checkpoint_folder)

    synthesized_ema = emas.synthesize_ema_model(sigma_rel=sigma, step=step, device='cpu')
    state_dict = synthesized_ema.ema_model.state_dict()
    return state_dict