|
import logging |
|
|
|
import torch |
|
import torch.utils.checkpoint |
|
from diffusers.models import AutoencoderKLTemporalDecoder |
|
from diffusers.schedulers import EulerDiscreteScheduler |
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection |
|
|
|
from ..modules.unet import UNetSpatioTemporalConditionModel |
|
from ..modules.pose_net import PoseNet |
|
from ..pipelines.pipeline_mimicmotion import MimicMotionPipeline |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class MimicMotionModel(torch.nn.Module): |
|
def __init__(self, base_model_path): |
|
"""construnct base model components and load pretrained svd model except pose-net |
|
Args: |
|
base_model_path (str): pretrained svd model path |
|
""" |
|
super().__init__() |
|
self.unet = UNetSpatioTemporalConditionModel.from_config( |
|
UNetSpatioTemporalConditionModel.load_config(base_model_path, subfolder="unet")) |
|
self.vae = AutoencoderKLTemporalDecoder.from_pretrained( |
|
base_model_path, subfolder="vae").half() |
|
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( |
|
base_model_path, subfolder="image_encoder") |
|
self.noise_scheduler = EulerDiscreteScheduler.from_pretrained( |
|
base_model_path, subfolder="scheduler") |
|
self.feature_extractor = CLIPImageProcessor.from_pretrained( |
|
base_model_path, subfolder="feature_extractor") |
|
|
|
self.pose_net = PoseNet(noise_latent_channels=self.unet.config.block_out_channels[0]) |
|
|
|
def create_pipeline(infer_config, device): |
|
"""create mimicmotion pipeline and load pretrained weight |
|
|
|
Args: |
|
infer_config (str): |
|
device (str or torch.device): "cpu" or "cuda:{device_id}" |
|
""" |
|
mimicmotion_models = MimicMotionModel(infer_config.base_model_path).to(device=device).eval() |
|
mimicmotion_models.load_state_dict(torch.load(infer_config.ckpt_path, map_location=device), strict=False) |
|
pipeline = MimicMotionPipeline( |
|
vae=mimicmotion_models.vae, |
|
image_encoder=mimicmotion_models.image_encoder, |
|
unet=mimicmotion_models.unet, |
|
scheduler=mimicmotion_models.noise_scheduler, |
|
feature_extractor=mimicmotion_models.feature_extractor, |
|
pose_net=mimicmotion_models.pose_net |
|
) |
|
return pipeline |
|
|
|
|