LASA / models /__init__.py
HaolinLiu's picture
first commit of codes and update readme.md
cc9780d
raw
history blame
706 Bytes
from .TriplaneVAE import TriplaneVAE
from .Triplane_Diffusion import Triplane_Diff_MultiImgCond_EDM
from .Triplane_Diffusion import EDMLoss_MultiImgCond
#from .Point_Diffusion_EDM import PointEDM,EDMLoss_PointAug
def get_model(model_args):
if model_args['type']=="TriVAE":
model=TriplaneVAE(model_args)
elif model_args['type']=="triplane_diff_multiimg_cond":
model=Triplane_Diff_MultiImgCond_EDM(model_args)
else:
raise NotImplementedError
return model
def get_criterion(cri_args):
if cri_args['type']=="EDMLoss_MultiImgCond":
criterion=EDMLoss_MultiImgCond(use_par=cri_args['use_par'])
else:
raise NotImplementedError
return criterion