File size: 706 Bytes
cc9780d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from .TriplaneVAE import TriplaneVAE
from .Triplane_Diffusion import Triplane_Diff_MultiImgCond_EDM
from .Triplane_Diffusion import EDMLoss_MultiImgCond
#from .Point_Diffusion_EDM import PointEDM,EDMLoss_PointAug
def get_model(model_args):
if model_args['type']=="TriVAE":
model=TriplaneVAE(model_args)
elif model_args['type']=="triplane_diff_multiimg_cond":
model=Triplane_Diff_MultiImgCond_EDM(model_args)
else:
raise NotImplementedError
return model
def get_criterion(cri_args):
if cri_args['type']=="EDMLoss_MultiImgCond":
criterion=EDMLoss_MultiImgCond(use_par=cri_args['use_par'])
else:
raise NotImplementedError
return criterion
|