# -*- coding: utf-8 -*- | |
# @Time : 2024/8/6 下午3:44 | |
# @Author : xiaoshun | |
# @Email : [email protected] | |
# @File : unetmobv2.py | |
# @Software: PyCharm | |
import segmentation_models_pytorch as smp | |
import torch | |
from torch import nn as nn | |
class UNetMobV2(nn.Module): | |
def __init__(self,num_classes,in_channels=3): | |
super().__init__() | |
self.backbone = smp.Unet( | |
encoder_name='mobilenet_v2', | |
encoder_weights=None, | |
in_channels=in_channels, | |
classes=num_classes, | |
) | |
def forward(self, x): | |
x = self.backbone(x) | |
return x | |
if __name__ == '__main__': | |
fake_image = torch.rand(1, 3, 224, 224) | |
model = UNetMobV2(num_classes=2) | |
output = model(fake_image) | |
print(output.size()) |