File size: 3,096 Bytes
258fd02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from thop import profile
from thop import clever_format
import torch
from tqdm import tqdm
import time
import sys
sys.path.append('./')


def analyze_model(model, inputs):
    # model size
    num_trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print("Num trainable parameters: {} M".format(num_trainable_parameters/1000./1000.))

    # computation cost
    with torch.no_grad():
        model.eval()
        macs, params = profile(model, inputs=inputs)
        macs, params = clever_format([macs, params], "%.3f")
        print("Macs: {}, Params: {}".format(macs, params))

    run_times = 50
    # eval forward 100 times
    with torch.no_grad():
        model = model.eval().to('cuda')
        inputs = [i.to('cuda') if isinstance(i, torch.Tensor) else i for i in inputs]
        model.init_device_dtype(inputs[0].device, inputs[0].dtype)
        st = time.time()
        for i in tqdm(range(run_times)):
            _ = model(*inputs)
        et = time.time()
        print("Eval forward : {:.03f} secs/per iter".format((et-st)/float(run_times)))

    # train backward 100 times
    model = model.train().to('cuda')
    inputs = [i.to('cuda') if isinstance(i, torch.Tensor) else i for i in inputs]
    model.init_device_dtype(inputs[0].device, inputs[0].dtype)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    optimizer.zero_grad()
    st = time.time()
    for i in tqdm(range(run_times)):
        inputs = [torch.rand_like(i) if isinstance(i, torch.cuda.FloatTensor) else i for i in inputs]
        out = model(*inputs)
        optimizer.zero_grad()
        out.mean().backward()
        optimizer.step()
    et = time.time()
    print("Train forward : {:.03f} secs/per iter".format((et-st)/float(run_times)))

def fetch_model_v3_transformer():
    # num params: 326M
    # macs (uncorrect): 261G/iter
    # infer: 0.32s/iter
    # train: 2.54s/iter
    from models_transformercond_winorm_ch16_everything_512 import PromptCondAudioDiffusion
    model = PromptCondAudioDiffusion( \
        "configs/scheduler/stable_diffusion_2.1_largenoise.json", \
        None, \
        "configs/models/transformer2D.json"
    )
    inputs = [
        torch.rand(1,16,1024*3//8,32), 
        torch.rand(1,7,512), 
        torch.tensor([1,]), 
        torch.tensor([0,]), 
        False,
    ]
    return model, inputs

def fetch_model_v3_unet():
    # num params: 310M
    # infer: 0.10s/iter
    # train: 0.70s/iter
    from models_musicldm_winorm_ch16_everything_sepnorm import PromptCondAudioDiffusion
    model = PromptCondAudioDiffusion( \
        "configs/scheduler/stable_diffusion_2.1_largenoise.json", \
        None, \
        "configs/diffusion_clapcond_model_config_ch16_everything.json"
    )
    inputs = [
        torch.rand(1,16,1024*3//8,32), 
        torch.rand(1,7,512), 
        torch.tensor([1,]), 
        torch.tensor([0,]), 
        False,
    ]
    return model, inputs

if __name__=="__main__":
    model, inputs = fetch_model_v3_transformer()
    # model, inputs = fetch_model_v3_unet()
    analyze_model(model, inputs)