P-DFD / loss /__init__.py
mrneuralnet's picture
Initial commit
982865f
raw
history blame contribute delete
254 Bytes
import torch.nn as nn
def get_loss(name="cross_entropy", device="cuda:0"):
print(f"Using loss: '{LOSSES[name]}'")
return LOSSES[name].to(device)
LOSSES = {
"binary_ce": nn.BCEWithLogitsLoss(),
"cross_entropy": nn.CrossEntropyLoss()
}