File size: 254 Bytes
982865f
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch.nn as nn


def get_loss(name="cross_entropy", device="cuda:0"):
    print(f"Using loss: '{LOSSES[name]}'")
    return LOSSES[name].to(device)


LOSSES = {
    "binary_ce": nn.BCEWithLogitsLoss(),
    "cross_entropy": nn.CrossEntropyLoss()
}