ayousanz commited on
Commit
364ef4d
·
verified ·
1 Parent(s): 3066f2e

Upload 2 files

Browse files
Files changed (2) hide show
  1. model.ckpt +3 -0
  2. model.yaml +135 -0
model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af6ab91337d0ef6b518367082ac3f849448c6daaa01fd987678fb25ea44ca184
3
+ size 1839231053
model.yaml ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 1.0.0
2
+
3
+ model:
4
+ base_learning_rate: 0.00001
5
+ target: mug.diffusion.diffusion.DDPM
6
+ params:
7
+ linear_start: 0.0001
8
+ linear_end: 0.02
9
+ log_every_t: 100
10
+ timesteps: 1000
11
+ z_channels: 16
12
+ z_length: 512
13
+ parameterization: eps
14
+ loss_type: smooth_l1
15
+ monitor: val/loss_simple
16
+
17
+ unet_config:
18
+ target: mug.diffusion.unet.UNetModel
19
+ params:
20
+ in_channels: 16
21
+ model_channels: 128
22
+ out_channels: 16
23
+ attention_resolutions: [ 8,4,2 ]
24
+ num_res_blocks: 2
25
+ channel_mult: [ 1,2,3,4 ]
26
+ num_heads: 8
27
+ context_dim: 128
28
+ dropout: 0.0
29
+ lstm_last: false
30
+ lstm_layer: false
31
+ s4_layer: true
32
+ audio_channels: [ 256,512,512,512 ]
33
+ use_checkpoint: false
34
+
35
+ first_stage_config:
36
+ target: mug.firststage.autoencoder.AutoencoderKL
37
+ params:
38
+ monitor: "val/loss"
39
+ kl_weight: 0.000001
40
+ ddconfig:
41
+ x_channels: 16 # key_count * 4
42
+ middle_channels: 64
43
+ z_channels: 16
44
+ num_groups: 8
45
+ channel_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
46
+ num_res_blocks: 1
47
+ lossconfig:
48
+ target: torch.nn.Identity
49
+ # target: mug.firststage.losses.ManiaReconstructLoss
50
+ # params:
51
+ # weight_start_offset: 0.5
52
+ # weight_holding: 0.5
53
+ # weight_end_offset: 0.2
54
+ # label_smoothing: 0.001
55
+
56
+
57
+ cond_stage_config:
58
+ target: mug.cond.feature.BeatmapFeatureEmbedder
59
+ params:
60
+ path_to_yaml: "configs/mug/mania_beatmap_features.yaml"
61
+ embed_dim: 128
62
+
63
+ wave_stage_config:
64
+ target: mug.cond.wave.MelspectrogramScaleEncoder1D
65
+ params:
66
+ n_freq: 128
67
+ middle_channels: 128
68
+ attention_resolutions: [ 128,256,512 ]
69
+ num_res_blocks: 2
70
+ num_heads: 8
71
+ num_groups: 32
72
+ dropout: 0.0
73
+ use_checkpoint: true
74
+ channel_mult: [ 1,1,1,1,2,2,2,4,4,4 ]
75
+
76
+
77
+ data:
78
+ target: main.DataModuleFromConfig
79
+ params:
80
+ batch_size: 48
81
+ wrap: False
82
+ # num_workers: 0
83
+ num_workers: 7
84
+ common_params:
85
+ txt_file: [ ]
86
+ sr: 22050
87
+ n_fft: 512
88
+ max_audio_frame: 32768
89
+ audio_note_window_ratio: 8
90
+ n_mels: 128
91
+ cache_dir: "data/audio_cache/"
92
+ with_audio: true
93
+ with_feature: true
94
+ feature_yaml: "configs/mug/mania_beatmap_features.yaml"
95
+ # audio_window_frame = n_fft / sr / 4 = 0.00580499 s
96
+ # note_window_frame = audio_note_window_ratio * audio_window_frame = 0.04643990 s
97
+ # max_duration = audio_window_frame * max_audio_frame = 190.2179 s = 3 min 10 s
98
+ # max_note_frame = max_audio_frame / audio_note_window_ratio = 4096
99
+
100
+ # old ===========
101
+ # audio_window_frame = n_fft / sr / 4 = 0.02321995 s
102
+ # note_window_frame = audio_note_window_ratio * audio_window_frame = 0.04643990 s
103
+ # max_duration = audio_window_frame * max_audio_frame = 380.4357 s = 6 min 20 s
104
+ # max_note_frame = max_audio_frame / audio_note_window_ratio = 8192
105
+ train:
106
+ target: mug.data.dataset.OsuTrainDataset
107
+ params:
108
+ mirror_p: 0.5
109
+ feature_dropout_p: 0.5
110
+ mirror_at_interval_p: 0
111
+ rate_p: 0.2
112
+ rate: [ 0.75,1.3 ]
113
+ freq_mask_p: 0.0
114
+ freq_mask_num: 15
115
+
116
+ validation:
117
+ target: mug.data.dataset.OsuValidDataset
118
+ params: {}
119
+ # test_txt_file: "data\\mug\\local_mania_4k_test.txt"
120
+
121
+
122
+ lightning:
123
+ callbacks:
124
+ beatmap_logger:
125
+ target: mug.data.dataset.BeatmapLogger
126
+ params:
127
+ log_batch_idx: [ 0 ]
128
+ splits: [ 'val' ]
129
+ count: 16
130
+
131
+ trainer:
132
+ benchmark: True
133
+ accelerator: dp
134
+ accumulate_grad_batches: 1
135
+ # precision: 16