Spaces:
Sleeping
Sleeping
File size: 3,440 Bytes
022acf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import gc
import time
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import wandb
from newsclassifier.config.config import Cfg, logger
from newsclassifier.data import (NewsDataset, data_split, load_dataset,
preprocess)
from newsclassifier.models import CustomModel
from newsclassifier.train import eval_step, train_step
from newsclassifier.utils import read_yaml
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def tune_loop(config=None):
# ====================================================
# loader
# ====================================================
logger.info("Starting Tuning.")
with wandb.init(project="NewsClassifier", config=config):
config = wandb.config
df = load_dataset(Cfg.dataset_loc)
ds, headlines_df, class_to_index, index_to_class = preprocess(df)
train_ds, val_ds = data_split(ds, test_size=Cfg.test_size)
train_dataset = NewsDataset(train_ds)
valid_dataset = NewsDataset(val_ds)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
# ====================================================
# model
# ====================================================
num_classes = Cfg.num_classes
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CustomModel(num_classes=num_classes, dropout_pb=config.dropout_pb)
model.to(device)
# ====================================================
# Training components
# ====================================================
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=config.lr_reduce_factor, patience=config.lr_reduce_patience
)
# ====================================================
# loop
# ====================================================
wandb.watch(model, criterion, log="all", log_freq=10)
for epoch in range(config.epochs):
try:
start_time = time.time()
# Step
train_loss = train_step(train_loader, model, num_classes, criterion, optimizer, epoch)
val_loss, _, _ = eval_step(valid_loader, model, num_classes, criterion, epoch)
scheduler.step(val_loss)
# scoring
elapsed = time.time() - start_time
wandb.log({"epoch": epoch + 1, "train_loss": train_loss, "val_loss": val_loss})
print(f"Epoch {epoch+1} - avg_train_loss: {train_loss:.4f} avg_val_loss: {val_loss:.4f} time: {elapsed:.0f}s")
except Exception as e:
logger.error(f"Epoch {epoch+1}, {e}")
torch.cuda.empty_cache()
gc.collect()
if __name__ == "__main__":
sweep_config = read_yaml(Cfg.sweep_config_path)
sweep_id = wandb.sweep(sweep_config, project="NewsClassifier")
wandb.agent(sweep_id, tune_loop, count=Cfg.sweep_runs)
|