Spaces:
Runtime error
Runtime error
File size: 712 Bytes
5e4b3a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from torch import nn
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.losses import DiceLoss
ENCODER = 'timm-efficientnet-b0'
WEIGHTS = 'imagenet'
class SegmentationModel(nn.Module):
def __init__(self):
super(SegmentationModel, self).__init__()
self.arc = smp.Unet(
encoder_name = ENCODER,
encoder_weights = WEIGHTS,
in_channels = 3,
classes = 1,
activation = None
)
def forward(self, images, masks = None):
logits = self.arc(images)
if masks != None:
loss1 = DiceLoss(mode='binary')(logits, masks)
loss2 = nn.BCEWithLogitsLoss()(logits, masks)
return logits, loss1 + loss2
return logits
|