982865f
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch.nn as nn def get_loss(name="cross_entropy", device="cuda:0"): print(f"Using loss: '{LOSSES[name]}'") return LOSSES[name].to(device) LOSSES = { "binary_ce": nn.BCEWithLogitsLoss(), "cross_entropy": nn.CrossEntropyLoss() }