smishr-18 commited on
Commit
bd72f48
·
verified ·
1 Parent(s): b0c2729

Delete unet.py

Browse files
Files changed (1) hide show
  1. unet.py +0 -102
unet.py DELETED
@@ -1,102 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- class DownSampling(nn.Module):
5
-
6
- def __init__(self, in_channels, out_channels, max_pool):
7
- """
8
- DownSampling block in the U-Net architecture.
9
-
10
- Args:
11
- in_channels (int): Number of input channels.
12
- out_channels (int): Number of output channels.
13
- max_pool (bool): Whether to use max pooling.
14
- """
15
- super(DownSampling, self).__init__()
16
- self.max_pool = max_pool
17
- self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
18
- self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
19
- self.batchnorm2d = nn.BatchNorm2d(out_channels)
20
- self.relu = nn.ReLU()
21
- self.maxpool2d = nn.MaxPool2d(kernel_size=2, stride=2)
22
-
23
- def forward(self, x):
24
- x = self.conv1(x)
25
- x = self.conv2(x)
26
-
27
- x = self.relu(self.batchnorm2d(x))
28
- skip_connection = x
29
-
30
- if self.max_pool:
31
- next_layer = self.maxpool2d(x)
32
- else:
33
- return x
34
- return next_layer, skip_connection
35
-
36
- class UpSampling(nn.Module):
37
- def __init__(self, in_channels, out_channels):
38
- """
39
- UpSampling block in the U-Net architecture.
40
-
41
- Args:
42
- in_channels (int): Number of input channels.
43
- out_channels (int): Number of output channels.
44
- """
45
- super(UpSampling, self).__init__()
46
- self.up = nn.ConvTranspose2d(in_channels, out_channels=out_channels, kernel_size=2, stride=2)
47
- self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
48
- self.relu = nn.ReLU()
49
- self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
50
- self.batchnorm = nn.BatchNorm2d(out_channels)
51
-
52
- def forward(self, x, prev_skip):
53
- x = self.up(x)
54
- x = torch.cat((x, prev_skip), dim=1)
55
- x = self.conv1(x)
56
- x = self.conv2(x)
57
- next_layer = self.relu(self.batchnorm(x))
58
- return next_layer
59
-
60
- class UNet(nn.Module):
61
-
62
- """
63
- U-Net architecture.
64
-
65
- Args:
66
- in_channels (int): Number of input channels.
67
- out_channels (int): Number of output channels.
68
- features (list): List of feature sizes for downsampling and upsampling.
69
- """
70
- def __init__(self, in_channels, out_channels, features):
71
- super(UNet, self).__init__()
72
- self.ups = nn.ModuleList()
73
- self.downs = nn.ModuleList()
74
-
75
- for feature in features:
76
- self.downs.append(DownSampling(in_channels, feature, True))
77
- in_channels = feature
78
-
79
- for feature in reversed(features):
80
- self.ups.append(UpSampling(2 * feature, feature))
81
-
82
- self.bottleneck = DownSampling(features[-1], 2 * features[-1], False)
83
- self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
84
-
85
- def forward(self, x):
86
- skip_connections = []
87
- for down in self.downs:
88
- x, skip_connection = down(x)
89
- skip_connections.append(skip_connection)
90
- skip_connections = skip_connections[::-1]
91
- x = self.bottleneck(x)
92
- for i, up in enumerate(self.ups):
93
- x = up(x, skip_connections[i])
94
-
95
- return self.final_conv(x)
96
-
97
- if __name__ == "__main__":
98
- #Example Usage
99
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
100
- features = [64, 128, 256, 512]
101
- model = UNet(1, 1, features=features).to(device)
102
- print(model(torch.rand(1, 1, 512, 512)).shape)