File size: 4,171 Bytes
7292c9f
 
 
 
 
0e4b91d
7292c9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
from torch.nn import LayerNorm, Linear, Dropout
from torch.nn.functional import gelu
from transformers import PretrainedConfig, PreTrainedModel
from .SwinCXRConfig import SwinCXRConfig

class SwinSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout):
        super(SwinSelfAttention, self).__init__()
        self.query = Linear(embed_dim, embed_dim)
        self.key = Linear(embed_dim, embed_dim)
        self.value = Linear(embed_dim, embed_dim)
        self.dropout = Dropout(p=dropout)

    def forward(self, x):
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)
        attention_weights = torch.matmul(query, key.transpose(-2, -1)) / query.size(-1)**0.5
        attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1)
        attention_output = torch.matmul(attention_weights, value)
        return self.dropout(attention_output)

class SwinLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(SwinLayer, self).__init__()
        self.layernorm_before = LayerNorm(embed_dim)
        self.attention = SwinSelfAttention(embed_dim, num_heads, dropout)
        self.drop_path = Dropout(p=dropout)
        self.layernorm_after = LayerNorm(embed_dim)

        self.fc1 = Linear(embed_dim, 4 * embed_dim)
        self.fc2 = Linear(4 * embed_dim, embed_dim)
        self.intermediate_act_fn = gelu

    def forward(self, x):
        normed = self.layernorm_before(x)
        attention_output = self.attention(normed)
        attention_output = self.drop_path(attention_output)
        x = x + attention_output
        
        normed = self.layernorm_after(x)
        intermediate = self.fc1(normed)
        intermediate = self.intermediate_act_fn(intermediate)
        output = self.fc2(intermediate)
        
        return x + output


class SwinPatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, patch_size=4, embed_dim=128):
        super(SwinPatchEmbedding, self).__init__()
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = LayerNorm(embed_dim)

    def forward(self, x):
        x = self.projection(x)
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x

class SwinEncoder(nn.Module):
    def __init__(self, num_layers, embed_dim, num_heads, dropout=0.1):
        super(SwinEncoder, self).__init__()
        self.layers = nn.ModuleList([ 
            SwinLayer(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
            for _ in range(num_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class SwinModelForCXRClassification(PreTrainedModel):
    config_class = SwinCXRConfig
    def __init__(self, config):
        super(SwinModelForCXRClassification, self).__init__(config)
        
        self.embeddings = nn.Module()
        self.embeddings.patch_embeddings = SwinPatchEmbedding(
            in_channels=3, 
            patch_size=4,
            embed_dim=128
        )
        self.embeddings.norm = LayerNorm(128)
        self.embeddings.dropout = Dropout(p=0.0)

        self.encoder = SwinEncoder(
            num_layers=4, 
            embed_dim=128, 
            num_heads=4,  
            dropout=0.1
        )

        self.layernorm = LayerNorm(128)
        self.pooler = nn.AdaptiveAvgPool1d(output_size=1)

        self.classifier = Linear(in_features=128, out_features=3, bias=True)

    def forward(self, pixel_values, labels=None):
        x = self.embeddings.patch_embeddings(pixel_values) 
        x = self.encoder(x) 

        x = self.layernorm(x)
        x = self.pooler(x.transpose(1, 2)).squeeze(-1)  

        logits = self.classifier(x)

        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1))
            return loss, logits
        
        return logits