SwinCXR / SwinModelForCXRClassification.py
amartyasaran's picture
Upload model
0e4b91d verified
import torch
import torch.nn as nn
from torch.nn import LayerNorm, Linear, Dropout
from torch.nn.functional import gelu
from transformers import PretrainedConfig, PreTrainedModel
from .SwinCXRConfig import SwinCXRConfig
class SwinSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout):
super(SwinSelfAttention, self).__init__()
self.query = Linear(embed_dim, embed_dim)
self.key = Linear(embed_dim, embed_dim)
self.value = Linear(embed_dim, embed_dim)
self.dropout = Dropout(p=dropout)
def forward(self, x):
query = self.query(x)
key = self.key(x)
value = self.value(x)
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / query.size(-1)**0.5
attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1)
attention_output = torch.matmul(attention_weights, value)
return self.dropout(attention_output)
class SwinLayer(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super(SwinLayer, self).__init__()
self.layernorm_before = LayerNorm(embed_dim)
self.attention = SwinSelfAttention(embed_dim, num_heads, dropout)
self.drop_path = Dropout(p=dropout)
self.layernorm_after = LayerNorm(embed_dim)
self.fc1 = Linear(embed_dim, 4 * embed_dim)
self.fc2 = Linear(4 * embed_dim, embed_dim)
self.intermediate_act_fn = gelu
def forward(self, x):
normed = self.layernorm_before(x)
attention_output = self.attention(normed)
attention_output = self.drop_path(attention_output)
x = x + attention_output
normed = self.layernorm_after(x)
intermediate = self.fc1(normed)
intermediate = self.intermediate_act_fn(intermediate)
output = self.fc2(intermediate)
return x + output
class SwinPatchEmbedding(nn.Module):
def __init__(self, in_channels=3, patch_size=4, embed_dim=128):
super(SwinPatchEmbedding, self).__init__()
self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = LayerNorm(embed_dim)
def forward(self, x):
x = self.projection(x)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x
class SwinEncoder(nn.Module):
def __init__(self, num_layers, embed_dim, num_heads, dropout=0.1):
super(SwinEncoder, self).__init__()
self.layers = nn.ModuleList([
SwinLayer(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
for _ in range(num_layers)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class SwinModelForCXRClassification(PreTrainedModel):
config_class = SwinCXRConfig
def __init__(self, config):
super(SwinModelForCXRClassification, self).__init__(config)
self.embeddings = nn.Module()
self.embeddings.patch_embeddings = SwinPatchEmbedding(
in_channels=3,
patch_size=4,
embed_dim=128
)
self.embeddings.norm = LayerNorm(128)
self.embeddings.dropout = Dropout(p=0.0)
self.encoder = SwinEncoder(
num_layers=4,
embed_dim=128,
num_heads=4,
dropout=0.1
)
self.layernorm = LayerNorm(128)
self.pooler = nn.AdaptiveAvgPool1d(output_size=1)
self.classifier = Linear(in_features=128, out_features=3, bias=True)
def forward(self, pixel_values, labels=None):
x = self.embeddings.patch_embeddings(pixel_values)
x = self.encoder(x)
x = self.layernorm(x)
x = self.pooler(x.transpose(1, 2)).squeeze(-1)
logits = self.classifier(x)
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1))
return loss, logits
return logits