File size: 706 Bytes
cc9780d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from .TriplaneVAE import TriplaneVAE
from .Triplane_Diffusion import Triplane_Diff_MultiImgCond_EDM
from .Triplane_Diffusion import EDMLoss_MultiImgCond
#from .Point_Diffusion_EDM import PointEDM,EDMLoss_PointAug

def get_model(model_args):
    if model_args['type']=="TriVAE":
        model=TriplaneVAE(model_args)
    elif model_args['type']=="triplane_diff_multiimg_cond":
        model=Triplane_Diff_MultiImgCond_EDM(model_args)
    else:
        raise NotImplementedError
    return model

def get_criterion(cri_args):
    if cri_args['type']=="EDMLoss_MultiImgCond":
        criterion=EDMLoss_MultiImgCond(use_par=cri_args['use_par'])
    else:
        raise NotImplementedError
    return criterion