shredder-31 commited on
Commit
9ca7f1a
·
verified ·
1 Parent(s): 9b046d8

Upload baseline 4 model files and outputs

Browse files
Files changed (1) hide show
  1. baseline4/trainer.py +283 -0
baseline4/trainer.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import yaml
4
+ import torch
5
+ import random
6
+ import numpy as np
7
+ import torch.nn as nn
8
+ import albumentations as A
9
+ import torch.optim as optim
10
+ import torch.multiprocessing as mp
11
+ from datetime import datetime
12
+ from albumentations.pytorch import ToTensorV2
13
+ from torch.cuda.amp import autocast, GradScaler
14
+ from torch.utils.data import DataLoader
15
+ from torch.utils.tensorboard.writer import SummaryWriter
16
+ from model import Group_Activity_Temporal_Classifer , collate_fn
17
+
18
+ ROOT = "/kaggle/"
19
+ PROJECT_ROOT= "/kaggle/working/Group-Activity-Recognition"
20
+ CONFIG_FILE_PATH = "/kaggle/working/Group-Activity-Recognition/modeling/configs/Baseline B4.yml"
21
+
22
+ sys.path.append(os.path.abspath(PROJECT_ROOT))
23
+
24
+ from data_utils import Group_Activity_DataSet, group_activity_labels
25
+ from eval_utils import get_f1_score , plot_confusion_matrix
26
+ from helper_utils import load_config, setup_logging, save_checkpoint
27
+
28
+ def set_seed(seed):
29
+ random.seed(seed)
30
+ np.random.seed(seed)
31
+ torch.manual_seed(seed)
32
+ torch.cuda.manual_seed(seed)
33
+ torch.cuda.manual_seed_all(seed)
34
+ torch.backends.cudnn.deterministic = True
35
+ torch.backends.cudnn.benchmark = False
36
+
37
+ def train_one_epoch(model, train_loader, criterion, optimizer, scaler, device, epoch, writer, logger):
38
+ model.train()
39
+ total_loss = 0
40
+ correct = 0
41
+ total = 0
42
+
43
+ for batch_idx, (inputs, targets) in enumerate(train_loader):
44
+ inputs, targets = inputs.to(device), targets.to(device)
45
+ optimizer.zero_grad()
46
+
47
+ with autocast(dtype=torch.float16):
48
+ outputs = model(inputs)
49
+ loss = criterion(outputs, targets)
50
+
51
+ scaler.scale(loss).backward()
52
+ scaler.step(optimizer)
53
+ scaler.update()
54
+
55
+ total_loss += loss.item()
56
+
57
+ predicted = outputs.argmax(1)
58
+ target_class = targets.argmax(1)
59
+ total += targets.size(0)
60
+ correct += predicted.eq(target_class).sum().item()
61
+
62
+ if batch_idx % 10 == 0:
63
+ step = epoch * len(train_loader) + batch_idx
64
+ writer.add_scalar('Training/BatchLoss', loss.item(), step)
65
+ writer.add_scalar('Training/BatchAccuracy', 100.*correct/total, step)
66
+
67
+ log_msg = f'Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f} | Acc: {100.*correct/total:.2f}%'
68
+ logger.info(log_msg)
69
+
70
+ epoch_loss = total_loss / len(train_loader)
71
+ epoch_acc = 100. * correct / total
72
+
73
+ logger.info(f"Epoch {epoch} - Train Loss: {epoch_loss:.4f} - Accuracy: {epoch_acc:.2f}%")
74
+
75
+ writer.add_scalar('Training/EpochLoss', epoch_loss, epoch)
76
+ writer.add_scalar('Training/EpochAccuracy', epoch_acc, epoch)
77
+
78
+ return epoch_loss, epoch_acc
79
+
80
+ def validate_model(model, val_loader, criterion, device, epoch, writer, logger, class_names):
81
+ model.eval()
82
+ total_loss = 0
83
+ correct = 0
84
+ total = 0
85
+
86
+ y_true = []
87
+ y_pred = []
88
+
89
+ with torch.no_grad():
90
+ for inputs, targets in val_loader:
91
+ inputs, targets = inputs.to(device), targets.to(device)
92
+
93
+ outputs = model(inputs)
94
+ loss = criterion(outputs, targets)
95
+
96
+ total_loss += loss.item()
97
+
98
+ predicted = outputs.argmax(1)
99
+ target_class = targets.argmax(1)
100
+ total += targets.size(0)
101
+ correct += predicted.eq(target_class).sum().item()
102
+
103
+
104
+ y_true.extend(target_class.cpu().numpy())
105
+ y_pred.extend(predicted.cpu().numpy())
106
+
107
+ avg_loss = total_loss / len(val_loader)
108
+ accuracy = 100. * correct / total
109
+
110
+ f1_score = get_f1_score(y_true, y_pred, average="weighted")
111
+ writer.add_scalar('Validation/F1Score', f1_score, epoch)
112
+
113
+ fig = plot_confusion_matrix(y_true, y_pred, class_names)
114
+ writer.add_figure('Validation/ConfusionMatrix', fig, epoch)
115
+
116
+ writer.add_scalar('Validation/Loss', avg_loss, epoch)
117
+ writer.add_scalar('Validation/Accuracy', accuracy, epoch)
118
+
119
+ logger.info(f"Epoch {epoch} | Valid Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}% | F1 Score: {f1_score:.4f}")
120
+
121
+ return avg_loss, accuracy
122
+
123
+ def train_model(config_path):
124
+
125
+ config = load_config(config_path)
126
+
127
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
128
+ exp_dir = os.path.join(
129
+ f"{PROJECT_ROOT}/modeling/baseline 4/{config.experiment['output_dir']}",
130
+ f"{config.experiment['name']}_V{config.experiment['version']}_{timestamp}"
131
+ )
132
+ os.makedirs(exp_dir, exist_ok=True)
133
+
134
+ logger = setup_logging(exp_dir)
135
+ logger.info(f"Starting experiment: {config.experiment['name']}_V{config.experiment['version']}")
136
+
137
+ writer = SummaryWriter(log_dir=os.path.join(exp_dir, 'tensorboard'))
138
+
139
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
140
+ logger.info(f"Using device: {device}")
141
+
142
+ set_seed(config.experiment['seed'])
143
+ logger.info(f"Set random seed: {config.experiment['seed']}")
144
+
145
+ train_transforms = A.Compose([
146
+ A.Resize(224, 224),
147
+ A.OneOf([
148
+ A.GaussianBlur(blur_limit=(3, 7)),
149
+ A.ColorJitter(brightness=0.2),
150
+ A.RandomBrightnessContrast(),
151
+ A.GaussNoise()
152
+ ], p=0.5),
153
+ A.OneOf([
154
+ A.HorizontalFlip(),
155
+ A.VerticalFlip(),
156
+ ], p=0.05),
157
+ A.Normalize(
158
+ mean=[0.485, 0.456, 0.406],
159
+ std=[0.229, 0.224, 0.225]
160
+ ),
161
+ ToTensorV2()
162
+ ])
163
+
164
+ val_transforms = A.Compose([
165
+ A.Resize(224, 224),
166
+ A.Normalize(
167
+ mean=[0.485, 0.456, 0.406],
168
+ std=[0.229, 0.224, 0.225]
169
+ ),
170
+ ToTensorV2()
171
+ ])
172
+
173
+ train_dataset = Group_Activity_DataSet(
174
+ videos_path=f"{ROOT}/{config.data['videos_path']}",
175
+ annot_path=f"{ROOT}/{config.data['annot_path']}",
176
+ split=config.data['video_splits']['train'],
177
+ crops=False,
178
+ seq=True,
179
+ labels=group_activity_labels,
180
+ transform=train_transforms
181
+ )
182
+
183
+ val_dataset = Group_Activity_DataSet(
184
+ videos_path=f"{ROOT}/{config.data['videos_path']}",
185
+ annot_path=f"{ROOT}/{config.data['annot_path']}",
186
+ split=config.data['video_splits']['validation'],
187
+ crops=False,
188
+ seq=True,
189
+ labels=group_activity_labels,
190
+ transform=val_transforms
191
+ )
192
+
193
+ logger.info(f"Training dataset size: {len(train_dataset)}")
194
+ logger.info(f"Validation dataset size: {len(val_dataset)}")
195
+
196
+ train_loader = DataLoader(
197
+ train_dataset,
198
+ batch_size=config.training['batch_size']['train'],
199
+ shuffle=True,
200
+ collate_fn=collate_fn,
201
+ num_workers=4,
202
+ pin_memory=True
203
+ )
204
+
205
+ val_loader = DataLoader(
206
+ val_dataset,
207
+ batch_size=config.training['batch_size']['val'],
208
+ shuffle=True,
209
+ collate_fn=collate_fn,
210
+ num_workers=4,
211
+ pin_memory=True
212
+ )
213
+
214
+ model = Group_Activity_Temporal_Classifer(
215
+ num_classes=config.model['num_classes'],
216
+ input_size=config.model['input_size'],
217
+ hidden_size=config.model['hidden_size'],
218
+ num_layers=config.model['num_layers']
219
+ )
220
+
221
+ model = model.to(device)
222
+
223
+ if config.training['optimizer'] == "AdamW":
224
+ optimizer = optim.AdamW(
225
+ model.parameters(),
226
+ lr=config.training['learning_rate'],
227
+ weight_decay=config.training['weight_decay']
228
+ )
229
+ elif config.training['optimizer'] == "SGD":
230
+ optimizer = optim.SGD(
231
+ model.parameters(),
232
+ lr=config.training['learning_rate'],
233
+ weight_decay=config.training['weight_decay']
234
+ )
235
+
236
+ criterion = nn.CrossEntropyLoss()
237
+ scaler = GradScaler()
238
+
239
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
240
+ optimizer,
241
+ mode='min',
242
+ factor=0.1,
243
+ patience=5,
244
+ )
245
+
246
+ config_save_path = os.path.join(exp_dir, 'config.yml')
247
+ with open(config_save_path, 'w') as f:
248
+ yaml.dump(config, f)
249
+ logger.info(f"Configuration saved to: {config_save_path}")
250
+
251
+ logger.info("Starting training...")
252
+ for epoch in range(config.training['epochs']):
253
+ logger.info(f'\nEpoch {epoch+1}/{config.training["epochs"]}')
254
+
255
+ train_loss, train_acc = train_one_epoch(
256
+ model, train_loader, criterion, optimizer, scaler, device, epoch, writer, logger
257
+ )
258
+
259
+ val_loss, val_acc = validate_model(model, val_loader, criterion, device, epoch, writer, logger, config.model['num_clases_label'])
260
+ scheduler.step(val_loss)
261
+
262
+ current_lr = optimizer.param_groups[0]['lr']
263
+ writer.add_scalar('Training/LearningRate', current_lr, epoch)
264
+ logger.info(f'Current learning rate: {current_lr}')
265
+ save_checkpoint(model, optimizer, epoch, val_acc, config, exp_dir)
266
+
267
+ writer.close()
268
+
269
+ final_model_path = os.path.join(exp_dir, 'final_model.pth')
270
+ torch.save({
271
+ 'epoch': config.training['epochs'],
272
+ 'model_state_dict': model.state_dict(),
273
+ 'optimizer_state_dict': optimizer.state_dict(),
274
+ 'val_acc': val_acc,
275
+ 'config': config,
276
+ }, final_model_path)
277
+
278
+ logger.info(f"Training completed. Final model saved to: {final_model_path}")
279
+
280
+ if __name__ == "__main__":
281
+ mp.set_start_method('spawn', force=True)
282
+ train_model(CONFIG_FILE_PATH)
283
+ # tensorboard --logdir="/teamspace/studios/this_studio/Group-Activity-Recognition/modeling/baseline 1/outputs/Baseline_B1_tuned_V1_20241117_044805/tensorboard"