File size: 5,924 Bytes
1fd4e9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#coding=utf-8
import logging
from pathlib import Path
import torch
import torchaudio


from third_party.MMAudio.mmaudio.eval_utils import ModelConfig, all_model_cfg, generate, load_video, make_video, setup_eval_logging
from third_party.MMAudio.mmaudio.model.flow_matching import FlowMatching
from third_party.MMAudio.mmaudio.model.networks import MMAudio, get_my_mmaudio
from third_party.MMAudio.mmaudio.model.utils.features_utils import FeaturesUtils


class V2A_MMAudio:
    def __init__(self, 
                variant: str="large_44k",
                num_steps: int=25,
                full_precision: bool=False,):
        
        self.log = logging.getLogger(self.__class__.__name__)
        self.log.setLevel(logging.INFO)
        self.log.info(f"The V2A model uses MMAudio {variant}, init...")
        
        self.device = 'cpu'
        if torch.cuda.is_available():
            self.device = 'cuda'
        elif torch.backends.mps.is_available():
            self.device = 'mps'
        else:
            self.log.warning('CUDA/MPS are not available, running on CPU')
        self.dtype = torch.float32 if full_precision else torch.bfloat16

        if variant not in all_model_cfg:
            raise ValueError(f'Unknown model variant: {variant}')
        self.model: ModelConfig = all_model_cfg[variant]
        self.model.download_if_needed()

        self.net: MMAudio= get_my_mmaudio(self.model.model_name).to(self.device, self.dtype).eval()
        self.net.load_weights(torch.load(self.model.model_path, map_location=self.device, weights_only=True))

        # Flow Matching
        self.fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)

        # Feature utils setup
        self.feature_utils = FeaturesUtils(tod_vae_ckpt=self.model.vae_path,
                                           synchformer_ckpt=self.model.synchformer_ckpt,
                                           enable_conditions=True,
                                           mode=self.model.mode,
                                           bigvgan_vocoder_ckpt=self.model.bigvgan_16k_path,
                                           need_vae_encoder=False)
        self.feature_utils = self.feature_utils.to(self.device, self.dtype).eval()


    @torch.no_grad()
    def generate_audio(self, 
                       video_path,
                       output_dir,
                       prompt: str='', 
                       negative_prompt: str='',
                       duration: int=10,
                       seed: int=42,
                       cfg_strength: float=4.5,
                       mask_away_clip: bool=False,
                       is_postp=False,):
        
        video_path = Path(video_path).expanduser()
        output_dir = Path(output_dir).expanduser()
        self.log.info(f"Loading video: {video_path}")
        output_dir.mkdir(parents=True, exist_ok=True)

        video_info = load_video(video_path, duration)
        clip_frames = video_info.clip_frames
        sync_frames = video_info.sync_frames
        duration = video_info.duration_sec

        # Setup random generator for reproducibility
        rng = torch.Generator(device=self.device)
        rng.manual_seed(seed)

        if mask_away_clip:
            clip_frames = None
        else:
            clip_frames = clip_frames.unsqueeze(0)
        sync_frames = sync_frames.unsqueeze(0)

        seq_cfg = self.model.seq_cfg
        seq_cfg.duration = duration
        self.net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)

        self.log.info(f'Prompt: {prompt}')
        self.log.info(f'Negative prompt: {negative_prompt}')
        
        self.log.info(f"Generating Audio...")
        audios = generate(
            clip_frames,
            sync_frames,
            [prompt],
            negative_text=[negative_prompt],
            feature_utils=self.feature_utils,
            net=self.net,
            fm=self.fm,
            rng=rng,
            cfg_strength=cfg_strength)
        audio = audios.float().cpu()[0]
        
        if is_postp:
            audio_save_path = output_dir / f'{video_path.stem}.neg.wav'
            video_save_path = output_dir / f'{video_path.stem}.neg.mp4'
        else:
            audio_save_path = output_dir / f'{video_path.stem}.step1.wav'
            video_save_path = output_dir / f'{video_path.stem}.step1.mp4'

        
        self.log.info(f"Saving generated audio and video to {output_dir}")
        torchaudio.save(str(audio_save_path), audio, seq_cfg.sampling_rate)
        self.log.info(f'Audio saved to {audio_save_path}')
        make_video(video_info, str(video_save_path), audio, sampling_rate=seq_cfg.sampling_rate)
        self.log.info(f'Video saved to {video_save_path}')

        return audio_save_path, video_save_path



# def main():
#     # 初始化日志(如果你有 logger.py,推荐只做一次初始化)
#     setup_eval_logging()

#     # 初始化模型
#     v2a_model = V2A_MMAudio(
#         variant="large_44k_v2",     # 这个是你模型的版本名
#         num_steps=25,               # 采样步数
#         seed=42,                    # 随机种子
#         full_precision=False        # 是否使用全精度
#     )

#     # 视频路径(换成你的真实路径)
#     video_path = "ZxiXftx2EMg_000477.mp4"

#     # 输出目录
#     output_dir = "outputs"

#     # 提示词(控制生成内容)
#     prompt = ""
#     negative_prompt = ""

#     # 生成音频 + 视频
#     audio_save_path, video_save_path = v2a_model.generate_audio(
#         video_path=video_path,
#         output_dir=output_dir,
#         prompt=prompt,
#         negative_prompt=negative_prompt,
#         duration=10,            # 秒
#         cfg_strength=4.5,       # 指导强度
#         mask_away_clip=False    # 是否移除 clip
#     )

# if __name__ == "__main__":
#     main()