MeDSLIP / PreTrain_MeDSLIP /models /model_MeDSLIP.py
wenruifan's picture
Upload 115 files
a256709 verified
# modified from https://github.com/tensorflow/models/blob/master/research/slim/nets/s3dg.py
from sklearn.metrics import log_loss
import torch.nn as nn
import torch
import math
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from .transformer import *
import torchvision.models as models
from einops import rearrange
from transformers import AutoModel
"""
args.N
args.d_model
args.res_base_model
args.H
args.num_queries
args.dropout
args.attribute_set_size
"""
class MeDSLIP(nn.Module):
def __init__(
self, config, anatomy_book, pathology_book, mode="train",
):
super(MeDSLIP, self).__init__()
self.mode = mode
self.d_model = config["d_model"]
# """ book embedding"""
with torch.no_grad():
bert_model = self._get_bert_basemodel(
config["text_encoder"], freeze_layers=None
).to(anatomy_book["input_ids"].device)
self.anatomy_book = bert_model(
input_ids=anatomy_book["input_ids"],
attention_mask=anatomy_book["attention_mask"],
) # (**encoded_inputs)
self.anatomy_book = self.anatomy_book.last_hidden_state[:, 0, :]
self.pathology_book = bert_model(
input_ids=pathology_book["input_ids"],
attention_mask=pathology_book["attention_mask"],
) # (**encoded_inputs)
self.pathology_book = self.pathology_book.last_hidden_state[:, 0, :]
self.pathology_embedding_layer = nn.Linear(768, 256)
self.cl_fc_pathology = nn.Linear(256, 768)
self.pathology_name = [
"normal",
"clear",
"sharp",
"sharply",
"unremarkable",
"intact",
"stable",
"free",
"effusion",
"opacity",
"pneumothorax",
"edema",
"atelectasis",
"tube",
"consolidation",
"process",
"abnormality",
"enlarge",
"tip",
"low",
"pneumonia",
"line",
"congestion",
"catheter",
"cardiomegaly",
"fracture",
"air",
"tortuous",
"lead",
"pathology",
"calcification",
"prominence",
"device",
"engorgement",
"picc",
"clip",
"elevation",
"expand",
"nodule",
"wire",
"fluid",
"degenerative",
"pacemaker",
"thicken",
"marking",
"scar",
"hyperinflate",
"blunt",
"loss",
"widen",
"coll_eapse",
"density",
"emphysema",
"aerate",
"mass",
"crowd",
"infiltrate",
"obscure",
"deformity",
"hernia",
"drainage",
"distention",
"shift",
"stent",
"pressure",
"lesion",
"finding",
"borderline",
"hardware",
"dilation",
"chf",
"redistribution",
"aspiration",
"tail_abnorm_obs",
"excluded_obs",
]
self.excluded_pathology = [
"pneumonia",
"infiltrate",
"mass",
"nodule",
"emphysema",
"fibrosis",
"thicken",
"hernia",
]
self.keep_class_dim_pathology = [
self.pathology_name.index(i)
for i in self.pathology_name
if i not in self.excluded_pathology
]
""" visual backbone"""
self.resnet_dict = {
"resnet18": models.resnet18(pretrained=False),
"resnet50": models.resnet50(pretrained=False),
}
resnet = self._get_res_basemodel(config["res_base_model"])
num_ftrs = int(resnet.fc.in_features / 2)
self.res_features = nn.Sequential(*list(resnet.children())[:-3])
self.res_l1_pathology = nn.Linear(num_ftrs, num_ftrs)
self.res_l2_pathology = nn.Linear(num_ftrs, self.d_model)
self.cl_fc_anatomy = nn.Linear(256, 768)
self.res_l1_anatomy = nn.Linear(num_ftrs, num_ftrs)
self.res_l2_anatomy = nn.Linear(num_ftrs, self.d_model)
self.mask_generator = nn.Linear(num_ftrs, num_ftrs)
###################################
""" Query Decoder"""
###################################
self.H = config["H"]
decoder_layer = TransformerDecoderLayer(
self.d_model, config["H"], 1024, 0.1, "relu", normalize_before=True
)
decoder_norm = nn.LayerNorm(self.d_model)
self.decoder_anatomy = TransformerDecoder(
decoder_layer, config["N"], decoder_norm, return_intermediate=False
)
self.decoder_pathology = TransformerDecoder(
decoder_layer, config["N"], decoder_norm, return_intermediate=False
)
# Learnable Queries
self.dropout_feas_anatomy = nn.Dropout(config["dropout"])
self.dropout_feas_pathology = nn.Dropout(config["dropout"])
# Attribute classifier
self.classifier_anatomy = nn.Linear(self.d_model, config["attribute_set_size"])
self.classifier_pathology = nn.Linear(
self.d_model, config["attribute_set_size"]
)
self.apply(self._init_weights)
def _get_res_basemodel(self, res_model_name):
try:
res_model = self.resnet_dict[res_model_name]
print("Image feature extractor:", res_model_name)
return res_model
except:
raise (
"Invalid model name. Check the config file and pass one of: resnet18 or resnet50"
)
def _get_bert_basemodel(self, bert_model_name, freeze_layers):
try:
model = AutoModel.from_pretrained(bert_model_name)
print("text feature extractor:", bert_model_name)
except:
raise (
"Invalid model name. Check the config file and pass a BERT model from transformers lybrary"
)
if freeze_layers is not None:
for layer_idx in freeze_layers:
for param in list(model.encoder.layer[layer_idx].parameters()):
param.requires_grad = False
return model
def image_encoder(self, xis):
# patch features
"""
16 torch.Size([16, 1024, 14, 14])
torch.Size([16, 196, 1024])
torch.Size([3136, 1024])
torch.Size([16, 196, 256])
"""
batch_size = xis.shape[0]
res_fea = self.res_features(xis) # batch_size,feature_size,patch_num,patch_num
res_fea = rearrange(res_fea, "b d n1 n2 -> b (n1 n2) d")
x = rearrange(res_fea, "b n d -> (b n) d")
mask = self.mask_generator(x)
x_pathology = mask * x
x_anatomy = (1 - mask) * x
x_pathology = self.res_l1_pathology(x_pathology)
x_anatomy = self.res_l1_anatomy(x_anatomy)
x_pathology = F.relu(x_pathology)
x_anatomy = F.relu(x_anatomy)
x_pathology = self.res_l2_pathology(x_pathology)
x_anatomy = self.res_l2_anatomy(x_anatomy)
out_emb_pathology = rearrange(x_pathology, "(b n) d -> b n d", b=batch_size)
out_emb_anatomy = rearrange(x_anatomy, "(b n) d -> b n d", b=batch_size)
return out_emb_pathology, out_emb_anatomy
def forward(
self,
images,
labels_pathology=None,
labels_anatomy=None,
matrix=None,
sample_index_pathology=None,
sample_index_anatomy=None,
is_train=True,
text_gen=False,
no_cl=False,
exclude_class=False,
):
B = images.shape[0]
device = images.device
""" Visual Backbone """
x_pathology, x_anatomy = self.image_encoder(images) # batch_size,patch_num,dim
features_pathology = x_pathology.transpose(0, 1) # patch_num b dim
features_anatomy = x_anatomy.transpose(0, 1) # patch_num b dim
query_embed_pathology = self.pathology_embedding_layer(self.pathology_book)
query_embed_anatomy = self.pathology_embedding_layer(self.anatomy_book)
query_embed_pathology = query_embed_pathology.unsqueeze(1).repeat(1, B, 1)
query_embed_anatomy = query_embed_anatomy.unsqueeze(1).repeat(1, B, 1)
features_pathology, ws_pathology = self.decoder_pathology(
query_embed_pathology,
features_pathology,
memory_key_padding_mask=None,
pos=None,
query_pos=None,
)
features_anatomy, ws_anatomy = self.decoder_anatomy(
query_embed_anatomy,
features_anatomy,
memory_key_padding_mask=None,
pos=None,
query_pos=None,
)
ap_pathology = features_pathology
ap_anatomy = features_anatomy
ap_logits = torch.bmm(
ap_pathology.transpose(0, 1), ap_anatomy.transpose(0, 1).transpose(1, 2)
).transpose(
1, 2
)
if text_gen:
output_logits = ap_logits
matrix_zero = matrix
masks = matrix_zero >= 0
ap_logits = ap_logits[masks]
matrix_zero = matrix_zero[masks]
loss_ap = F.binary_cross_entropy_with_logits(
ap_logits.float(), matrix_zero.float()
)
out_pathology = self.dropout_feas_pathology(features_pathology)
out_anatomy = self.dropout_feas_anatomy(features_anatomy)
if is_train == True and no_cl == False:
# get anatomytomy query
anatomytomy_query = torch.zeros(
[
sample_index_pathology.shape[0],
sample_index_pathology.shape[1],
sample_index_pathology.shape[2],
self.anatomy_book.shape[-1],
]
).to(
device
)
entity_query = torch.zeros(
[
sample_index_anatomy.shape[0],
sample_index_anatomy.shape[1],
sample_index_anatomy.shape[2],
self.pathology_book.shape[-1],
]
).to(device)
anatomytomy_query = self.anatomy_book[sample_index_pathology, :] * (
sample_index_pathology != -1
).int().unsqueeze(-1).repeat(
1, 1, 1, 768
) # batch, Q , position_num ,dim
entity_query = self.pathology_book[sample_index_anatomy, :] * (
sample_index_anatomy != -1
).int().unsqueeze(-1).repeat(1, 1, 1, 768)
matrix_zero_pathology = matrix
matrix_zero_anatomy = matrix.transpose(1, 2)
matrix_zero_pathology[matrix_zero_pathology < 1] = 0
matrix_zero_anatomy[matrix_zero_anatomy < 1] = 0
matrix_zero_pathology = matrix_zero_pathology.unsqueeze(3).repeat(
1, 1, 1, anatomytomy_query.shape[-1]
)
matrix_zero_anatomy = matrix_zero_anatomy.unsqueeze(3).repeat(
1, 1, 1, entity_query.shape[-1]
)
anatomy_temp = self.anatomy_book
pathology_temp = self.pathology_book
anatomy_temp = anatomy_temp.unsqueeze(0).repeat(
anatomytomy_query.shape[0], 1, 1
)
pathology_temp = pathology_temp.unsqueeze(0).repeat(
entity_query.shape[0], 1, 1
)
anatomy_temp = anatomy_temp.unsqueeze(2).repeat(
1, 1, anatomytomy_query.shape[1], 1
)
pathology_temp = pathology_temp.unsqueeze(2).repeat(
1, 1, entity_query.shape[1], 1
)
posi_matrix_pathology = (matrix_zero_pathology * anatomy_temp).transpose(
1, 2
)
posi_matrix_anatomy = (matrix_zero_anatomy * pathology_temp).transpose(1, 2)
for i in range(anatomytomy_query.shape[0]):
for j in range(anatomytomy_query.shape[1]):
if (posi_matrix_pathology[i, j] != 0).sum() > 0:
num_posi = (
torch.nonzero(posi_matrix_pathology[i, j], as_tuple=True)[0]
.unique()
.shape[0]
)
assert anatomytomy_query[i, j, 0, :].sum() == 0
anatomytomy_query[i, j, 0, :] = (
posi_matrix_pathology[i, j, :, :].sum(dim=0) / num_posi
)
for i in range(entity_query.shape[0]):
for j in range(entity_query.shape[1]):
if (posi_matrix_anatomy[i, j] != 0).sum() > 0:
num_posi = (
torch.nonzero(posi_matrix_anatomy[i, j], as_tuple=True)[0]
.unique()
.shape[0]
)
assert entity_query[i, j, 0, :].sum() == 0
entity_query[i, j, 0, :] = (
posi_matrix_anatomy[i, j, :, :].sum(dim=0) / num_posi
)
# Got anatomytomy query
# [Q,B,A]
ll_pathology = out_pathology.transpose(0, 1) # B Q A
ll_anatomy = out_anatomy.transpose(0, 1) # B Q A
Q_pathology = ll_pathology.shape[1]
Q_anatomy = ll_anatomy.shape[1]
ll_pathology = ll_pathology.reshape(
ll_pathology.shape[0] * ll_pathology.shape[1], -1
)
ll_anatomy = ll_anatomy.reshape(
ll_anatomy.shape[0] * ll_anatomy.shape[1], -1
)
ll_pathology = self.cl_fc_pathology(ll_pathology)
ll_anatomy = self.cl_fc_anatomy(ll_anatomy)
ll_pathology = ll_pathology.unsqueeze(dim=-1)
ll_anatomy = ll_anatomy.unsqueeze(dim=-1)
anatomytomy_query = anatomytomy_query.reshape(B * Q_pathology, 8, 768)
entity_query = entity_query.reshape(B * Q_anatomy, 8, 768)
ll_pathology = torch.bmm(
anatomytomy_query, ll_pathology
).squeeze() # B Q position_num
ll_anatomy = torch.bmm(
entity_query, ll_anatomy
).squeeze() # B Q position_num
cl_labels_pathology = torch.zeros((ll_pathology.shape[0])).to(device)
cl_labels_anatomy = torch.zeros((ll_anatomy.shape[0])).to(device)
if exclude_class == True:
cl_labels_pathology = cl_labels_pathology.reshape(B, Q_pathology)
cl_labels_anatomy = cl_labels_anatomy.reshape(B, Q_anatomy)
cl_labels_pathology = cl_labels_pathology[
:, self.keep_class_dim_pathology
]
cl_labels_anatomy = cl_labels_anatomy[:, self.keep_class_dim_pathology]
cl_labels_pathology = cl_labels_pathology.reshape(-1)
cl_labels_anatomy = cl_labels_anatomy.reshape(-1)
ll_pathology = ll_pathology.reshape(B, Q_pathology, -1)
ll_anatomy = ll_anatomy.reshape(B, Q_anatomy, -1)
ll_pathology = ll_pathology[:, self.keep_class_dim_pathology, :]
ll_pathology = ll_pathology.reshape(
B * (len(self.keep_class_dim_pathology)), -1
)
ll_anatomy = ll_anatomy.reshape(B * Q_anatomy, -1)
x_pathology = self.classifier_pathology(out_pathology).transpose(0, 1)
x_anatomy = self.classifier_anatomy(out_anatomy).transpose(
0, 1
) # B query Atributes
if exclude_class == True:
labels_pathology = labels_pathology[:, self.keep_class_dim_pathology]
x_pathology = x_pathology[:, self.keep_class_dim_pathology, :]
labels_pathology = labels_pathology.reshape(-1, 1)
labels_anatomy = labels_anatomy.reshape(-1, 1)
logits_pathology = x_pathology.reshape(-1, x_pathology.shape[-1])
logits_anatomy = x_anatomy.reshape(-1, x_anatomy.shape[-1])
Mask_pathology = ((labels_pathology != -1) & (labels_pathology != 2)).squeeze()
Mask_anatomy = ((labels_anatomy != -1) & (labels_anatomy != 2)).squeeze()
cl_mask_pathology = (labels_pathology == 1).squeeze()
cl_mask_anatomy = (labels_anatomy == 1).squeeze()
if is_train == True:
labels_pathology = labels_pathology[Mask_pathology].long()
labels_anatomy = labels_anatomy[Mask_anatomy].long()
logits_pathology = logits_pathology[Mask_pathology]
logits_anatomy = logits_anatomy[Mask_anatomy]
loss_ce_pathology = F.cross_entropy(
logits_pathology, labels_pathology[:, 0]
)
loss_ce_anatomy = F.cross_entropy(logits_anatomy, labels_anatomy[:, 0])
if no_cl == False:
cl_labels_pathology = cl_labels_pathology[cl_mask_pathology].long()
cl_labels_anatomy = cl_labels_anatomy[cl_mask_anatomy].long()
ll_pathology = ll_pathology[cl_mask_pathology]
ll_anatomy = ll_anatomy[cl_mask_anatomy]
loss_cl_pathology = F.cross_entropy(ll_pathology, cl_labels_pathology)
loss_cl_anatomy = F.cross_entropy(ll_anatomy, cl_labels_anatomy)
loss_ce = loss_ce_pathology + loss_ce_anatomy
loss_cl = loss_cl_pathology + loss_cl_anatomy
loss = loss_ce + loss_cl + loss_ap
else:
loss_cl = torch.tensor(0)
loss = loss_ce_pathology + loss_ce_anatomy + loss_ap
else:
loss = 0
if is_train == True:
if text_gen:
return (
loss,
x_pathology,
ws_pathology,
x_anatomy,
ws_anatomy,
output_logits,
)
else:
return (
loss,
loss_ce_pathology,
loss_cl_pathology,
loss_ce_anatomy,
loss_cl_anatomy,
loss_ap,
)
else:
return loss, x_pathology, ws_pathology, x_anatomy, ws_anatomy
@staticmethod
def _init_weights(module):
r"""Initialize weights like BERT - N(0.0, 0.02), bias = 0."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.MultiheadAttention):
module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
module.out_proj.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()