import torch import torch.nn as nn from torch.nn import LayerNorm, Linear, Dropout from torch.nn.functional import gelu from transformers import PretrainedConfig, PreTrainedModel from .SwinCXRConfig import SwinCXRConfig class SwinSelfAttention(nn.Module): def __init__(self, embed_dim, num_heads, dropout): super(SwinSelfAttention, self).__init__() self.query = Linear(embed_dim, embed_dim) self.key = Linear(embed_dim, embed_dim) self.value = Linear(embed_dim, embed_dim) self.dropout = Dropout(p=dropout) def forward(self, x): query = self.query(x) key = self.key(x) value = self.value(x) attention_weights = torch.matmul(query, key.transpose(-2, -1)) / query.size(-1)**0.5 attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1) attention_output = torch.matmul(attention_weights, value) return self.dropout(attention_output) class SwinLayer(nn.Module): def __init__(self, embed_dim, num_heads, dropout=0.1): super(SwinLayer, self).__init__() self.layernorm_before = LayerNorm(embed_dim) self.attention = SwinSelfAttention(embed_dim, num_heads, dropout) self.drop_path = Dropout(p=dropout) self.layernorm_after = LayerNorm(embed_dim) self.fc1 = Linear(embed_dim, 4 * embed_dim) self.fc2 = Linear(4 * embed_dim, embed_dim) self.intermediate_act_fn = gelu def forward(self, x): normed = self.layernorm_before(x) attention_output = self.attention(normed) attention_output = self.drop_path(attention_output) x = x + attention_output normed = self.layernorm_after(x) intermediate = self.fc1(normed) intermediate = self.intermediate_act_fn(intermediate) output = self.fc2(intermediate) return x + output class SwinPatchEmbedding(nn.Module): def __init__(self, in_channels=3, patch_size=4, embed_dim=128): super(SwinPatchEmbedding, self).__init__() self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = LayerNorm(embed_dim) def forward(self, x): x = self.projection(x) x = x.flatten(2).transpose(1, 2) x = self.norm(x) return x class SwinEncoder(nn.Module): def __init__(self, num_layers, embed_dim, num_heads, dropout=0.1): super(SwinEncoder, self).__init__() self.layers = nn.ModuleList([ SwinLayer(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout) for _ in range(num_layers) ]) def forward(self, x): for layer in self.layers: x = layer(x) return x class SwinModelForCXRClassification(PreTrainedModel): config_class = SwinCXRConfig def __init__(self, config): super(SwinModelForCXRClassification, self).__init__(config) self.embeddings = nn.Module() self.embeddings.patch_embeddings = SwinPatchEmbedding( in_channels=3, patch_size=4, embed_dim=128 ) self.embeddings.norm = LayerNorm(128) self.embeddings.dropout = Dropout(p=0.0) self.encoder = SwinEncoder( num_layers=4, embed_dim=128, num_heads=4, dropout=0.1 ) self.layernorm = LayerNorm(128) self.pooler = nn.AdaptiveAvgPool1d(output_size=1) self.classifier = Linear(in_features=128, out_features=3, bias=True) def forward(self, pixel_values, labels=None): x = self.embeddings.patch_embeddings(pixel_values) x = self.encoder(x) x = self.layernorm(x) x = self.pooler(x.transpose(1, 2)).squeeze(-1) logits = self.classifier(x) if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1)) return loss, logits return logits