birdnet / model.py
ulichovick's picture
Create model.py
cdfd2e1 verified
raw
history blame
4.73 kB
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
class SurfinBird(nn.Module, PyTorchModelHubMixin):
def __init__(self, config: dict) -> None:
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=config["num_channels"],
out_channels=64,
kernel_size=7,
stride=2,
padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.relu1 = nn.ReLU()
self.mp1 = nn.MaxPool2d(kernel_size=2,
stride=2)
self.conv_block_2 = nn.Sequential(
nn.Conv2d(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,
stride=2)
)
self.conv_block_3 = nn.Sequential(
nn.Conv2d(
in_channels=64,
out_channels=128,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,
stride=2)
)
self.conv_block_4 = nn.Sequential(
nn.Conv2d(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(
in_channels=128,
out_channels=128,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,
stride=2)
)
self.conv_block_5 = nn.Sequential(
nn.Conv2d(
in_channels=128,
out_channels=256,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(
in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(
in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,
stride=2)
)
self.conv_block_6 = nn.Sequential(
nn.Conv2d(
in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(
in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.Conv2d(
in_channels=256,
out_channels=256,
kernel_size=3,
stride=1,
padding=1
),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2,
stride=2)
)
self.avgpool = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1))
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(in_features=config["hidden_units"]*1*1,
out_features=config["num_classes"])
)
def forward(self, x: torch.Tensor):
return self.classifier(self.avgpool(self.conv_block_6(self.conv_block_5(self.conv_block_4(self.conv_block_3(self.conv_block_2(self.mp1(self.relu1(self.bn1(self.conv1(x)))))))))))
config = {"num_channels": 3, "hidden_units": 256, "num_classes": 525}