Spaces:
Runtime error
Runtime error
""" | |
This script is a gradio web ui. | |
The script takes an image and an audio clip, and lets you configure all the | |
variables such as cfg_scale, pose_weight, face_weight, lip_weight, etc. | |
Usage: | |
This script can be run from the command line with the following command: | |
python scripts/app.py | |
""" | |
import gradio as gr | |
import argparse | |
import copy | |
import logging | |
import math | |
import os | |
import random | |
import time | |
import warnings | |
from datetime import datetime | |
from typing import List, Tuple | |
import diffusers | |
import mlflow | |
import torch | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
import transformers | |
from accelerate import Accelerator | |
from accelerate.logging import get_logger | |
from accelerate.utils import DistributedDataParallelKwargs | |
from diffusers import AutoencoderKL, DDIMScheduler | |
from diffusers.optimization import get_scheduler | |
from diffusers.utils import check_min_version | |
from diffusers.utils.import_utils import is_xformers_available | |
from einops import rearrange, repeat | |
from omegaconf import OmegaConf | |
from torch import nn | |
from tqdm.auto import tqdm | |
import uuid | |
import sys | |
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) | |
from joyhallo.animate.face_animate import FaceAnimatePipeline | |
from joyhallo.datasets.audio_processor import AudioProcessor | |
from joyhallo.datasets.image_processor import ImageProcessor | |
from joyhallo.datasets.talk_video import TalkingVideoDataset | |
from joyhallo.models.audio_proj import AudioProjModel | |
from joyhallo.models.face_locator import FaceLocator | |
from joyhallo.models.image_proj import ImageProjModel | |
from joyhallo.models.mutual_self_attention import ReferenceAttentionControl | |
from joyhallo.models.unet_2d_condition import UNet2DConditionModel | |
from joyhallo.models.unet_3d import UNet3DConditionModel | |
from joyhallo.utils.util import (compute_snr, delete_additional_ckpt, | |
import_filename, init_output_dir, | |
load_checkpoint, save_checkpoint, | |
seed_everything, tensor_to_video) | |
warnings.filterwarnings("ignore") | |
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. | |
check_min_version("0.10.0.dev0") | |
logger = get_logger(__name__, log_level="INFO") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class Net(nn.Module): | |
""" | |
The Net class defines a neural network model that combines a reference UNet2DConditionModel, | |
a denoising UNet3DConditionModel, a face locator, and other components to animate a face in a static image. | |
Args: | |
reference_unet (UNet2DConditionModel): The reference UNet2DConditionModel used for face animation. | |
denoising_unet (UNet3DConditionModel): The denoising UNet3DConditionModel used for face animation. | |
face_locator (FaceLocator): The face locator model used for face animation. | |
reference_control_writer: The reference control writer component. | |
reference_control_reader: The reference control reader component. | |
imageproj: The image projection model. | |
audioproj: The audio projection model. | |
Forward method: | |
noisy_latents (torch.Tensor): The noisy latents tensor. | |
timesteps (torch.Tensor): The timesteps tensor. | |
ref_image_latents (torch.Tensor): The reference image latents tensor. | |
face_emb (torch.Tensor): The face embeddings tensor. | |
audio_emb (torch.Tensor): The audio embeddings tensor. | |
mask (torch.Tensor): Hard face mask for face locator. | |
full_mask (torch.Tensor): Pose Mask. | |
face_mask (torch.Tensor): Face Mask | |
lip_mask (torch.Tensor): Lip Mask | |
uncond_img_fwd (bool): A flag indicating whether to perform reference image unconditional forward pass. | |
uncond_audio_fwd (bool): A flag indicating whether to perform audio unconditional forward pass. | |
Returns: | |
torch.Tensor: The output tensor of the neural network model. | |
""" | |
def __init__( | |
self, | |
reference_unet: UNet2DConditionModel, | |
denoising_unet: UNet3DConditionModel, | |
face_locator: FaceLocator, | |
reference_control_writer, | |
reference_control_reader, | |
imageproj, | |
audioproj, | |
): | |
super().__init__() | |
self.reference_unet = reference_unet | |
self.denoising_unet = denoising_unet | |
self.face_locator = face_locator | |
self.reference_control_writer = reference_control_writer | |
self.reference_control_reader = reference_control_reader | |
self.imageproj = imageproj | |
self.audioproj = audioproj | |
def forward( | |
self, | |
noisy_latents: torch.Tensor, | |
timesteps: torch.Tensor, | |
ref_image_latents: torch.Tensor, | |
face_emb: torch.Tensor, | |
audio_emb: torch.Tensor, | |
mask: torch.Tensor, | |
full_mask: torch.Tensor, | |
face_mask: torch.Tensor, | |
lip_mask: torch.Tensor, | |
uncond_img_fwd: bool = False, | |
uncond_audio_fwd: bool = False, | |
): | |
""" | |
simple docstring to prevent pylint error | |
""" | |
face_emb = self.imageproj(face_emb) | |
mask = mask.to(device=device) | |
mask_feature = self.face_locator(mask) | |
audio_emb = audio_emb.to( | |
device=self.audioproj.device, dtype=self.audioproj.dtype) | |
audio_emb = self.audioproj(audio_emb) | |
# condition forward | |
if not uncond_img_fwd: | |
ref_timesteps = torch.zeros_like(timesteps) | |
ref_timesteps = repeat( | |
ref_timesteps, | |
"b -> (repeat b)", | |
repeat=ref_image_latents.size(0) // ref_timesteps.size(0), | |
) | |
self.reference_unet( | |
ref_image_latents, | |
ref_timesteps, | |
encoder_hidden_states=face_emb, | |
return_dict=False, | |
) | |
self.reference_control_reader.update(self.reference_control_writer) | |
if uncond_audio_fwd: | |
audio_emb = torch.zeros_like(audio_emb).to( | |
device=audio_emb.device, dtype=audio_emb.dtype | |
) | |
model_pred = self.denoising_unet( | |
noisy_latents, | |
timesteps, | |
mask_cond_fea=mask_feature, | |
encoder_hidden_states=face_emb, | |
audio_embedding=audio_emb, | |
full_mask=full_mask, | |
face_mask=face_mask, | |
lip_mask=lip_mask | |
).sample | |
return model_pred | |
def get_attention_mask(mask: torch.Tensor, weight_dtype: torch.dtype) -> torch.Tensor: | |
""" | |
Rearrange the mask tensors to the required format. | |
Args: | |
mask (torch.Tensor): The input mask tensor. | |
weight_dtype (torch.dtype): The data type for the mask tensor. | |
Returns: | |
torch.Tensor: The rearranged mask tensor. | |
""" | |
if isinstance(mask, List): | |
_mask = [] | |
for m in mask: | |
_mask.append( | |
rearrange(m, "b f 1 h w -> (b f) (h w)").to(weight_dtype)) | |
return _mask | |
mask = rearrange(mask, "b f 1 h w -> (b f) (h w)").to(weight_dtype) | |
return mask | |
def get_noise_scheduler(cfg: argparse.Namespace) -> Tuple[DDIMScheduler, DDIMScheduler]: | |
""" | |
Create noise scheduler for training. | |
Args: | |
cfg (argparse.Namespace): Configuration object. | |
Returns: | |
Tuple[DDIMScheduler, DDIMScheduler]: Train noise scheduler and validation noise scheduler. | |
""" | |
sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs) | |
if cfg.enable_zero_snr: | |
sched_kwargs.update( | |
rescale_betas_zero_snr=True, | |
timestep_spacing="trailing", | |
prediction_type="v_prediction", | |
) | |
val_noise_scheduler = DDIMScheduler(**sched_kwargs) | |
sched_kwargs.update({"beta_schedule": "scaled_linear"}) | |
train_noise_scheduler = DDIMScheduler(**sched_kwargs) | |
return train_noise_scheduler, val_noise_scheduler | |
def process_audio_emb(audio_emb: torch.Tensor) -> torch.Tensor: | |
""" | |
Process the audio embedding to concatenate with other tensors. | |
Parameters: | |
audio_emb (torch.Tensor): The audio embedding tensor to process. | |
Returns: | |
concatenated_tensors (List[torch.Tensor]): The concatenated tensor list. | |
""" | |
concatenated_tensors = [] | |
for i in range(audio_emb.shape[0]): | |
vectors_to_concat = [ | |
audio_emb[max(min(i + j, audio_emb.shape[0] - 1), 0)]for j in range(-2, 3)] | |
concatenated_tensors.append(torch.stack(vectors_to_concat, dim=0)) | |
audio_emb = torch.stack(concatenated_tensors, dim=0) | |
return audio_emb | |
def log_validation( | |
accelerator: Accelerator, | |
vae: AutoencoderKL, | |
net: Net, | |
scheduler: DDIMScheduler, | |
width: int, | |
height: int, | |
clip_length: int = 24, | |
generator: torch.Generator = None, | |
cfg: dict = None, | |
save_dir: str = None, | |
global_step: int = 0, | |
times: int = None, | |
face_analysis_model_path: str = "", | |
) -> None: | |
""" | |
Log validation video during the training process. | |
Args: | |
accelerator (Accelerator): The accelerator for distributed training. | |
vae (AutoencoderKL): The autoencoder model. | |
net (Net): The main neural network model. | |
scheduler (DDIMScheduler): The scheduler for noise. | |
width (int): The width of the input images. | |
height (int): The height of the input images. | |
clip_length (int): The length of the video clips. Defaults to 24. | |
generator (torch.Generator): The random number generator. Defaults to None. | |
cfg (dict): The configuration dictionary. Defaults to None. | |
save_dir (str): The directory to save validation results. Defaults to None. | |
global_step (int): The current global step in training. Defaults to 0. | |
times (int): The number of inference times. Defaults to None. | |
face_analysis_model_path (str): The path to the face analysis model. Defaults to "". | |
Returns: | |
torch.Tensor: The tensor result of the validation. | |
""" | |
ori_net = accelerator.unwrap_model(net) | |
reference_unet = ori_net.reference_unet | |
denoising_unet = ori_net.denoising_unet | |
face_locator = ori_net.face_locator | |
imageproj = ori_net.imageproj | |
audioproj = ori_net.audioproj | |
tmp_denoising_unet = copy.deepcopy(denoising_unet) | |
pipeline = FaceAnimatePipeline( | |
vae=vae, | |
reference_unet=reference_unet, | |
denoising_unet=tmp_denoising_unet, | |
face_locator=face_locator, | |
image_proj=imageproj, | |
scheduler=scheduler, | |
) | |
pipeline = pipeline.to(device) | |
image_processor = ImageProcessor((width, height), face_analysis_model_path) | |
audio_processor = AudioProcessor( | |
cfg.data.sample_rate, | |
cfg.data.fps, | |
cfg.wav2vec_config.model_path, | |
cfg.wav2vec_config.features == "last", | |
os.path.dirname(cfg.audio_separator.model_path), | |
os.path.basename(cfg.audio_separator.model_path), | |
os.path.join(save_dir, '.cache', "audio_preprocess"), | |
device=device, | |
) | |
return cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length | |
def inference(cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length): | |
ref_img_path = cfg.ref_img_path | |
audio_path = cfg.audio_path | |
source_image_pixels, \ | |
source_image_face_region, \ | |
source_image_face_emb, \ | |
source_image_full_mask, \ | |
source_image_face_mask, \ | |
source_image_lip_mask = image_processor.preprocess( | |
ref_img_path, os.path.join(save_dir, '.cache'), cfg.face_expand_ratio) | |
audio_emb, audio_length = audio_processor.preprocess( | |
audio_path, clip_length) | |
audio_emb = process_audio_emb(audio_emb) | |
source_image_pixels = source_image_pixels.unsqueeze(0) | |
source_image_face_region = source_image_face_region.unsqueeze(0) | |
source_image_face_emb = source_image_face_emb.reshape(1, -1) | |
source_image_face_emb = torch.tensor(source_image_face_emb) | |
source_image_full_mask = [ | |
(mask.repeat(clip_length, 1)) | |
for mask in source_image_full_mask | |
] | |
source_image_face_mask = [ | |
(mask.repeat(clip_length, 1)) | |
for mask in source_image_face_mask | |
] | |
source_image_lip_mask = [ | |
(mask.repeat(clip_length, 1)) | |
for mask in source_image_lip_mask | |
] | |
times = audio_emb.shape[0] // clip_length | |
tensor_result = [] | |
generator = torch.manual_seed(42) | |
for t in range(times): | |
print(f"[{t+1}/{times}]") | |
if len(tensor_result) == 0: | |
# The first iteration | |
motion_zeros = source_image_pixels.repeat( | |
cfg.data.n_motion_frames, 1, 1, 1) | |
motion_zeros = motion_zeros.to( | |
dtype=source_image_pixels.dtype, device=source_image_pixels.device) | |
pixel_values_ref_img = torch.cat( | |
[source_image_pixels, motion_zeros], dim=0) # concat the ref image and the first motion frames | |
else: | |
motion_frames = tensor_result[-1][0] | |
motion_frames = motion_frames.permute(1, 0, 2, 3) | |
motion_frames = motion_frames[0 - cfg.data.n_motion_frames:] | |
motion_frames = motion_frames * 2.0 - 1.0 | |
motion_frames = motion_frames.to( | |
dtype=source_image_pixels.dtype, device=source_image_pixels.device) | |
pixel_values_ref_img = torch.cat( | |
[source_image_pixels, motion_frames], dim=0) # concat the ref image and the motion frames | |
pixel_values_ref_img = pixel_values_ref_img.unsqueeze(0) | |
audio_tensor = audio_emb[ | |
t * clip_length: min((t + 1) * clip_length, audio_emb.shape[0]) | |
] | |
audio_tensor = audio_tensor.unsqueeze(0) | |
audio_tensor = audio_tensor.to( | |
device=audioproj.device, dtype=audioproj.dtype) | |
audio_tensor = audioproj(audio_tensor) | |
pipeline_output = pipeline( | |
ref_image=pixel_values_ref_img, | |
audio_tensor=audio_tensor, | |
face_emb=source_image_face_emb, | |
face_mask=source_image_face_region, | |
pixel_values_full_mask=source_image_full_mask, | |
pixel_values_face_mask=source_image_face_mask, | |
pixel_values_lip_mask=source_image_lip_mask, | |
width=cfg.data.train_width, | |
height=cfg.data.train_height, | |
video_length=clip_length, | |
num_inference_steps=cfg.inference_steps, | |
guidance_scale=cfg.cfg_scale, | |
generator=generator, | |
) | |
tensor_result.append(pipeline_output.videos) | |
tensor_result = torch.cat(tensor_result, dim=2) | |
tensor_result = tensor_result.squeeze(0) | |
tensor_result = tensor_result[:, :audio_length] | |
output_file = cfg.output | |
tensor_to_video(tensor_result, output_file, audio_path) | |
return output_file | |
def get_model(cfg: argparse.Namespace) -> None: | |
""" | |
Trains the model using the given configuration (cfg). | |
Args: | |
cfg (dict): The configuration dictionary containing the parameters for training. | |
Notes: | |
- This function trains the model using the given configuration. | |
- It initializes the necessary components for training, such as the pipeline, optimizer, and scheduler. | |
- The training progress is logged and tracked using the accelerator. | |
- The trained model is saved after the training is completed. | |
""" | |
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) | |
accelerator = Accelerator( | |
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, | |
mixed_precision=cfg.solver.mixed_precision, | |
log_with="mlflow", | |
project_dir="./mlruns", | |
kwargs_handlers=[kwargs], | |
) | |
# Make one log on every process with the configuration for debugging. | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
logger.info(accelerator.state, main_process_only=False) | |
if accelerator.is_local_main_process: | |
transformers.utils.logging.set_verbosity_warning() | |
diffusers.utils.logging.set_verbosity_info() | |
else: | |
transformers.utils.logging.set_verbosity_error() | |
diffusers.utils.logging.set_verbosity_error() | |
# If passed along, set the training seed now. | |
if cfg.seed is not None: | |
seed_everything(cfg.seed) | |
# create output dir for training | |
exp_name = cfg.exp_name | |
save_dir = f"{cfg.output_dir}/{exp_name}" | |
validation_dir = save_dir | |
if accelerator.is_main_process: | |
init_output_dir([save_dir]) | |
accelerator.wait_for_everyone() | |
if cfg.weight_dtype == "fp16": | |
weight_dtype = torch.float16 | |
elif cfg.weight_dtype == "bf16": | |
weight_dtype = torch.bfloat16 | |
elif cfg.weight_dtype == "fp32": | |
weight_dtype = torch.float32 | |
else: | |
raise ValueError( | |
f"Do not support weight dtype: {cfg.weight_dtype} during training" | |
) | |
if not torch.cuda.is_available(): | |
weight_dtype = torch.float32 | |
# Create Models | |
vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to( | |
device=device, dtype=weight_dtype | |
) | |
reference_unet = UNet2DConditionModel.from_pretrained( | |
cfg.base_model_path, | |
subfolder="unet", | |
).to(device=device, dtype=weight_dtype) | |
denoising_unet = UNet3DConditionModel.from_pretrained_2d( | |
cfg.base_model_path, | |
cfg.mm_path, | |
subfolder="unet", | |
unet_additional_kwargs=OmegaConf.to_container( | |
cfg.unet_additional_kwargs), | |
use_landmark=False | |
).to(device=device, dtype=weight_dtype) | |
imageproj = ImageProjModel( | |
cross_attention_dim=denoising_unet.config.cross_attention_dim, | |
clip_embeddings_dim=512, | |
clip_extra_context_tokens=4, | |
).to(device=device, dtype=weight_dtype) | |
face_locator = FaceLocator( | |
conditioning_embedding_channels=320, | |
).to(device=device, dtype=weight_dtype) | |
audioproj = AudioProjModel( | |
seq_len=5, | |
blocks=12, | |
channels=768, | |
intermediate_dim=512, | |
output_dim=768, | |
context_tokens=32, | |
).to(device=device, dtype=weight_dtype) | |
# Freeze | |
vae.requires_grad_(False) | |
imageproj.requires_grad_(False) | |
reference_unet.requires_grad_(False) | |
denoising_unet.requires_grad_(False) | |
face_locator.requires_grad_(False) | |
audioproj.requires_grad_(True) | |
# Set motion module learnable | |
trainable_modules = cfg.trainable_para | |
for name, module in denoising_unet.named_modules(): | |
if any(trainable_mod in name for trainable_mod in trainable_modules): | |
for params in module.parameters(): | |
params.requires_grad_(True) | |
reference_control_writer = ReferenceAttentionControl( | |
reference_unet, | |
do_classifier_free_guidance=False, | |
mode="write", | |
fusion_blocks="full", | |
) | |
reference_control_reader = ReferenceAttentionControl( | |
denoising_unet, | |
do_classifier_free_guidance=False, | |
mode="read", | |
fusion_blocks="full", | |
) | |
net = Net( | |
reference_unet, | |
denoising_unet, | |
face_locator, | |
reference_control_writer, | |
reference_control_reader, | |
imageproj, | |
audioproj, | |
).to(dtype=weight_dtype) | |
m,u = net.load_state_dict( | |
torch.load( | |
cfg.audio_ckpt_dir, | |
map_location="cpu", | |
), | |
) | |
assert len(m) == 0 and len(u) == 0, "Fail to load correct checkpoint." | |
print("loaded weight from ", os.path.join(cfg.audio_ckpt_dir)) | |
# get noise scheduler | |
_, val_noise_scheduler = get_noise_scheduler(cfg) | |
if cfg.solver.enable_xformers_memory_efficient_attention and torch.cuda.is_available(): | |
if is_xformers_available(): | |
reference_unet.enable_xformers_memory_efficient_attention() | |
denoising_unet.enable_xformers_memory_efficient_attention() | |
else: | |
raise ValueError( | |
"xformers is not available. Make sure it is installed correctly" | |
) | |
if cfg.solver.gradient_checkpointing: | |
reference_unet.enable_gradient_checkpointing() | |
denoising_unet.enable_gradient_checkpointing() | |
if cfg.solver.scale_lr: | |
learning_rate = ( | |
cfg.solver.learning_rate | |
* cfg.solver.gradient_accumulation_steps | |
* cfg.data.train_bs | |
* accelerator.num_processes | |
) | |
else: | |
learning_rate = cfg.solver.learning_rate | |
# Initialize the optimizer | |
optimizer_cls = torch.optim.AdamW | |
trainable_params = list( | |
filter(lambda p: p.requires_grad, net.parameters())) | |
optimizer = optimizer_cls( | |
trainable_params, | |
lr=learning_rate, | |
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), | |
weight_decay=cfg.solver.adam_weight_decay, | |
eps=cfg.solver.adam_epsilon, | |
) | |
# Scheduler | |
lr_scheduler = get_scheduler( | |
cfg.solver.lr_scheduler, | |
optimizer=optimizer, | |
num_warmup_steps=cfg.solver.lr_warmup_steps | |
* cfg.solver.gradient_accumulation_steps, | |
num_training_steps=cfg.solver.max_train_steps | |
* cfg.solver.gradient_accumulation_steps, | |
) | |
# get data loader | |
train_dataset = TalkingVideoDataset( | |
img_size=(cfg.data.train_width, cfg.data.train_height), | |
sample_rate=cfg.data.sample_rate, | |
n_sample_frames=cfg.data.n_sample_frames, | |
n_motion_frames=cfg.data.n_motion_frames, | |
audio_margin=cfg.data.audio_margin, | |
data_meta_paths=cfg.data.train_meta_paths, | |
wav2vec_cfg=cfg.wav2vec_config, | |
) | |
train_dataloader = torch.utils.data.DataLoader( | |
train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=16 | |
) | |
# Prepare everything with our `accelerator`. | |
( | |
net, | |
optimizer, | |
train_dataloader, | |
lr_scheduler, | |
) = accelerator.prepare( | |
net, | |
optimizer, | |
train_dataloader, | |
lr_scheduler, | |
) | |
return accelerator, vae, net, val_noise_scheduler, cfg, validation_dir | |
def load_config(config_path: str) -> dict: | |
""" | |
Loads the configuration file. | |
Args: | |
config_path (str): Path to the configuration file. | |
Returns: | |
dict: The configuration dictionary. | |
""" | |
if config_path.endswith(".yaml"): | |
return OmegaConf.load(config_path) | |
if config_path.endswith(".py"): | |
return import_filename(config_path).cfg | |
raise ValueError("Unsupported format for config file") | |
args = argparse.Namespace() | |
_config = load_config('configs/inference/inference.yaml') | |
for key, value in _config.items(): | |
setattr(args, key, value) | |
accelerator, vae, net, val_noise_scheduler, cfg, validation_dir = get_model(args) | |
cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length = log_validation( | |
accelerator=accelerator, | |
vae=vae, | |
net=net, | |
scheduler=val_noise_scheduler, | |
width=cfg.data.train_width, | |
height=cfg.data.train_height, | |
clip_length=cfg.data.n_sample_frames, | |
cfg=cfg, | |
save_dir=validation_dir, | |
global_step=0, | |
times=cfg.single_inference_times if cfg.single_inference_times is not None else None, | |
face_analysis_model_path=cfg.face_analysis_model_path | |
) | |
def predict(image, audio, pose_weight, face_weight, lip_weight, face_expand_ratio, progress=gr.Progress(track_tqdm=True)): | |
""" | |
Create a gradio interface with the configs. | |
""" | |
_ = progress | |
unique_id = uuid.uuid4() | |
config = { | |
'ref_img_path': image, | |
'audio_path': audio, | |
'pose_weight': pose_weight, | |
'face_weight': face_weight, | |
'lip_weight': lip_weight, | |
'face_expand_ratio': face_expand_ratio, | |
'config': 'configs/inference/inference.yaml', | |
'checkpoint': None, | |
'output': f'output-{unique_id}.mp4' | |
} | |
global cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length | |
for key, value in config.items(): | |
setattr(cfg, key, value) | |
return inference(cfg, image_processor, audio_processor, pipeline, audioproj, save_dir, global_step, clip_length) |