|
|
|
from sklearn.metrics import log_loss |
|
import torch.nn as nn |
|
import torch |
|
import math |
|
import numpy as np |
|
from torch.nn.utils.rnn import pad_sequence |
|
import torch.nn.functional as F |
|
from .transformer import * |
|
import torchvision.models as models |
|
from einops import rearrange |
|
from transformers import AutoModel |
|
|
|
""" |
|
args.N |
|
args.d_model |
|
args.res_base_model |
|
args.H |
|
args.num_queries |
|
args.dropout |
|
args.attribute_set_size |
|
""" |
|
|
|
|
|
class MeDSLIP(nn.Module): |
|
def __init__( |
|
self, config, anatomy_book, pathology_book, mode="train", |
|
): |
|
super(MeDSLIP, self).__init__() |
|
self.mode = mode |
|
self.d_model = config["d_model"] |
|
|
|
with torch.no_grad(): |
|
bert_model = self._get_bert_basemodel( |
|
config["text_encoder"], freeze_layers=None |
|
).to(anatomy_book["input_ids"].device) |
|
self.anatomy_book = bert_model( |
|
input_ids=anatomy_book["input_ids"], |
|
attention_mask=anatomy_book["attention_mask"], |
|
) |
|
self.anatomy_book = self.anatomy_book.last_hidden_state[:, 0, :] |
|
self.pathology_book = bert_model( |
|
input_ids=pathology_book["input_ids"], |
|
attention_mask=pathology_book["attention_mask"], |
|
) |
|
self.pathology_book = self.pathology_book.last_hidden_state[:, 0, :] |
|
self.pathology_embedding_layer = nn.Linear(768, 256) |
|
self.cl_fc_pathology = nn.Linear(256, 768) |
|
|
|
self.pathology_name = [ |
|
"normal", |
|
"clear", |
|
"sharp", |
|
"sharply", |
|
"unremarkable", |
|
"intact", |
|
"stable", |
|
"free", |
|
"effusion", |
|
"opacity", |
|
"pneumothorax", |
|
"edema", |
|
"atelectasis", |
|
"tube", |
|
"consolidation", |
|
"process", |
|
"abnormality", |
|
"enlarge", |
|
"tip", |
|
"low", |
|
"pneumonia", |
|
"line", |
|
"congestion", |
|
"catheter", |
|
"cardiomegaly", |
|
"fracture", |
|
"air", |
|
"tortuous", |
|
"lead", |
|
"pathology", |
|
"calcification", |
|
"prominence", |
|
"device", |
|
"engorgement", |
|
"picc", |
|
"clip", |
|
"elevation", |
|
"expand", |
|
"nodule", |
|
"wire", |
|
"fluid", |
|
"degenerative", |
|
"pacemaker", |
|
"thicken", |
|
"marking", |
|
"scar", |
|
"hyperinflate", |
|
"blunt", |
|
"loss", |
|
"widen", |
|
"coll_eapse", |
|
"density", |
|
"emphysema", |
|
"aerate", |
|
"mass", |
|
"crowd", |
|
"infiltrate", |
|
"obscure", |
|
"deformity", |
|
"hernia", |
|
"drainage", |
|
"distention", |
|
"shift", |
|
"stent", |
|
"pressure", |
|
"lesion", |
|
"finding", |
|
"borderline", |
|
"hardware", |
|
"dilation", |
|
"chf", |
|
"redistribution", |
|
"aspiration", |
|
"tail_abnorm_obs", |
|
"excluded_obs", |
|
] |
|
|
|
self.excluded_pathology = [ |
|
"pneumonia", |
|
"infiltrate", |
|
"mass", |
|
"nodule", |
|
"emphysema", |
|
"fibrosis", |
|
"thicken", |
|
"hernia", |
|
] |
|
|
|
self.keep_class_dim_pathology = [ |
|
self.pathology_name.index(i) |
|
for i in self.pathology_name |
|
if i not in self.excluded_pathology |
|
] |
|
""" visual backbone""" |
|
self.resnet_dict = { |
|
"resnet18": models.resnet18(pretrained=False), |
|
"resnet50": models.resnet50(pretrained=False), |
|
} |
|
resnet = self._get_res_basemodel(config["res_base_model"]) |
|
num_ftrs = int(resnet.fc.in_features / 2) |
|
self.res_features = nn.Sequential(*list(resnet.children())[:-3]) |
|
|
|
self.res_l1_pathology = nn.Linear(num_ftrs, num_ftrs) |
|
self.res_l2_pathology = nn.Linear(num_ftrs, self.d_model) |
|
|
|
self.cl_fc_anatomy = nn.Linear(256, 768) |
|
self.res_l1_anatomy = nn.Linear(num_ftrs, num_ftrs) |
|
self.res_l2_anatomy = nn.Linear(num_ftrs, self.d_model) |
|
|
|
self.mask_generator = nn.Linear(num_ftrs, num_ftrs) |
|
|
|
|
|
""" Query Decoder""" |
|
|
|
|
|
self.H = config["H"] |
|
decoder_layer = TransformerDecoderLayer( |
|
self.d_model, config["H"], 1024, 0.1, "relu", normalize_before=True |
|
) |
|
decoder_norm = nn.LayerNorm(self.d_model) |
|
self.decoder_anatomy = TransformerDecoder( |
|
decoder_layer, config["N"], decoder_norm, return_intermediate=False |
|
) |
|
self.decoder_pathology = TransformerDecoder( |
|
decoder_layer, config["N"], decoder_norm, return_intermediate=False |
|
) |
|
|
|
|
|
self.dropout_feas_anatomy = nn.Dropout(config["dropout"]) |
|
self.dropout_feas_pathology = nn.Dropout(config["dropout"]) |
|
|
|
|
|
self.classifier_anatomy = nn.Linear(self.d_model, config["attribute_set_size"]) |
|
self.classifier_pathology = nn.Linear( |
|
self.d_model, config["attribute_set_size"] |
|
) |
|
|
|
self.apply(self._init_weights) |
|
|
|
def _get_res_basemodel(self, res_model_name): |
|
try: |
|
res_model = self.resnet_dict[res_model_name] |
|
print("Image feature extractor:", res_model_name) |
|
return res_model |
|
except: |
|
raise ( |
|
"Invalid model name. Check the config file and pass one of: resnet18 or resnet50" |
|
) |
|
|
|
def _get_bert_basemodel(self, bert_model_name, freeze_layers): |
|
try: |
|
model = AutoModel.from_pretrained(bert_model_name) |
|
print("text feature extractor:", bert_model_name) |
|
except: |
|
raise ( |
|
"Invalid model name. Check the config file and pass a BERT model from transformers lybrary" |
|
) |
|
|
|
if freeze_layers is not None: |
|
for layer_idx in freeze_layers: |
|
for param in list(model.encoder.layer[layer_idx].parameters()): |
|
param.requires_grad = False |
|
return model |
|
|
|
def image_encoder(self, xis): |
|
|
|
""" |
|
16 torch.Size([16, 1024, 14, 14]) |
|
torch.Size([16, 196, 1024]) |
|
torch.Size([3136, 1024]) |
|
torch.Size([16, 196, 256]) |
|
""" |
|
batch_size = xis.shape[0] |
|
res_fea = self.res_features(xis) |
|
res_fea = rearrange(res_fea, "b d n1 n2 -> b (n1 n2) d") |
|
x = rearrange(res_fea, "b n d -> (b n) d") |
|
|
|
mask = self.mask_generator(x) |
|
x_pathology = mask * x |
|
x_anatomy = (1 - mask) * x |
|
|
|
x_pathology = self.res_l1_pathology(x_pathology) |
|
x_anatomy = self.res_l1_anatomy(x_anatomy) |
|
x_pathology = F.relu(x_pathology) |
|
x_anatomy = F.relu(x_anatomy) |
|
|
|
x_pathology = self.res_l2_pathology(x_pathology) |
|
x_anatomy = self.res_l2_anatomy(x_anatomy) |
|
|
|
out_emb_pathology = rearrange(x_pathology, "(b n) d -> b n d", b=batch_size) |
|
out_emb_anatomy = rearrange(x_anatomy, "(b n) d -> b n d", b=batch_size) |
|
return out_emb_pathology, out_emb_anatomy |
|
|
|
def forward( |
|
self, |
|
images, |
|
labels_pathology=None, |
|
labels_anatomy=None, |
|
matrix=None, |
|
sample_index_pathology=None, |
|
sample_index_anatomy=None, |
|
is_train=True, |
|
text_gen=False, |
|
no_cl=False, |
|
exclude_class=False, |
|
): |
|
|
|
B = images.shape[0] |
|
device = images.device |
|
""" Visual Backbone """ |
|
x_pathology, x_anatomy = self.image_encoder(images) |
|
|
|
features_pathology = x_pathology.transpose(0, 1) |
|
features_anatomy = x_anatomy.transpose(0, 1) |
|
|
|
query_embed_pathology = self.pathology_embedding_layer(self.pathology_book) |
|
query_embed_anatomy = self.pathology_embedding_layer(self.anatomy_book) |
|
query_embed_pathology = query_embed_pathology.unsqueeze(1).repeat(1, B, 1) |
|
query_embed_anatomy = query_embed_anatomy.unsqueeze(1).repeat(1, B, 1) |
|
|
|
features_pathology, ws_pathology = self.decoder_pathology( |
|
query_embed_pathology, |
|
features_pathology, |
|
memory_key_padding_mask=None, |
|
pos=None, |
|
query_pos=None, |
|
) |
|
features_anatomy, ws_anatomy = self.decoder_anatomy( |
|
query_embed_anatomy, |
|
features_anatomy, |
|
memory_key_padding_mask=None, |
|
pos=None, |
|
query_pos=None, |
|
) |
|
|
|
ap_pathology = features_pathology |
|
ap_anatomy = features_anatomy |
|
|
|
ap_logits = torch.bmm( |
|
ap_pathology.transpose(0, 1), ap_anatomy.transpose(0, 1).transpose(1, 2) |
|
).transpose( |
|
1, 2 |
|
) |
|
if text_gen: |
|
output_logits = ap_logits |
|
matrix_zero = matrix |
|
|
|
masks = matrix_zero >= 0 |
|
ap_logits = ap_logits[masks] |
|
matrix_zero = matrix_zero[masks] |
|
|
|
loss_ap = F.binary_cross_entropy_with_logits( |
|
ap_logits.float(), matrix_zero.float() |
|
) |
|
|
|
out_pathology = self.dropout_feas_pathology(features_pathology) |
|
out_anatomy = self.dropout_feas_anatomy(features_anatomy) |
|
|
|
if is_train == True and no_cl == False: |
|
|
|
|
|
anatomytomy_query = torch.zeros( |
|
[ |
|
sample_index_pathology.shape[0], |
|
sample_index_pathology.shape[1], |
|
sample_index_pathology.shape[2], |
|
self.anatomy_book.shape[-1], |
|
] |
|
).to( |
|
device |
|
) |
|
entity_query = torch.zeros( |
|
[ |
|
sample_index_anatomy.shape[0], |
|
sample_index_anatomy.shape[1], |
|
sample_index_anatomy.shape[2], |
|
self.pathology_book.shape[-1], |
|
] |
|
).to(device) |
|
|
|
anatomytomy_query = self.anatomy_book[sample_index_pathology, :] * ( |
|
sample_index_pathology != -1 |
|
).int().unsqueeze(-1).repeat( |
|
1, 1, 1, 768 |
|
) |
|
entity_query = self.pathology_book[sample_index_anatomy, :] * ( |
|
sample_index_anatomy != -1 |
|
).int().unsqueeze(-1).repeat(1, 1, 1, 768) |
|
|
|
matrix_zero_pathology = matrix |
|
matrix_zero_anatomy = matrix.transpose(1, 2) |
|
matrix_zero_pathology[matrix_zero_pathology < 1] = 0 |
|
matrix_zero_anatomy[matrix_zero_anatomy < 1] = 0 |
|
matrix_zero_pathology = matrix_zero_pathology.unsqueeze(3).repeat( |
|
1, 1, 1, anatomytomy_query.shape[-1] |
|
) |
|
matrix_zero_anatomy = matrix_zero_anatomy.unsqueeze(3).repeat( |
|
1, 1, 1, entity_query.shape[-1] |
|
) |
|
|
|
anatomy_temp = self.anatomy_book |
|
pathology_temp = self.pathology_book |
|
anatomy_temp = anatomy_temp.unsqueeze(0).repeat( |
|
anatomytomy_query.shape[0], 1, 1 |
|
) |
|
pathology_temp = pathology_temp.unsqueeze(0).repeat( |
|
entity_query.shape[0], 1, 1 |
|
) |
|
anatomy_temp = anatomy_temp.unsqueeze(2).repeat( |
|
1, 1, anatomytomy_query.shape[1], 1 |
|
) |
|
pathology_temp = pathology_temp.unsqueeze(2).repeat( |
|
1, 1, entity_query.shape[1], 1 |
|
) |
|
|
|
posi_matrix_pathology = (matrix_zero_pathology * anatomy_temp).transpose( |
|
1, 2 |
|
) |
|
posi_matrix_anatomy = (matrix_zero_anatomy * pathology_temp).transpose(1, 2) |
|
|
|
for i in range(anatomytomy_query.shape[0]): |
|
for j in range(anatomytomy_query.shape[1]): |
|
if (posi_matrix_pathology[i, j] != 0).sum() > 0: |
|
num_posi = ( |
|
torch.nonzero(posi_matrix_pathology[i, j], as_tuple=True)[0] |
|
.unique() |
|
.shape[0] |
|
) |
|
assert anatomytomy_query[i, j, 0, :].sum() == 0 |
|
anatomytomy_query[i, j, 0, :] = ( |
|
posi_matrix_pathology[i, j, :, :].sum(dim=0) / num_posi |
|
) |
|
|
|
for i in range(entity_query.shape[0]): |
|
for j in range(entity_query.shape[1]): |
|
if (posi_matrix_anatomy[i, j] != 0).sum() > 0: |
|
num_posi = ( |
|
torch.nonzero(posi_matrix_anatomy[i, j], as_tuple=True)[0] |
|
.unique() |
|
.shape[0] |
|
) |
|
assert entity_query[i, j, 0, :].sum() == 0 |
|
entity_query[i, j, 0, :] = ( |
|
posi_matrix_anatomy[i, j, :, :].sum(dim=0) / num_posi |
|
) |
|
|
|
|
|
|
|
ll_pathology = out_pathology.transpose(0, 1) |
|
ll_anatomy = out_anatomy.transpose(0, 1) |
|
|
|
Q_pathology = ll_pathology.shape[1] |
|
Q_anatomy = ll_anatomy.shape[1] |
|
|
|
ll_pathology = ll_pathology.reshape( |
|
ll_pathology.shape[0] * ll_pathology.shape[1], -1 |
|
) |
|
ll_anatomy = ll_anatomy.reshape( |
|
ll_anatomy.shape[0] * ll_anatomy.shape[1], -1 |
|
) |
|
|
|
ll_pathology = self.cl_fc_pathology(ll_pathology) |
|
ll_anatomy = self.cl_fc_anatomy(ll_anatomy) |
|
|
|
ll_pathology = ll_pathology.unsqueeze(dim=-1) |
|
ll_anatomy = ll_anatomy.unsqueeze(dim=-1) |
|
|
|
anatomytomy_query = anatomytomy_query.reshape(B * Q_pathology, 8, 768) |
|
entity_query = entity_query.reshape(B * Q_anatomy, 8, 768) |
|
|
|
ll_pathology = torch.bmm( |
|
anatomytomy_query, ll_pathology |
|
).squeeze() |
|
ll_anatomy = torch.bmm( |
|
entity_query, ll_anatomy |
|
).squeeze() |
|
|
|
cl_labels_pathology = torch.zeros((ll_pathology.shape[0])).to(device) |
|
cl_labels_anatomy = torch.zeros((ll_anatomy.shape[0])).to(device) |
|
|
|
if exclude_class == True: |
|
cl_labels_pathology = cl_labels_pathology.reshape(B, Q_pathology) |
|
cl_labels_anatomy = cl_labels_anatomy.reshape(B, Q_anatomy) |
|
|
|
cl_labels_pathology = cl_labels_pathology[ |
|
:, self.keep_class_dim_pathology |
|
] |
|
cl_labels_anatomy = cl_labels_anatomy[:, self.keep_class_dim_pathology] |
|
|
|
cl_labels_pathology = cl_labels_pathology.reshape(-1) |
|
cl_labels_anatomy = cl_labels_anatomy.reshape(-1) |
|
|
|
ll_pathology = ll_pathology.reshape(B, Q_pathology, -1) |
|
ll_anatomy = ll_anatomy.reshape(B, Q_anatomy, -1) |
|
|
|
ll_pathology = ll_pathology[:, self.keep_class_dim_pathology, :] |
|
ll_pathology = ll_pathology.reshape( |
|
B * (len(self.keep_class_dim_pathology)), -1 |
|
) |
|
ll_anatomy = ll_anatomy.reshape(B * Q_anatomy, -1) |
|
|
|
x_pathology = self.classifier_pathology(out_pathology).transpose(0, 1) |
|
x_anatomy = self.classifier_anatomy(out_anatomy).transpose( |
|
0, 1 |
|
) |
|
|
|
if exclude_class == True: |
|
labels_pathology = labels_pathology[:, self.keep_class_dim_pathology] |
|
x_pathology = x_pathology[:, self.keep_class_dim_pathology, :] |
|
|
|
labels_pathology = labels_pathology.reshape(-1, 1) |
|
labels_anatomy = labels_anatomy.reshape(-1, 1) |
|
logits_pathology = x_pathology.reshape(-1, x_pathology.shape[-1]) |
|
logits_anatomy = x_anatomy.reshape(-1, x_anatomy.shape[-1]) |
|
Mask_pathology = ((labels_pathology != -1) & (labels_pathology != 2)).squeeze() |
|
Mask_anatomy = ((labels_anatomy != -1) & (labels_anatomy != 2)).squeeze() |
|
|
|
cl_mask_pathology = (labels_pathology == 1).squeeze() |
|
cl_mask_anatomy = (labels_anatomy == 1).squeeze() |
|
if is_train == True: |
|
labels_pathology = labels_pathology[Mask_pathology].long() |
|
labels_anatomy = labels_anatomy[Mask_anatomy].long() |
|
logits_pathology = logits_pathology[Mask_pathology] |
|
logits_anatomy = logits_anatomy[Mask_anatomy] |
|
loss_ce_pathology = F.cross_entropy( |
|
logits_pathology, labels_pathology[:, 0] |
|
) |
|
loss_ce_anatomy = F.cross_entropy(logits_anatomy, labels_anatomy[:, 0]) |
|
if no_cl == False: |
|
cl_labels_pathology = cl_labels_pathology[cl_mask_pathology].long() |
|
cl_labels_anatomy = cl_labels_anatomy[cl_mask_anatomy].long() |
|
ll_pathology = ll_pathology[cl_mask_pathology] |
|
ll_anatomy = ll_anatomy[cl_mask_anatomy] |
|
loss_cl_pathology = F.cross_entropy(ll_pathology, cl_labels_pathology) |
|
loss_cl_anatomy = F.cross_entropy(ll_anatomy, cl_labels_anatomy) |
|
loss_ce = loss_ce_pathology + loss_ce_anatomy |
|
loss_cl = loss_cl_pathology + loss_cl_anatomy |
|
loss = loss_ce + loss_cl + loss_ap |
|
else: |
|
loss_cl = torch.tensor(0) |
|
loss = loss_ce_pathology + loss_ce_anatomy + loss_ap |
|
else: |
|
loss = 0 |
|
if is_train == True: |
|
if text_gen: |
|
return ( |
|
loss, |
|
x_pathology, |
|
ws_pathology, |
|
x_anatomy, |
|
ws_anatomy, |
|
output_logits, |
|
) |
|
else: |
|
return ( |
|
loss, |
|
loss_ce_pathology, |
|
loss_cl_pathology, |
|
loss_ce_anatomy, |
|
loss_cl_anatomy, |
|
loss_ap, |
|
) |
|
else: |
|
return loss, x_pathology, ws_pathology, x_anatomy, ws_anatomy |
|
|
|
@staticmethod |
|
def _init_weights(module): |
|
r"""Initialize weights like BERT - N(0.0, 0.02), bias = 0.""" |
|
|
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
|
elif isinstance(module, nn.MultiheadAttention): |
|
module.in_proj_weight.data.normal_(mean=0.0, std=0.02) |
|
module.out_proj.weight.data.normal_(mean=0.0, std=0.02) |
|
|
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|