File size: 4,171 Bytes
7292c9f 0e4b91d 7292c9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import torch
import torch.nn as nn
from torch.nn import LayerNorm, Linear, Dropout
from torch.nn.functional import gelu
from transformers import PretrainedConfig, PreTrainedModel
from .SwinCXRConfig import SwinCXRConfig
class SwinSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout):
super(SwinSelfAttention, self).__init__()
self.query = Linear(embed_dim, embed_dim)
self.key = Linear(embed_dim, embed_dim)
self.value = Linear(embed_dim, embed_dim)
self.dropout = Dropout(p=dropout)
def forward(self, x):
query = self.query(x)
key = self.key(x)
value = self.value(x)
attention_weights = torch.matmul(query, key.transpose(-2, -1)) / query.size(-1)**0.5
attention_weights = torch.nn.functional.softmax(attention_weights, dim=-1)
attention_output = torch.matmul(attention_weights, value)
return self.dropout(attention_output)
class SwinLayer(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super(SwinLayer, self).__init__()
self.layernorm_before = LayerNorm(embed_dim)
self.attention = SwinSelfAttention(embed_dim, num_heads, dropout)
self.drop_path = Dropout(p=dropout)
self.layernorm_after = LayerNorm(embed_dim)
self.fc1 = Linear(embed_dim, 4 * embed_dim)
self.fc2 = Linear(4 * embed_dim, embed_dim)
self.intermediate_act_fn = gelu
def forward(self, x):
normed = self.layernorm_before(x)
attention_output = self.attention(normed)
attention_output = self.drop_path(attention_output)
x = x + attention_output
normed = self.layernorm_after(x)
intermediate = self.fc1(normed)
intermediate = self.intermediate_act_fn(intermediate)
output = self.fc2(intermediate)
return x + output
class SwinPatchEmbedding(nn.Module):
def __init__(self, in_channels=3, patch_size=4, embed_dim=128):
super(SwinPatchEmbedding, self).__init__()
self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = LayerNorm(embed_dim)
def forward(self, x):
x = self.projection(x)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x
class SwinEncoder(nn.Module):
def __init__(self, num_layers, embed_dim, num_heads, dropout=0.1):
super(SwinEncoder, self).__init__()
self.layers = nn.ModuleList([
SwinLayer(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
for _ in range(num_layers)
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class SwinModelForCXRClassification(PreTrainedModel):
config_class = SwinCXRConfig
def __init__(self, config):
super(SwinModelForCXRClassification, self).__init__(config)
self.embeddings = nn.Module()
self.embeddings.patch_embeddings = SwinPatchEmbedding(
in_channels=3,
patch_size=4,
embed_dim=128
)
self.embeddings.norm = LayerNorm(128)
self.embeddings.dropout = Dropout(p=0.0)
self.encoder = SwinEncoder(
num_layers=4,
embed_dim=128,
num_heads=4,
dropout=0.1
)
self.layernorm = LayerNorm(128)
self.pooler = nn.AdaptiveAvgPool1d(output_size=1)
self.classifier = Linear(in_features=128, out_features=3, bias=True)
def forward(self, pixel_values, labels=None):
x = self.embeddings.patch_embeddings(pixel_values)
x = self.encoder(x)
x = self.layernorm(x)
x = self.pooler(x.transpose(1, 2)).squeeze(-1)
logits = self.classifier(x)
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_classes), labels.view(-1))
return loss, logits
return logits |