ianpan's picture
Initial commit
231edce
raw
history blame
7.37 kB
import numpy as np
import torch
import torch.nn as nn
from transformers.models.distilbert.modeling_distilbert import Transformer as T
from .pooling import create_pool1d_layer
class Config:
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class Transformer(nn.Module):
"""
If predict_sequence is True, then the model will predict an output
for each element in the sequence. If False, then the model will
predict a single output for the sequence.
e.g., classifying each image in a CT scan vs the entire CT scan
"""
def __init__(self,
num_classes,
embedding_dim=512,
hidden_dim=1024,
n_layers=4,
n_heads=16,
dropout=0.2,
attention_dropout=0.1,
output_attentions=False,
activation='gelu',
output_hidden_states=False,
chunk_size_feed_forward=0,
predict_sequence=True,
pool=None
):
super().__init__()
config = Config(**{
'dim': embedding_dim,
'hidden_dim': hidden_dim,
'n_layers': n_layers,
'n_heads': n_heads,
'dropout': dropout,
'attention_dropout': attention_dropout,
'output_attentions': output_attentions,
'activation': activation,
'output_hidden_states': output_hidden_states,
'chunk_size_feed_forward': chunk_size_feed_forward
})
self.transformer = T(config)
self.predict_sequence = predict_sequence
if not predict_sequence:
if isinstance(pool, str):
self.pool_layer = create_pool1d_layer(pool)
if pool == "catavgmax":
embedding_dim *= 2
else:
self.pool_layer = nn.Identity()
self.classifier = nn.Linear(embedding_dim, num_classes)
def extract_features(self, x):
x, mask = x
x = self.transformer(x, attn_mask=mask, head_mask=[None]*x.size(1))
x = x[0]
if not self.predict_sequence:
if isinstance(self.pool_layer, nn.Identity):
# Just take the last vector in the sequence
x = x[:, 0]
else:
x = self.pool_layer(x.transpose(-1, -2))
return x
def classify(self, x):
if not self.predict_sequence:
if isinstance(self.pool_layer, nn.Identity):
# Just take the last vector in the sequence
x = x[:, 0]
else:
x = self.pool_layer(x.transpose(-1, -2))
out = self.classifier(x)
if self.classifier.out_features == 1:
return out[..., 0]
else:
return out
def forward_tr(self, x, mask):
output = self.transformer(x, attn_mask=mask, head_mask=[None]*x.size(1))
return self.classify(output[0])
def forward(self, x):
x, mask = x
return self.forward_tr(x, mask)
class DualTransformer(nn.Module):
"""
Essentially the same as above except predicts both sequence and study labels.
"""
def __init__(self,
num_classes,
embedding_dim=512,
hidden_dim=1024,
n_layers=4,
n_heads=16,
dropout=0.2,
attention_dropout=0.1,
output_attentions=False,
activation='gelu',
output_hidden_states=False,
chunk_size_feed_forward=0,
pool=None
):
super().__init__()
config = Config(**{
'dim': embedding_dim,
'hidden_dim': hidden_dim,
'n_layers': n_layers,
'n_heads': n_heads,
'dropout': dropout,
'attention_dropout': attention_dropout,
'output_attentions': output_attentions,
'activation': activation,
'output_hidden_states': output_hidden_states,
'chunk_size_feed_forward': chunk_size_feed_forward
})
self.transformer = T(config)
self.pool_layer = nn.Identity()
self.classifier1 = nn.Linear(embedding_dim, num_classes)
self.classifier2 = nn.Linear(embedding_dim, num_classes)
def classify(self, x):
if isinstance(self.pool_layer, nn.Identity):
# Just take the last vector in the sequence
x_summ = x[:, 0]
else:
x_summ = self.pool_layer(x.transpose(-1, -2))
# Element-wise labels
out1 = self.classifier1(x)[:, :, 0]
# Single label for whole sequence
out2 = self.classifier2(x_summ)
out = torch.cat([out1, out2], dim=1)
return out
def forward_tr(self, x, mask):
output = self.transformer(x, attn_mask=mask, head_mask=[None]*x.size(1))
return self.classify(output[0])
def forward(self, x):
x, mask = x
return self.forward_tr(x, mask)
class DualTransformerV2(nn.Module):
"""
More complicated variant of DualTransformer.
Returns tuple of: (element-wise prediction, sequence prediction)
"""
def __init__(self,
num_seq_classes,
num_classes,
embedding_dim=512,
hidden_dim=1024,
n_layers=4,
n_heads=16,
dropout=0.2,
attention_dropout=0.1,
output_attentions=False,
activation='gelu',
output_hidden_states=False,
chunk_size_feed_forward=0,
pool=None
):
super().__init__()
config = Config(**{
'dim': embedding_dim,
'hidden_dim': hidden_dim,
'n_layers': n_layers,
'n_heads': n_heads,
'dropout': dropout,
'attention_dropout': attention_dropout,
'output_attentions': output_attentions,
'activation': activation,
'output_hidden_states': output_hidden_states,
'chunk_size_feed_forward': chunk_size_feed_forward
})
self.transformer = T(config)
self.pool_layer = nn.Identity()
self.classifier1 = nn.Linear(embedding_dim, num_seq_classes)
self.classifier2 = nn.Linear(embedding_dim, num_classes)
def classify(self, x):
if isinstance(self.pool_layer, nn.Identity):
# Just take the last vector in the sequence
x_summ = x[:, 0]
else:
x_summ = self.pool_layer(x.transpose(-1, -2))
# Element-wise labels
out1 = self.classifier1(x)
if self.classifier1.out_features == 1:
out1 = out1[:, :, 0]
# Single label for whole sequence
out2 = self.classifier2(x_summ)
return out1, out2
def forward_tr(self, x, mask):
output = self.transformer(x, attn_mask=mask, head_mask=[None]*x.size(1))
return self.classify(output[0])
def forward(self, x):
x, mask = x
return self.forward_tr(x, mask)