|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import LayerNorm, Linear, Dropout
|
|
from torch.nn.functional import gelu
|
|
from transformers import PretrainedConfig, PreTrainedModel
|
|
from .SwinCXRConfig import SwinCXRConfig
|
|
|
|
class SwinSelfAttention(nn.Module):
|
|
def __init__(self, embed_dim, num_heads, dropout):
|
|
super(SwinSelfAttention, self).__init__()
|
|
self.query = Linear(embed_dim, embed_dim)
|
|
self.key = Linear(embed_dim, embed_dim)
|
|
self.value = Linear(embed_dim, embed_dim)
|
|
self.dropout = Dropout(p=dropout)
|
|
|
|
def forward(self, x):
|
|
query = self.query(x)
|
|
key = self.key(x)
|
|
value = self.value(x)
|
|
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / query.size(-1)**0.5
|
|
attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1)
|
|
attention_output = torch.matmul(attention_weights, value)
|
|
return self.dropout(attention_output)
|
|
|
|
class SwinLayer(nn.Module):
|
|
def __init__(self, embed_dim, num_heads, dropout=0.1):
|
|
super(SwinLayer, self).__init__()
|
|
self.layernorm_before = LayerNorm(embed_dim)
|
|
self.attention = SwinSelfAttention(embed_dim, num_heads, dropout)
|
|
self.drop_path = Dropout(p=dropout)
|
|
self.layernorm_after = LayerNorm(embed_dim)
|
|
|
|
self.fc1 = Linear(embed_dim, 4 * embed_dim)
|
|
self.fc2 = Linear(4 * embed_dim, embed_dim)
|
|
self.intermediate_act_fn = gelu
|
|
|
|
def forward(self, x):
|
|
normed = self.layernorm_before(x)
|
|
attention_output = self.attention(normed)
|
|
attention_output = self.drop_path(attention_output)
|
|
x = x + attention_output
|
|
|
|
normed = self.layernorm_after(x)
|
|
intermediate = self.fc1(normed)
|
|
intermediate = self.intermediate_act_fn(intermediate)
|
|
output = self.fc2(intermediate)
|
|
|
|
return x + output
|
|
|
|
|
|
class SwinPatchEmbedding(nn.Module):
|
|
def __init__(self, in_channels=3, patch_size=4, embed_dim=128):
|
|
super(SwinPatchEmbedding, self).__init__()
|
|
self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
|
self.norm = LayerNorm(embed_dim)
|
|
|
|
def forward(self, x):
|
|
x = self.projection(x)
|
|
x = x.flatten(2).transpose(1, 2)
|
|
x = self.norm(x)
|
|
return x
|
|
|
|
class SwinEncoder(nn.Module):
|
|
def __init__(self, num_layers, embed_dim, num_heads, dropout=0.1):
|
|
super(SwinEncoder, self).__init__()
|
|
self.layers = nn.ModuleList([
|
|
SwinLayer(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
|
|
for _ in range(num_layers)
|
|
])
|
|
|
|
def forward(self, x):
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
return x
|
|
|
|
class SwinModelForCXRClassification(PreTrainedModel):
|
|
config_class = SwinCXRConfig
|
|
def __init__(self, config):
|
|
super(SwinModelForCXRClassification, self).__init__(config)
|
|
|
|
self.embeddings = nn.Module()
|
|
self.embeddings.patch_embeddings = SwinPatchEmbedding(
|
|
in_channels=3,
|
|
patch_size=4,
|
|
embed_dim=128
|
|
)
|
|
self.embeddings.norm = LayerNorm(128)
|
|
self.embeddings.dropout = Dropout(p=0.0)
|
|
|
|
self.encoder = SwinEncoder(
|
|
num_layers=4,
|
|
embed_dim=128,
|
|
num_heads=4,
|
|
dropout=0.1
|
|
)
|
|
|
|
self.layernorm = LayerNorm(128)
|
|
self.pooler = nn.AdaptiveAvgPool1d(output_size=1)
|
|
|
|
self.classifier = Linear(in_features=128, out_features=3, bias=True)
|
|
|
|
def forward(self, pixel_values, labels=None):
|
|
x = self.embeddings.patch_embeddings(pixel_values)
|
|
x = self.encoder(x)
|
|
|
|
x = self.layernorm(x)
|
|
x = self.pooler(x.transpose(1, 2)).squeeze(-1)
|
|
|
|
logits = self.classifier(x)
|
|
|
|
if labels is not None:
|
|
loss_fct = nn.CrossEntropyLoss()
|
|
loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1))
|
|
return loss, logits
|
|
|
|
return logits |