import numpy as np import torch import torch.nn as nn from transformers.models.distilbert.modeling_distilbert import Transformer as T from .pooling import create_pool1d_layer class Config: def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) class Transformer(nn.Module): """ If predict_sequence is True, then the model will predict an output for each element in the sequence. If False, then the model will predict a single output for the sequence. e.g., classifying each image in a CT scan vs the entire CT scan """ def __init__(self, num_classes, embedding_dim=512, hidden_dim=1024, n_layers=4, n_heads=16, dropout=0.2, attention_dropout=0.1, output_attentions=False, activation='gelu', output_hidden_states=False, chunk_size_feed_forward=0, predict_sequence=True, pool=None ): super().__init__() config = Config(**{ 'dim': embedding_dim, 'hidden_dim': hidden_dim, 'n_layers': n_layers, 'n_heads': n_heads, 'dropout': dropout, 'attention_dropout': attention_dropout, 'output_attentions': output_attentions, 'activation': activation, 'output_hidden_states': output_hidden_states, 'chunk_size_feed_forward': chunk_size_feed_forward }) self.transformer = T(config) self.predict_sequence = predict_sequence if not predict_sequence: if isinstance(pool, str): self.pool_layer = create_pool1d_layer(pool) if pool == "catavgmax": embedding_dim *= 2 else: self.pool_layer = nn.Identity() self.classifier = nn.Linear(embedding_dim, num_classes) def extract_features(self, x): x, mask = x x = self.transformer(x, attn_mask=mask, head_mask=[None]*x.size(1)) x = x[0] if not self.predict_sequence: if isinstance(self.pool_layer, nn.Identity): # Just take the last vector in the sequence x = x[:, 0] else: x = self.pool_layer(x.transpose(-1, -2)) return x def classify(self, x): if not self.predict_sequence: if isinstance(self.pool_layer, nn.Identity): # Just take the last vector in the sequence x = x[:, 0] else: x = self.pool_layer(x.transpose(-1, -2)) out = self.classifier(x) if self.classifier.out_features == 1: return out[..., 0] else: return out def forward_tr(self, x, mask): output = self.transformer(x, attn_mask=mask, head_mask=[None]*x.size(1)) return self.classify(output[0]) def forward(self, x): x, mask = x return self.forward_tr(x, mask) class DualTransformer(nn.Module): """ Essentially the same as above except predicts both sequence and study labels. """ def __init__(self, num_classes, embedding_dim=512, hidden_dim=1024, n_layers=4, n_heads=16, dropout=0.2, attention_dropout=0.1, output_attentions=False, activation='gelu', output_hidden_states=False, chunk_size_feed_forward=0, pool=None ): super().__init__() config = Config(**{ 'dim': embedding_dim, 'hidden_dim': hidden_dim, 'n_layers': n_layers, 'n_heads': n_heads, 'dropout': dropout, 'attention_dropout': attention_dropout, 'output_attentions': output_attentions, 'activation': activation, 'output_hidden_states': output_hidden_states, 'chunk_size_feed_forward': chunk_size_feed_forward }) self.transformer = T(config) self.pool_layer = nn.Identity() self.classifier1 = nn.Linear(embedding_dim, num_classes) self.classifier2 = nn.Linear(embedding_dim, num_classes) def classify(self, x): if isinstance(self.pool_layer, nn.Identity): # Just take the last vector in the sequence x_summ = x[:, 0] else: x_summ = self.pool_layer(x.transpose(-1, -2)) # Element-wise labels out1 = self.classifier1(x)[:, :, 0] # Single label for whole sequence out2 = self.classifier2(x_summ) out = torch.cat([out1, out2], dim=1) return out def forward_tr(self, x, mask): output = self.transformer(x, attn_mask=mask, head_mask=[None]*x.size(1)) return self.classify(output[0]) def forward(self, x): x, mask = x return self.forward_tr(x, mask) class DualTransformerV2(nn.Module): """ More complicated variant of DualTransformer. Returns tuple of: (element-wise prediction, sequence prediction) """ def __init__(self, num_seq_classes, num_classes, embedding_dim=512, hidden_dim=1024, n_layers=4, n_heads=16, dropout=0.2, attention_dropout=0.1, output_attentions=False, activation='gelu', output_hidden_states=False, chunk_size_feed_forward=0, pool=None ): super().__init__() config = Config(**{ 'dim': embedding_dim, 'hidden_dim': hidden_dim, 'n_layers': n_layers, 'n_heads': n_heads, 'dropout': dropout, 'attention_dropout': attention_dropout, 'output_attentions': output_attentions, 'activation': activation, 'output_hidden_states': output_hidden_states, 'chunk_size_feed_forward': chunk_size_feed_forward }) self.transformer = T(config) self.pool_layer = nn.Identity() self.classifier1 = nn.Linear(embedding_dim, num_seq_classes) self.classifier2 = nn.Linear(embedding_dim, num_classes) def classify(self, x): if isinstance(self.pool_layer, nn.Identity): # Just take the last vector in the sequence x_summ = x[:, 0] else: x_summ = self.pool_layer(x.transpose(-1, -2)) # Element-wise labels out1 = self.classifier1(x) if self.classifier1.out_features == 1: out1 = out1[:, :, 0] # Single label for whole sequence out2 = self.classifier2(x_summ) return out1, out2 def forward_tr(self, x, mask): output = self.transformer(x, attn_mask=mask, head_mask=[None]*x.size(1)) return self.classify(output[0]) def forward(self, x): x, mask = x return self.forward_tr(x, mask)