Upload baseline 4 model files and outputs
Browse files- baseline4/trainer.py +283 -0
baseline4/trainer.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import yaml
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import torch.nn as nn
|
8 |
+
import albumentations as A
|
9 |
+
import torch.optim as optim
|
10 |
+
import torch.multiprocessing as mp
|
11 |
+
from datetime import datetime
|
12 |
+
from albumentations.pytorch import ToTensorV2
|
13 |
+
from torch.cuda.amp import autocast, GradScaler
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
from torch.utils.tensorboard.writer import SummaryWriter
|
16 |
+
from model import Group_Activity_Temporal_Classifer , collate_fn
|
17 |
+
|
18 |
+
ROOT = "/kaggle/"
|
19 |
+
PROJECT_ROOT= "/kaggle/working/Group-Activity-Recognition"
|
20 |
+
CONFIG_FILE_PATH = "/kaggle/working/Group-Activity-Recognition/modeling/configs/Baseline B4.yml"
|
21 |
+
|
22 |
+
sys.path.append(os.path.abspath(PROJECT_ROOT))
|
23 |
+
|
24 |
+
from data_utils import Group_Activity_DataSet, group_activity_labels
|
25 |
+
from eval_utils import get_f1_score , plot_confusion_matrix
|
26 |
+
from helper_utils import load_config, setup_logging, save_checkpoint
|
27 |
+
|
28 |
+
def set_seed(seed):
|
29 |
+
random.seed(seed)
|
30 |
+
np.random.seed(seed)
|
31 |
+
torch.manual_seed(seed)
|
32 |
+
torch.cuda.manual_seed(seed)
|
33 |
+
torch.cuda.manual_seed_all(seed)
|
34 |
+
torch.backends.cudnn.deterministic = True
|
35 |
+
torch.backends.cudnn.benchmark = False
|
36 |
+
|
37 |
+
def train_one_epoch(model, train_loader, criterion, optimizer, scaler, device, epoch, writer, logger):
|
38 |
+
model.train()
|
39 |
+
total_loss = 0
|
40 |
+
correct = 0
|
41 |
+
total = 0
|
42 |
+
|
43 |
+
for batch_idx, (inputs, targets) in enumerate(train_loader):
|
44 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
45 |
+
optimizer.zero_grad()
|
46 |
+
|
47 |
+
with autocast(dtype=torch.float16):
|
48 |
+
outputs = model(inputs)
|
49 |
+
loss = criterion(outputs, targets)
|
50 |
+
|
51 |
+
scaler.scale(loss).backward()
|
52 |
+
scaler.step(optimizer)
|
53 |
+
scaler.update()
|
54 |
+
|
55 |
+
total_loss += loss.item()
|
56 |
+
|
57 |
+
predicted = outputs.argmax(1)
|
58 |
+
target_class = targets.argmax(1)
|
59 |
+
total += targets.size(0)
|
60 |
+
correct += predicted.eq(target_class).sum().item()
|
61 |
+
|
62 |
+
if batch_idx % 10 == 0:
|
63 |
+
step = epoch * len(train_loader) + batch_idx
|
64 |
+
writer.add_scalar('Training/BatchLoss', loss.item(), step)
|
65 |
+
writer.add_scalar('Training/BatchAccuracy', 100.*correct/total, step)
|
66 |
+
|
67 |
+
log_msg = f'Epoch: {epoch} | Batch: {batch_idx}/{len(train_loader)} | Loss: {loss.item():.4f} | Acc: {100.*correct/total:.2f}%'
|
68 |
+
logger.info(log_msg)
|
69 |
+
|
70 |
+
epoch_loss = total_loss / len(train_loader)
|
71 |
+
epoch_acc = 100. * correct / total
|
72 |
+
|
73 |
+
logger.info(f"Epoch {epoch} - Train Loss: {epoch_loss:.4f} - Accuracy: {epoch_acc:.2f}%")
|
74 |
+
|
75 |
+
writer.add_scalar('Training/EpochLoss', epoch_loss, epoch)
|
76 |
+
writer.add_scalar('Training/EpochAccuracy', epoch_acc, epoch)
|
77 |
+
|
78 |
+
return epoch_loss, epoch_acc
|
79 |
+
|
80 |
+
def validate_model(model, val_loader, criterion, device, epoch, writer, logger, class_names):
|
81 |
+
model.eval()
|
82 |
+
total_loss = 0
|
83 |
+
correct = 0
|
84 |
+
total = 0
|
85 |
+
|
86 |
+
y_true = []
|
87 |
+
y_pred = []
|
88 |
+
|
89 |
+
with torch.no_grad():
|
90 |
+
for inputs, targets in val_loader:
|
91 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
92 |
+
|
93 |
+
outputs = model(inputs)
|
94 |
+
loss = criterion(outputs, targets)
|
95 |
+
|
96 |
+
total_loss += loss.item()
|
97 |
+
|
98 |
+
predicted = outputs.argmax(1)
|
99 |
+
target_class = targets.argmax(1)
|
100 |
+
total += targets.size(0)
|
101 |
+
correct += predicted.eq(target_class).sum().item()
|
102 |
+
|
103 |
+
|
104 |
+
y_true.extend(target_class.cpu().numpy())
|
105 |
+
y_pred.extend(predicted.cpu().numpy())
|
106 |
+
|
107 |
+
avg_loss = total_loss / len(val_loader)
|
108 |
+
accuracy = 100. * correct / total
|
109 |
+
|
110 |
+
f1_score = get_f1_score(y_true, y_pred, average="weighted")
|
111 |
+
writer.add_scalar('Validation/F1Score', f1_score, epoch)
|
112 |
+
|
113 |
+
fig = plot_confusion_matrix(y_true, y_pred, class_names)
|
114 |
+
writer.add_figure('Validation/ConfusionMatrix', fig, epoch)
|
115 |
+
|
116 |
+
writer.add_scalar('Validation/Loss', avg_loss, epoch)
|
117 |
+
writer.add_scalar('Validation/Accuracy', accuracy, epoch)
|
118 |
+
|
119 |
+
logger.info(f"Epoch {epoch} | Valid Loss: {avg_loss:.4f} | Accuracy: {accuracy:.2f}% | F1 Score: {f1_score:.4f}")
|
120 |
+
|
121 |
+
return avg_loss, accuracy
|
122 |
+
|
123 |
+
def train_model(config_path):
|
124 |
+
|
125 |
+
config = load_config(config_path)
|
126 |
+
|
127 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
128 |
+
exp_dir = os.path.join(
|
129 |
+
f"{PROJECT_ROOT}/modeling/baseline 4/{config.experiment['output_dir']}",
|
130 |
+
f"{config.experiment['name']}_V{config.experiment['version']}_{timestamp}"
|
131 |
+
)
|
132 |
+
os.makedirs(exp_dir, exist_ok=True)
|
133 |
+
|
134 |
+
logger = setup_logging(exp_dir)
|
135 |
+
logger.info(f"Starting experiment: {config.experiment['name']}_V{config.experiment['version']}")
|
136 |
+
|
137 |
+
writer = SummaryWriter(log_dir=os.path.join(exp_dir, 'tensorboard'))
|
138 |
+
|
139 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
140 |
+
logger.info(f"Using device: {device}")
|
141 |
+
|
142 |
+
set_seed(config.experiment['seed'])
|
143 |
+
logger.info(f"Set random seed: {config.experiment['seed']}")
|
144 |
+
|
145 |
+
train_transforms = A.Compose([
|
146 |
+
A.Resize(224, 224),
|
147 |
+
A.OneOf([
|
148 |
+
A.GaussianBlur(blur_limit=(3, 7)),
|
149 |
+
A.ColorJitter(brightness=0.2),
|
150 |
+
A.RandomBrightnessContrast(),
|
151 |
+
A.GaussNoise()
|
152 |
+
], p=0.5),
|
153 |
+
A.OneOf([
|
154 |
+
A.HorizontalFlip(),
|
155 |
+
A.VerticalFlip(),
|
156 |
+
], p=0.05),
|
157 |
+
A.Normalize(
|
158 |
+
mean=[0.485, 0.456, 0.406],
|
159 |
+
std=[0.229, 0.224, 0.225]
|
160 |
+
),
|
161 |
+
ToTensorV2()
|
162 |
+
])
|
163 |
+
|
164 |
+
val_transforms = A.Compose([
|
165 |
+
A.Resize(224, 224),
|
166 |
+
A.Normalize(
|
167 |
+
mean=[0.485, 0.456, 0.406],
|
168 |
+
std=[0.229, 0.224, 0.225]
|
169 |
+
),
|
170 |
+
ToTensorV2()
|
171 |
+
])
|
172 |
+
|
173 |
+
train_dataset = Group_Activity_DataSet(
|
174 |
+
videos_path=f"{ROOT}/{config.data['videos_path']}",
|
175 |
+
annot_path=f"{ROOT}/{config.data['annot_path']}",
|
176 |
+
split=config.data['video_splits']['train'],
|
177 |
+
crops=False,
|
178 |
+
seq=True,
|
179 |
+
labels=group_activity_labels,
|
180 |
+
transform=train_transforms
|
181 |
+
)
|
182 |
+
|
183 |
+
val_dataset = Group_Activity_DataSet(
|
184 |
+
videos_path=f"{ROOT}/{config.data['videos_path']}",
|
185 |
+
annot_path=f"{ROOT}/{config.data['annot_path']}",
|
186 |
+
split=config.data['video_splits']['validation'],
|
187 |
+
crops=False,
|
188 |
+
seq=True,
|
189 |
+
labels=group_activity_labels,
|
190 |
+
transform=val_transforms
|
191 |
+
)
|
192 |
+
|
193 |
+
logger.info(f"Training dataset size: {len(train_dataset)}")
|
194 |
+
logger.info(f"Validation dataset size: {len(val_dataset)}")
|
195 |
+
|
196 |
+
train_loader = DataLoader(
|
197 |
+
train_dataset,
|
198 |
+
batch_size=config.training['batch_size']['train'],
|
199 |
+
shuffle=True,
|
200 |
+
collate_fn=collate_fn,
|
201 |
+
num_workers=4,
|
202 |
+
pin_memory=True
|
203 |
+
)
|
204 |
+
|
205 |
+
val_loader = DataLoader(
|
206 |
+
val_dataset,
|
207 |
+
batch_size=config.training['batch_size']['val'],
|
208 |
+
shuffle=True,
|
209 |
+
collate_fn=collate_fn,
|
210 |
+
num_workers=4,
|
211 |
+
pin_memory=True
|
212 |
+
)
|
213 |
+
|
214 |
+
model = Group_Activity_Temporal_Classifer(
|
215 |
+
num_classes=config.model['num_classes'],
|
216 |
+
input_size=config.model['input_size'],
|
217 |
+
hidden_size=config.model['hidden_size'],
|
218 |
+
num_layers=config.model['num_layers']
|
219 |
+
)
|
220 |
+
|
221 |
+
model = model.to(device)
|
222 |
+
|
223 |
+
if config.training['optimizer'] == "AdamW":
|
224 |
+
optimizer = optim.AdamW(
|
225 |
+
model.parameters(),
|
226 |
+
lr=config.training['learning_rate'],
|
227 |
+
weight_decay=config.training['weight_decay']
|
228 |
+
)
|
229 |
+
elif config.training['optimizer'] == "SGD":
|
230 |
+
optimizer = optim.SGD(
|
231 |
+
model.parameters(),
|
232 |
+
lr=config.training['learning_rate'],
|
233 |
+
weight_decay=config.training['weight_decay']
|
234 |
+
)
|
235 |
+
|
236 |
+
criterion = nn.CrossEntropyLoss()
|
237 |
+
scaler = GradScaler()
|
238 |
+
|
239 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
240 |
+
optimizer,
|
241 |
+
mode='min',
|
242 |
+
factor=0.1,
|
243 |
+
patience=5,
|
244 |
+
)
|
245 |
+
|
246 |
+
config_save_path = os.path.join(exp_dir, 'config.yml')
|
247 |
+
with open(config_save_path, 'w') as f:
|
248 |
+
yaml.dump(config, f)
|
249 |
+
logger.info(f"Configuration saved to: {config_save_path}")
|
250 |
+
|
251 |
+
logger.info("Starting training...")
|
252 |
+
for epoch in range(config.training['epochs']):
|
253 |
+
logger.info(f'\nEpoch {epoch+1}/{config.training["epochs"]}')
|
254 |
+
|
255 |
+
train_loss, train_acc = train_one_epoch(
|
256 |
+
model, train_loader, criterion, optimizer, scaler, device, epoch, writer, logger
|
257 |
+
)
|
258 |
+
|
259 |
+
val_loss, val_acc = validate_model(model, val_loader, criterion, device, epoch, writer, logger, config.model['num_clases_label'])
|
260 |
+
scheduler.step(val_loss)
|
261 |
+
|
262 |
+
current_lr = optimizer.param_groups[0]['lr']
|
263 |
+
writer.add_scalar('Training/LearningRate', current_lr, epoch)
|
264 |
+
logger.info(f'Current learning rate: {current_lr}')
|
265 |
+
save_checkpoint(model, optimizer, epoch, val_acc, config, exp_dir)
|
266 |
+
|
267 |
+
writer.close()
|
268 |
+
|
269 |
+
final_model_path = os.path.join(exp_dir, 'final_model.pth')
|
270 |
+
torch.save({
|
271 |
+
'epoch': config.training['epochs'],
|
272 |
+
'model_state_dict': model.state_dict(),
|
273 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
274 |
+
'val_acc': val_acc,
|
275 |
+
'config': config,
|
276 |
+
}, final_model_path)
|
277 |
+
|
278 |
+
logger.info(f"Training completed. Final model saved to: {final_model_path}")
|
279 |
+
|
280 |
+
if __name__ == "__main__":
|
281 |
+
mp.set_start_method('spawn', force=True)
|
282 |
+
train_model(CONFIG_FILE_PATH)
|
283 |
+
# tensorboard --logdir="/teamspace/studios/this_studio/Group-Activity-Recognition/modeling/baseline 1/outputs/Baseline_B1_tuned_V1_20241117_044805/tensorboard"
|