from .TriplaneVAE import TriplaneVAE | |
from .Triplane_Diffusion import Triplane_Diff_MultiImgCond_EDM | |
from .Triplane_Diffusion import EDMLoss_MultiImgCond | |
#from .Point_Diffusion_EDM import PointEDM,EDMLoss_PointAug | |
def get_model(model_args): | |
if model_args['type']=="TriVAE": | |
model=TriplaneVAE(model_args) | |
elif model_args['type']=="triplane_diff_multiimg_cond": | |
model=Triplane_Diff_MultiImgCond_EDM(model_args) | |
else: | |
raise NotImplementedError | |
return model | |
def get_criterion(cri_args): | |
if cri_args['type']=="EDMLoss_MultiImgCond": | |
criterion=EDMLoss_MultiImgCond(use_par=cri_args['use_par']) | |
else: | |
raise NotImplementedError | |
return criterion | |