wi-lab commited on
Commit
0431525
·
verified ·
1 Parent(s): 6a75443

Update utils/pretraining.py

Browse files
Files changed (1) hide show
  1. utils/pretraining.py +273 -150
utils/pretraining.py CHANGED
@@ -1,150 +1,273 @@
1
- #%% PACKAGES & MODULES
2
- import torch
3
- import torch.nn as nn
4
- import torch.optim as optim
5
- from torch.optim.lr_scheduler import StepLR
6
- from inference import prepare_for_lwm
7
- from input_preprocess import tokenizer
8
- from lwm_model import lwm
9
- import numpy as np
10
-
11
- #%% PARAMETERS
12
- n_epochs = 100
13
- n_layers = 12
14
- n_heads = 12
15
- d_model = 64
16
- d_ff = d_model * 4
17
- d_k = d_model // n_heads
18
- d_v = d_model // n_heads
19
- dropout = 0.1
20
- max_len = 129
21
- element_length = 16
22
- batch_size = 64
23
- train_ratio = 0.7
24
- val_ratio = 0.2
25
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
26
-
27
- #%% PRE-TRAINING DATA GENERATION
28
- # The following DeepMIMO scenarios are not enough for pre-training a
29
- # Transformer-based foundation model like LWM. Add more scenarios for
30
- # more effective pre-training. The instruction for reproducing the actual
31
- # dataset used for pre-training LWM can be found in the Huggingface forum.
32
- scenario_names = np.array([
33
- "city_18_denver", "city_15_indianapolis", "city_19_oklahoma",
34
- "city_12_fortworth", "city_11_santaclara", "city_7_sandiego"
35
- ])
36
-
37
- scenario_idxs = np.array([0, 1, 2, 3, 4, 5])
38
- selected_scenario_names = scenario_names[scenario_idxs]
39
-
40
- preprocessed_chs = tokenizer(
41
- selected_scenario_names=selected_scenario_names,
42
- manual_data=None,
43
- gen_raw=False)
44
-
45
- #%% DATALOADER
46
- train_size = int(train_ratio * len(preprocessed_chs))
47
- val_size = int(val_ratio * len(preprocessed_chs))
48
- test_size = len(preprocessed_chs) - val_size - train_size
49
-
50
- train_data, val_data, test_data = torch.utils.data.random_split(
51
- preprocessed_chs, [train_size, val_size, test_size]
52
- )
53
-
54
- train_loader = prepare_for_lwm(train_data, device, batch_size=batch_size, shuffle=True)
55
- val_loader = prepare_for_lwm(val_data, device, batch_size=batch_size, shuffle=True)
56
- test_loader = prepare_for_lwm(test_data, device, batch_size=batch_size, shuffle=True)
57
-
58
- # %% Model
59
- load_model = False
60
-
61
- model = lwm()
62
- model.to(device)
63
-
64
- if load_model:
65
- model_name = 'models/pretrained_model.pth'
66
- model.load_state_dict(torch.load(model_name))
67
- print(f"Model loaded from {model_name}")
68
-
69
- # Loss function
70
- criterionMLM = nn.MSELoss()
71
-
72
- # %% Optimizer and Scheduler
73
- adaptive_lr = False
74
-
75
- optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
76
- scheduler = (
77
- optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')
78
- if adaptive_lr
79
- else StepLR(optimizer, step_size=10, gamma=0.9)
80
- )
81
-
82
- # %% Training
83
- training_loss = []
84
- validation_loss = []
85
-
86
- def train(model, dataloader, optimizer, scheduler=None, device="cuda"):
87
-
88
- model.train()
89
- running_loss = 0.0
90
- criterionMCM = nn.MSELoss()
91
-
92
- for idx, batch in enumerate(dataloader):
93
- input_ids = batch[0].to(device)
94
- masked_tokens = batch[1].to(device)
95
- masked_pos = batch[2].to(device)
96
-
97
- optimizer.zero_grad()
98
-
99
- logits_lm, _ = model(input_ids, masked_pos)
100
- loss_lm = criterionMCM(logits_lm, masked_tokens)
101
- loss = loss_lm / torch.var(masked_tokens)
102
-
103
- loss.backward()
104
- optimizer.step()
105
-
106
- if scheduler is not None:
107
- scheduler.step()
108
-
109
- running_loss += loss.item()
110
-
111
- average_loss = running_loss / len(dataloader)
112
-
113
- return average_loss
114
-
115
- def validate(model, dataloader, device="cuda"):
116
- model.eval()
117
- running_loss = 0.0
118
- criterionMCM = nn.MSELoss()
119
-
120
- with torch.no_grad():
121
- for idx, batch in enumerate(dataloader):
122
- input_ids = batch[0].to(device)
123
- masked_tokens = batch[1].to(device)
124
- masked_pos = batch[2].to(device)
125
-
126
- logits_lm, _ = model(input_ids, masked_pos)
127
-
128
- loss_lm = criterionMCM(logits_lm, masked_tokens)
129
- loss = loss_lm / torch.var(masked_tokens)
130
-
131
- running_loss += loss.item()
132
-
133
- average_loss = running_loss / len(dataloader)
134
-
135
- return average_loss
136
-
137
- # %% Training Loop
138
- for epoch in range(n_epochs):
139
- print(f"Epoch {epoch + 1}/{n_epochs}")
140
-
141
- # Training step
142
- train_loss = train(model, train_loader, optimizer, scheduler, device)
143
- training_loss.append(train_loss)
144
- print(f"Training Loss: {train_loss:.4f}")
145
-
146
- # Validation step
147
- if val_loader is not None:
148
- val_loss = validate(model, val_loader, device)
149
- validation_loss.append(val_loss)
150
- print(f"Validation Loss: {val_loss:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%% PACKAGES & MODULES
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.optim.lr_scheduler import StepLR
6
+ from inference import prepare_for_lwm
7
+ from input_preprocess import tokenizer
8
+ from lwm_model import lwm
9
+ import numpy as np
10
+ import DeepMIMOv3
11
+
12
+ #%% PRE-TRAINING SCENARIO CONFIG
13
+ def get_parameters(scenario):
14
+
15
+ n_ant_bs = 32
16
+ n_ant_ue = 1
17
+ n_subcarriers = 32
18
+ scs = 30e3
19
+
20
+ row_column_users = {
21
+ 'asu_campus1': {
22
+ 'n_rows': 321,
23
+ 'n_per_row': 411
24
+ },
25
+ 'Boston5G_3p5': {
26
+ 'n_rows': [812,1622],
27
+ 'n_per_row': 595
28
+ },
29
+ 'city_0_newyork': {
30
+ 'n_rows': 44,
31
+ 'n_per_row': 117
32
+ },
33
+ 'city_1_losangeles': {
34
+ 'n_rows': 57,
35
+ 'n_per_row': 81
36
+ },
37
+ 'city_2_chicago': {
38
+ 'n_rows': 56,
39
+ 'n_per_row': 80
40
+ },
41
+ 'city_3_houston': {
42
+ 'n_rows': 62,
43
+ 'n_per_row': 81
44
+ },
45
+ 'city_4_phoenix': {
46
+ 'n_rows': 79,
47
+ 'n_per_row': 86
48
+ },
49
+ 'city_5_philadelphia': {
50
+ 'n_rows': 96,
51
+ 'n_per_row': 66
52
+ },
53
+ 'city_6_miami': {
54
+ 'n_rows': 80,
55
+ 'n_per_row': 87
56
+ },
57
+ 'city_8_dallas': {
58
+ 'n_rows': 83,
59
+ 'n_per_row': 76
60
+ },
61
+ 'city_9_sanfrancisco': {
62
+ 'n_rows': 79,
63
+ 'n_per_row': 83
64
+ },
65
+ 'city_10_austin': {
66
+ 'n_rows': 102,
67
+ 'n_per_row': 55
68
+ },
69
+ 'city_13_columbus': {
70
+ 'n_rows': 71,
71
+ 'n_per_row': 96
72
+ },
73
+ 'city_17_seattle': {
74
+ 'n_rows': 74,
75
+ 'n_per_row': 82
76
+ },
77
+ 'O1_3p5': {
78
+ 'n_rows': 5203,
79
+ 'n_per_row': 181
80
+ },
81
+ 'city_18_denver': {
82
+ 'n_rows': 85,
83
+ 'n_per_row': 82
84
+ },
85
+ 'city_15_indianapolis': {
86
+ 'n_rows': 80,
87
+ 'n_per_row': 79
88
+ },
89
+ 'city_19_oklahoma': {
90
+ 'n_rows': 82,
91
+ 'n_per_row': 75
92
+ },
93
+ 'city_12_fortworth': {
94
+ 'n_rows': 86,
95
+ 'n_per_row': 72
96
+ },
97
+ 'city_11_santaclara': {
98
+ 'n_rows': 47,
99
+ 'n_per_row': 114
100
+ },
101
+ 'city_7_sandiego': {
102
+ 'n_rows': 71,
103
+ 'n_per_row': 83
104
+ }}
105
+
106
+ parameters = DeepMIMOv3.default_params()
107
+ parameters['dataset_folder'] = './scenarios'
108
+ parameters['scenario'] = scenario
109
+
110
+ if scenario == 'O1_3p5':
111
+ parameters['active_BS'] = np.array([4])
112
+ elif scenario in ['city_14_charlotte', 'city_18_denver', 'city_15_indianapolis']:
113
+ parameters['active_BS'] = np.array([3])
114
+ else:
115
+ parameters['active_BS'] = np.array([1])
116
+
117
+ if scenario == 'Boston5G_3p5':
118
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'][0],
119
+ row_column_users[scenario]['n_rows'][1])
120
+ else:
121
+ parameters['user_rows'] = np.arange(row_column_users[scenario]['n_rows'])
122
+ parameters['bs_antenna']['shape'] = np.array([n_ant_bs, 1]) # Horizontal, Vertical
123
+ parameters['bs_antenna']['rotation'] = np.array([0,0,-135]) # (x,y,z)
124
+ parameters['ue_antenna']['shape'] = np.array([n_ant_ue, 1])
125
+ parameters['enable_BS2BS'] = False
126
+ parameters['OFDM']['subcarriers'] = n_subcarriers
127
+ parameters['OFDM']['selected_subcarriers'] = np.arange(n_subcarriers)
128
+
129
+ parameters['OFDM']['bandwidth'] = scs * n_subcarriers / 1e9
130
+ parameters['num_paths'] = 20
131
+
132
+ return parameters, row_column_users, n_ant_bs, n_ant_ue, n_subcarriers
133
+
134
+ #%% PARAMETERS
135
+ n_epochs = 100
136
+ n_layers = 12
137
+ n_heads = 12
138
+ d_model = 64
139
+ d_ff = d_model * 4
140
+ d_k = d_model // n_heads
141
+ d_v = d_model // n_heads
142
+ dropout = 0.1
143
+ max_len = 129
144
+ element_length = 16
145
+ batch_size = 64
146
+ train_ratio = 0.7
147
+ val_ratio = 0.2
148
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
149
+
150
+ #%% PRE-TRAINING DATA GENERATION
151
+ # The following DeepMIMO scenarios are not enough for pre-training a
152
+ # Transformer-based foundation model like LWM. Add more scenarios for
153
+ # more effective pre-training. The instruction for reproducing the actual
154
+ # dataset used for pre-training LWM can be found in the Huggingface forum.
155
+ scenario_names = np.array([
156
+ "city_18_denver", "city_15_indianapolis", "city_19_oklahoma",
157
+ "city_12_fortworth", "city_11_santaclara", "city_7_sandiego"
158
+ ])
159
+
160
+ scenario_idxs = np.array([0, 1, 2, 3, 4, 5])
161
+ selected_scenario_names = scenario_names[scenario_idxs]
162
+
163
+ preprocessed_chs = tokenizer(
164
+ selected_scenario_names=selected_scenario_names,
165
+ manual_data=None,
166
+ gen_raw=False)
167
+
168
+ #%% DATALOADER
169
+ train_size = int(train_ratio * len(preprocessed_chs))
170
+ val_size = int(val_ratio * len(preprocessed_chs))
171
+ test_size = len(preprocessed_chs) - val_size - train_size
172
+
173
+ train_data, val_data, test_data = torch.utils.data.random_split(
174
+ preprocessed_chs, [train_size, val_size, test_size]
175
+ )
176
+
177
+ train_loader = prepare_for_lwm(train_data, device, batch_size=batch_size, shuffle=True)
178
+ val_loader = prepare_for_lwm(val_data, device, batch_size=batch_size, shuffle=True)
179
+ test_loader = prepare_for_lwm(test_data, device, batch_size=batch_size, shuffle=True)
180
+
181
+ # %% Model
182
+ load_model = False
183
+
184
+ model = lwm()
185
+ model.to(device)
186
+
187
+ if load_model:
188
+ model_name = 'models/pretrained_model.pth'
189
+ model.load_state_dict(torch.load(model_name))
190
+ print(f"Model loaded from {model_name}")
191
+
192
+ # Loss function
193
+ criterionMLM = nn.MSELoss()
194
+
195
+ # %% Optimizer and Scheduler
196
+ adaptive_lr = False
197
+
198
+ optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
199
+ scheduler = (
200
+ optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min')
201
+ if adaptive_lr
202
+ else StepLR(optimizer, step_size=10, gamma=0.9)
203
+ )
204
+
205
+ # %% Training
206
+ training_loss = []
207
+ validation_loss = []
208
+
209
+ def train(model, dataloader, optimizer, scheduler=None, device="cuda"):
210
+
211
+ model.train()
212
+ running_loss = 0.0
213
+ criterionMCM = nn.MSELoss()
214
+
215
+ for idx, batch in enumerate(dataloader):
216
+ input_ids = batch[0].to(device)
217
+ masked_tokens = batch[1].to(device)
218
+ masked_pos = batch[2].to(device)
219
+
220
+ optimizer.zero_grad()
221
+
222
+ logits_lm, _ = model(input_ids, masked_pos)
223
+ loss_lm = criterionMCM(logits_lm, masked_tokens)
224
+ loss = loss_lm / torch.var(masked_tokens)
225
+
226
+ loss.backward()
227
+ optimizer.step()
228
+
229
+ if scheduler is not None:
230
+ scheduler.step()
231
+
232
+ running_loss += loss.item()
233
+
234
+ average_loss = running_loss / len(dataloader)
235
+
236
+ return average_loss
237
+
238
+ def validate(model, dataloader, device="cuda"):
239
+ model.eval()
240
+ running_loss = 0.0
241
+ criterionMCM = nn.MSELoss()
242
+
243
+ with torch.no_grad():
244
+ for idx, batch in enumerate(dataloader):
245
+ input_ids = batch[0].to(device)
246
+ masked_tokens = batch[1].to(device)
247
+ masked_pos = batch[2].to(device)
248
+
249
+ logits_lm, _ = model(input_ids, masked_pos)
250
+
251
+ loss_lm = criterionMCM(logits_lm, masked_tokens)
252
+ loss = loss_lm / torch.var(masked_tokens)
253
+
254
+ running_loss += loss.item()
255
+
256
+ average_loss = running_loss / len(dataloader)
257
+
258
+ return average_loss
259
+
260
+ # %% Training Loop
261
+ for epoch in range(n_epochs):
262
+ print(f"Epoch {epoch + 1}/{n_epochs}")
263
+
264
+ # Training step
265
+ train_loss = train(model, train_loader, optimizer, scheduler, device)
266
+ training_loss.append(train_loss)
267
+ print(f"Training Loss: {train_loss:.4f}")
268
+
269
+ # Validation step
270
+ if val_loader is not None:
271
+ val_loss = validate(model, val_loader, device)
272
+ validation_loss.append(val_loss)
273
+ print(f"Validation Loss: {val_loss:.4f}")