File size: 5,104 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import torch
import folder_paths
import logging
from diffusers import AutoencoderKL
from diffusers.utils import is_xformers_available
from packaging import version
from safetensors.torch import load_file

from memo.models.unet_2d_condition import UNet2DConditionModel 
from memo.models.unet_3d import UNet3DConditionModel
from memo.models.image_proj import ImageProjModel
from memo.models.audio_proj import AudioProjModel
from memo.models.emotion_classifier import AudioEmotionClassifierModel
from memo_model_manager import MemoModelManager

logger = logging.getLogger("memo")

class IF_MemoCheckpointLoader:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "enable_xformers": ("BOOLEAN", {"default": True}),
            }
        }

    RETURN_TYPES = ("MODEL", "MODEL", "VAE", "IMAGE_PROJ", "AUDIO_PROJ", "EMOTION_CLASSIFIER")
    RETURN_NAMES = ("reference_net", "diffusion_net", "vae", "image_proj", "audio_proj", "emotion_classifier")
    FUNCTION = "load_checkpoint"
    CATEGORY = "ImpactFrames💥🎞️/MemoAvatar"

    def __init__(self):
        # Initialize model manager to set up all paths and auxiliary models
        self.model_manager = MemoModelManager()
        self.paths = self.model_manager.get_model_paths()

    def load_checkpoint(self, enable_xformers=True):
        try:
            device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
            dtype = torch.float16 if str(device) == "cuda" else torch.float32
            
            logger.info("Loading models")
            
            # Load VAE
            vae = AutoencoderKL.from_pretrained(
                self.paths["vae"]
            ).to(device=device, dtype=dtype)
            vae.requires_grad_(False)
            vae.eval()

            # Load reference net
            reference_net = UNet2DConditionModel.from_pretrained(
                self.paths["memo_base"],
                subfolder="reference_net",
                use_safetensors=True
            )
            reference_net.requires_grad_(False)
            reference_net.eval()

            # Load diffusion net
            diffusion_net = UNet3DConditionModel.from_pretrained(
                self.paths["memo_base"],
                subfolder="diffusion_net",
                use_safetensors=True
            )
            diffusion_net.requires_grad_(False)
            diffusion_net.eval()

            # Load projectors
            image_proj = ImageProjModel.from_pretrained(
                self.paths["memo_base"],
                subfolder="image_proj",
                use_safetensors=True
            )
            image_proj.requires_grad_(False)
            image_proj.eval()

            audio_proj = AudioProjModel.from_pretrained(
                self.paths["memo_base"],
                subfolder="audio_proj",
                use_safetensors=True
            )
            audio_proj.requires_grad_(False)
            audio_proj.eval()

            # Enable xformers
            if enable_xformers and is_xformers_available():
                try:
                    import xformers
                    xformers_version = version.parse(xformers.__version__)
                    if xformers_version == version.parse("0.0.16"):
                        logger.warning("xFormers 0.0.16 cannot be used for training in some GPUs.")
                    reference_net.enable_xformers_memory_efficient_attention()
                    diffusion_net.enable_xformers_memory_efficient_attention()
                except Exception as e:
                    logger.warning(f"Could not enable xformers: {e}")

            # Move models to device
            for model in [reference_net, diffusion_net, image_proj, audio_proj]:
                model.to(device=device, dtype=dtype)

            # Load emotion classifier
            emotion_classifier = AudioEmotionClassifierModel()
            emotion_classifier_path = os.path.join(
                self.paths["memo_base"], 
                "misc/audio_emotion_classifier/diffusion_pytorch_model.safetensors"
            )
            emotion_classifier.load_state_dict(load_file(emotion_classifier_path))
            emotion_classifier.to(device=device, dtype=dtype)
            emotion_classifier.eval()

            logger.info(f"Models loaded successfully to {device} with dtype {dtype}")
            return (reference_net, diffusion_net, vae, image_proj, audio_proj, emotion_classifier)

        except Exception as e:
            logger.error(f"Error loading models: {e}")
            import traceback
            traceback.print_exc()
            raise RuntimeError(f"Failed to load models: {str(e)}")

    @classmethod
    def IS_CHANGED(s, **kwargs):
        return float("nan")

NODE_CLASS_MAPPINGS = {
    "IF_MemoCheckpointLoader": IF_MemoCheckpointLoader
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "IF_MemoCheckpointLoader": "IF Memo Checkpoint Loader 🎬"
}