File size: 4,512 Bytes
b140fcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

import os
import argparse
import torch
from diffusers.utils.import_utils import is_xformers_available
from datasets import load_dataset
from tqdm.auto import tqdm
from scipy.io.wavfile import write
from auffusion_pipeline import AuffusionPipeline




def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a inference script.")
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default="auffusion/auffusion",
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )    
    parser.add_argument(
        "--test_data_dir",
        type=str,
        default="./data/test_audiocaps.raw.json",
        help="Path to test dataset in json file",
    )    
    parser.add_argument(
        "--audio_column", type=str, default="audio_path", help="The column of the dataset containing an audio."
    )
    parser.add_argument(
        "--caption_column", type=str, default="text", help="The column of the dataset containing a caption."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./output/auffusion_hf",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--sample_rate", type=int, default=16000, help="The sample rate of audio."
    )
    parser.add_argument(
        "--duration", type=int, default=10, help="The duration(s) of audio."
    ) 
    parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible inference.")
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="fp16",
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )    
    parser.add_argument(
        "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
    )
    parser.add_argument(
        "--guidance_scale", type=float, default=7.5, help="The scale of guidance."
    )
    parser.add_argument(
        "--num_inference_steps", type=int, default=100, help="Number of inference steps to perform."
    )
    parser.add_argument(
        "--width", type=int, default=1024, help="Width of the image."
    )
    parser.add_argument(
        "--height", type=int, default=256, help="Height of the image."
    ) 
    args = parser.parse_args()
    
    return args


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)   

    device = "cuda" if torch.cuda.is_available() else "cpu"
    weight_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.float32

    pipeline = AuffusionPipeline.from_pretrained(args.pretrained_model_name_or_path)
    pipeline = pipeline.to(device, weight_dtype)
    pipeline.set_progress_bar_config(disable=True)
    
    if is_xformers_available() and args.enable_xformers_memory_efficient_attention:
        pipeline.enable_xformers_memory_efficient_attention()

    generator = torch.Generator(device=device).manual_seed(args.seed)

    # load dataset
    audio_column, caption_column = args.audio_column, args.caption_column
    data_files = {"test": args.test_data_dir}
    dataset = load_dataset("json", data_files=data_files, split="test")

    # output dir
    audio_output_dir = os.path.join(args.output_dir, "audios")
    os.makedirs(audio_output_dir, exist_ok=True)

    # generating    
    audio_length = args.sample_rate * args.duration
    for i in tqdm(range(len(dataset)), desc="Generating"):

        prompt = dataset[i][caption_column]        
        audio_name = os.path.basename(dataset[i][audio_column])

        audio_path = os.path.join(audio_output_dir, audio_name)

        if os.path.exists(audio_path):
            continue

        with torch.autocast("cuda"):
            output = pipeline(prompt=prompt, num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, generator=generator, width=args.width, height=args.height)        

        audio = output.audios[0][:audio_length]

        write(audio_path, args.sample_rate, audio)


if __name__ == "__main__":
    main()