least1924 commited on
Commit
6775edf
·
verified ·
1 Parent(s): 514e8bb

Upload 10 files

Browse files

add inferecce code and the trained model file.

checkpoints/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e7ccea2005799786c18f5adc27b1e6586fc614fb6bf8e56092a44bc23adf398
3
+ size 101028923
inference.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, division
2
+ import os
3
+ import sys
4
+ import time
5
+ import argparse
6
+ import warnings
7
+ import torch
8
+ import pickle
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ import pandas as pd
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+
15
+ from torch.utils.data import Dataset, DataLoader, TensorDataset
16
+ from torchvision import transforms, utils
17
+ from models.modeling import PATHOLOGICAL_CLASSFIER, CONFIGS
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ def load_weights(model, weight_path):
21
+ print("Loading PATHOLOGICAL_CLASSFIER...",weight_path)
22
+ loadnet = torch.load(weight_path,map_location=device)
23
+ if "model_state_dict" in loadnet:
24
+ keyname = "model_state_dict"
25
+ else:
26
+ keyname = "model_state_dict"
27
+ model.load_state_dict(loadnet[keyname], strict=True)
28
+ return model
29
+
30
+ class MyDataset(Dataset):
31
+ def __init__(self, root_path):
32
+ m_data = []
33
+ img_pkl_file_path = os.path.join(root_path, "img_feature")
34
+ txt_pkl_file_path = os.path.join(root_path, "txt_feature")
35
+ target_pkl_file_path = os.path.join(root_path, "target")
36
+ for file in os.listdir(img_pkl_file_path):
37
+
38
+ img_pkl_file = os.path.join(img_pkl_file_path, file)
39
+ txt_pkl_file = os.path.join(txt_pkl_file_path, file)
40
+ target_pkl_file = os.path.join(target_pkl_file_path, file)
41
+ with open(img_pkl_file, "rb") as img_f:
42
+ img_load_dict = pickle.load(img_f)
43
+ m_input_img = img_load_dict["img_feature"]
44
+ with open(txt_pkl_file, "rb") as txt_f:
45
+ txt_load_dict = pickle.load(txt_f)
46
+ m_input_txt = txt_load_dict["txt_feature"]
47
+ with open(target_pkl_file, "rb") as target_f:
48
+ target_load_dict = pickle.load(target_f)
49
+ m_output_os = target_load_dict["target_os"]
50
+ m_output_dfs = target_load_dict["target_dfs"]
51
+ m_data.append((m_input_img, m_input_txt, m_output_os, m_output_dfs,file))
52
+ self.m_data = m_data
53
+ def __getitem__(self, idx):
54
+ inp_i, inp_txt, oup_os, oup_dfs,f_name = self.m_data[idx]
55
+ return inp_i, inp_txt, oup_os, oup_dfs,f_name
56
+ def __len__(self):
57
+ return len(self.m_data)
58
+
59
+ def valid(args):
60
+ torch.manual_seed(0)
61
+ num_classes = 2
62
+ config = CONFIGS["PATHOLOGICAL_CLASSFIER"]
63
+ model = PATHOLOGICAL_CLASSFIER(config, num_classes=num_classes, vis=True, mm=True)
64
+
65
+ model_path = '/your/trained/model/path/'
66
+ p_c_model = load_weights(model, model_path)
67
+
68
+ p_c_model.to(device)
69
+ test_dataset = MyDataset("/your/dataset/path/" )
70
+ test_loader = DataLoader(test_dataset, batch_size=1)
71
+
72
+ # #----- Test ------
73
+ print("--------Start testing-------")
74
+ p_c_model.eval()
75
+
76
+ valid_1_acc = 0
77
+ valid_1_total = 0
78
+ valid_1_cnt = 0
79
+
80
+ valid_2_acc = 0
81
+ valid_2_total = 0
82
+ valid_2_cnt = 0
83
+ valid_total_cnt=0
84
+
85
+ target_cnt_0=0
86
+ target_cnt_1=0
87
+ with torch.no_grad():
88
+ for imgs, txt, target_1, target_2,file_name in test_loader:
89
+ output_1, output_2, = model(imgs.to(device), txt.to(device))
90
+
91
+ out_1_list_prob = (torch.softmax(output_1.squeeze(1), axis=-1).cpu().numpy().tolist())
92
+
93
+ out_1_list = (torch.argmax(output_1.squeeze(1), axis=-1).cpu().numpy().tolist())
94
+ target_1_list = target_1.tolist()
95
+
96
+ out_2_list = (torch.argmax(output_2.squeeze(1), axis=-1).cpu().numpy().tolist())
97
+ target_2_list = target_2.tolist()
98
+
99
+ valid_1_total += len(out_1_list)
100
+ valid_2_total += len(out_2_list)
101
+
102
+ for i in range(len(out_1_list)):
103
+ if out_1_list[i] == target_1_list[i]:
104
+ valid_1_cnt += 1
105
+ if out_2_list[i] == target_2_list[i]:
106
+ valid_2_cnt += 1
107
+ if out_1_list[i] == target_1_list[i] and out_2_list[i] == target_2_list[i]:
108
+ valid_total_cnt+=1
109
+
110
+ valid_1_acc = valid_1_cnt / valid_1_total
111
+ valid_2_acc = valid_2_cnt / valid_2_total
112
+ valid_total_acc =valid_total_cnt/valid_1_total
113
+
114
+ print(valid_1_acc,valid_1_total, valid_2_acc,valid_2_total,valid_total_acc,valid_total_cnt)
115
+ print("="*100)
116
+
117
+ if __name__ == "__main__":
118
+ parser = argparse.ArgumentParser(description="")
119
+ args = parser.parse_args()
120
+ valid(args)
121
+
122
+
123
+
124
+
125
+
models/__init__.py ADDED
File without changes
models/attention.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
4
+ import models.configs as configs
5
+ import math
6
+
7
+ class Attention(nn.Module):
8
+ def __init__(self, config, vis, mm=True):
9
+ super(Attention, self).__init__()
10
+ self.vis = vis
11
+ self.num_attention_heads = config.transformer["num_heads"]
12
+ self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
13
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
14
+
15
+ self.query = Linear(config.hidden_size, self.all_head_size)
16
+ self.key = Linear(config.hidden_size, self.all_head_size)
17
+ self.value = Linear(config.hidden_size, self.all_head_size)
18
+
19
+ if mm:
20
+ self.query_text = Linear(config.hidden_size, self.all_head_size)
21
+ self.key_text = Linear(config.hidden_size, self.all_head_size)
22
+ self.value_text = Linear(config.hidden_size, self.all_head_size)
23
+ self.out_text = Linear(config.hidden_size, config.hidden_size)
24
+ self.attn_dropout_text = Dropout(config.transformer["attention_dropout_rate"])
25
+ self.attn_dropout_it = Dropout(config.transformer["attention_dropout_rate"])
26
+ self.attn_dropout_ti = Dropout(config.transformer["attention_dropout_rate"])
27
+ self.proj_dropout_text = Dropout(config.transformer["attention_dropout_rate"])
28
+
29
+ self.out = Linear(config.hidden_size, config.hidden_size)
30
+ self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
31
+ self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
32
+
33
+ self.softmax = Softmax(dim=-1)
34
+
35
+ def transpose_for_scores(self, x):
36
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
37
+ x = x.view(*new_x_shape)
38
+ return x.permute(0, 2, 1, 3)
39
+
40
+ def forward(self, hidden_states, text=None):
41
+ mixed_query_layer = self.query(hidden_states)
42
+ mixed_key_layer = self.key(hidden_states)
43
+ mixed_value_layer = self.value(hidden_states)
44
+ if text is not None:
45
+ text_q = self.query_text(text)
46
+ text_k = self.key_text(text)
47
+ text_v = self.value_text(text)
48
+ query_layer = self.transpose_for_scores(mixed_query_layer)
49
+ key_layer = self.transpose_for_scores(mixed_key_layer)
50
+ value_layer = self.transpose_for_scores(mixed_value_layer)
51
+ if text is not None:
52
+ query_layer_img = query_layer
53
+ key_layer_img = key_layer
54
+ value_layer_img = value_layer
55
+ query_layer_text = self.transpose_for_scores(text_q)
56
+ key_layer_text = self.transpose_for_scores(text_k)
57
+ value_layer_text = self.transpose_for_scores(text_v)
58
+
59
+ if text is None:
60
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
61
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
62
+ attention_probs = self.softmax(attention_scores)
63
+ weights = attention_probs if self.vis else None
64
+ attention_probs = self.attn_dropout(attention_probs)
65
+
66
+ context_layer = torch.matmul(attention_probs, value_layer)
67
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
68
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
69
+ context_layer = context_layer.view(*new_context_layer_shape)
70
+ attention_output = self.out(context_layer)
71
+ attention_output = self.proj_dropout(attention_output)
72
+
73
+ return attention_output,None, weights
74
+
75
+ else:
76
+ attention_scores_img = torch.matmul(query_layer, key_layer.transpose(-1, -2))
77
+ attention_scores_text = torch.matmul(query_layer_text, key_layer_text.transpose(-1, -2))
78
+ attention_scores_it = torch.matmul(query_layer_img, key_layer_text.transpose(-1, -2))
79
+ attention_scores_ti = torch.matmul(query_layer_text, key_layer_img.transpose(-1, -2))
80
+ attention_scores_img = attention_scores_img / math.sqrt(self.attention_head_size)
81
+
82
+ attention_probs_img = self.softmax(attention_scores_img)
83
+ weights_img = attention_probs_img if self.vis else None
84
+
85
+ attention_probs_img = self.attn_dropout(attention_probs_img)
86
+
87
+ attention_scores_text = attention_scores_text / math.sqrt(self.attention_head_size)
88
+ attention_probs_text = self.softmax(attention_scores_text)
89
+
90
+ text_per_weights = attention_probs_text.mean(dim=-1)
91
+
92
+ text_per_weights = self.softmax(text_per_weights)
93
+
94
+ weights_text = attention_probs_text if self.vis else None
95
+
96
+ attention_probs_text = self.attn_dropout_text(attention_probs_text)
97
+
98
+ attention_scores_it = attention_scores_it / math.sqrt(self.attention_head_size)
99
+ attention_probs_it = self.softmax(attention_scores_it)
100
+
101
+ attention_probs_it = self.attn_dropout_it(attention_probs_it)
102
+
103
+ attention_scores_ti = attention_scores_ti / math.sqrt(self.attention_head_size)
104
+ attention_probs_ti = self.softmax(attention_scores_ti)
105
+ attention_probs_ti = self.attn_dropout_ti(attention_probs_ti)
106
+
107
+ context_layer_img = torch.matmul(attention_probs_img, value_layer_img)
108
+ context_layer_img = context_layer_img.permute(0, 2, 1, 3).contiguous()
109
+ context_layer_text = torch.matmul(attention_probs_text, value_layer_text)
110
+ context_layer_text = context_layer_text.permute(0, 2, 1, 3).contiguous()
111
+ context_layer_it = torch.matmul(attention_probs_it, value_layer_text)
112
+ context_layer_it = context_layer_it.permute(0, 2, 1, 3).contiguous()
113
+ context_layer_ti = torch.matmul(attention_probs_ti, value_layer_img)
114
+ context_layer_ti = context_layer_ti.permute(0, 2, 1, 3).contiguous()
115
+ new_context_layer_shape = context_layer_img.size()[:-2] + (self.all_head_size,)
116
+ context_layer_img = context_layer_img.view(*new_context_layer_shape)
117
+ new_context_layer_shape = context_layer_text.size()[:-2] + (self.all_head_size,)
118
+ context_layer_text = context_layer_text.view(*new_context_layer_shape)
119
+ new_context_layer_shape = context_layer_it.size()[:-2] + (self.all_head_size,)
120
+ context_layer_it = context_layer_it.view(*new_context_layer_shape)
121
+ new_context_layer_shape = context_layer_ti.size()[:-2] + (self.all_head_size,)
122
+ context_layer_ti = context_layer_ti.view(*new_context_layer_shape)
123
+ attention_output_img = self.out((context_layer_img + context_layer_it)/2)
124
+ attention_output_text = self.out((context_layer_text + context_layer_ti)/2)
125
+ attention_output_img = self.proj_dropout(attention_output_img)
126
+ attention_output_text = self.proj_dropout_text(attention_output_text)
127
+
128
+ return attention_output_img, attention_output_text
models/block.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import copy
7
+ import logging
8
+ import math
9
+
10
+ from os.path import join as pjoin
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+
16
+ from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
17
+ from torch.nn.modules.utils import _pair
18
+ from scipy import ndimage
19
+
20
+ import models.configs as configs
21
+ from models.attention import Attention
22
+ from models.embed import Embeddings
23
+ from models.mlp import Mlp
24
+
25
+ ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
26
+ ATTENTION_K = "MultiHeadDotProductAttention_1/key"
27
+ ATTENTION_V = "MultiHeadDotProductAttention_1/value"
28
+ ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
29
+ FC_0 = "MlpBlock_3/Dense_0"
30
+ FC_1 = "MlpBlock_3/Dense_1"
31
+ ATTENTION_NORM = "LayerNorm_0"
32
+ MLP_NORM = "LayerNorm_2"
33
+
34
+ class Block(nn.Module):
35
+ def __init__(self, config, vis, mm=True):
36
+ super(Block, self).__init__()
37
+ self.hidden_size = config.hidden_size
38
+ self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
39
+ self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
40
+ if mm:
41
+ self.att_norm_text = LayerNorm(config.hidden_size, eps=1e-6)
42
+ self.ffn_norm_text = LayerNorm(config.hidden_size, eps=1e-6)
43
+ self.ffn_text = Mlp(config)
44
+
45
+ self.ffn = Mlp(config)
46
+ self.attn = Attention(config, vis, mm)
47
+
48
+ def forward(self, x, text=None):
49
+ if text is None:
50
+ h = x
51
+ x = self.attention_norm(x)
52
+ x, text,weights = self.attn(x)
53
+
54
+ x = x + h
55
+
56
+ h = x
57
+ x = self.ffn_norm(x)
58
+ x = self.ffn(x)
59
+ x = x + h
60
+ return x
61
+ else:
62
+ h = x
63
+ h_text = text
64
+ x = self.attention_norm(x)
65
+ text = self.att_norm_text(text)
66
+
67
+ x, text, weights_img = self.attn(x, text)
68
+
69
+ x = x + h
70
+ text = text + h_text
71
+
72
+ h = x
73
+ h_text = text
74
+ x = self.ffn_norm(x)
75
+ text = self.ffn_norm_text(text)
76
+ x = self.ffn(x)
77
+ text = self.ffn_text(text)
78
+ x = x + h
79
+ text = text + h_text
80
+
81
+ return x
82
+
83
+ def load_from(self, weights, n_block):
84
+ ROOT = f"Transformer/encoderblock_{n_block}"
85
+ with torch.no_grad():
86
+ query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
87
+ key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
88
+ value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
89
+ out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
90
+
91
+ query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
92
+ key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
93
+ value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
94
+ out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
95
+
96
+ self.attn.query.weight.copy_(query_weight)
97
+ self.attn.key.weight.copy_(key_weight)
98
+ self.attn.value.weight.copy_(value_weight)
99
+ self.attn.out.weight.copy_(out_weight)
100
+ self.attn.query.bias.copy_(query_bias)
101
+ self.attn.key.bias.copy_(key_bias)
102
+ self.attn.value.bias.copy_(value_bias)
103
+ self.attn.out.bias.copy_(out_bias)
104
+
105
+ mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
106
+ mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
107
+ mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
108
+ mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
109
+
110
+ self.ffn.fc1.weight.copy_(mlp_weight_0)
111
+ self.ffn.fc2.weight.copy_(mlp_weight_1)
112
+ self.ffn.fc1.bias.copy_(mlp_bias_0)
113
+ self.ffn.fc2.bias.copy_(mlp_bias_1)
114
+
115
+ self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
116
+ self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
117
+ self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
118
+ self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
119
+
120
+
models/configs.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import ml_collections
16
+
17
+ def get_IRENE_config():
18
+ """Returns the PATHOLOGICAL_CLASSFIER configuration."""
19
+ config = ml_collections.ConfigDict()
20
+ config.patches = ml_collections.ConfigDict({'size': (16, 16)})
21
+ config.hidden_size = 512
22
+ config.transformer = ml_collections.ConfigDict()
23
+ config.transformer.mlp_dim = 1024
24
+ config.transformer.num_heads = 1 #需要被hidden_size整除
25
+ config.transformer.num_layers = 4 # 其他三个院训练后续模型都是4 TCGA用的2
26
+ # config.transformer.num_layers = 2 # 其他三个院训练后续模型都是4 TCGA用的2
27
+ config.transformer.attention_dropout_rate = 0.2 # 0.0 - 0.2
28
+ config.transformer.dropout_rate = 0.3 # 0.1 - 0.3
29
+ config.classifier = 'token'
30
+ config.representation_size = None
31
+ config.cc_len = 40
32
+ config.lab_len = 92
33
+ return config
models/embed.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import copy
7
+ import logging
8
+ import math
9
+
10
+ from os.path import join as pjoin
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+
16
+ from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
17
+ from torch.nn.modules.utils import _pair
18
+ from scipy import ndimage
19
+
20
+ import models.configs as configs
21
+ from models.attention import Attention
22
+ import pdb
23
+
24
+ class Embeddings(nn.Module):
25
+ """Construct the embeddings from patch, position embeddings.
26
+ """
27
+ def __init__(self, config, img_size, in_channels=3):
28
+ super(Embeddings, self).__init__()
29
+ self.hybrid = None
30
+ img_size = _pair(img_size)
31
+ tk_lim = config.cc_len
32
+ num_lab = config.lab_len
33
+
34
+ if config.patches.get("grid") is not None:
35
+ grid_size = config.patches["grid"]
36
+ patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
37
+ n_patches = (img_size[0] // 16) * (img_size[1] // 16)
38
+ self.hybrid = True
39
+ else:
40
+ patch_size = _pair(config.patches["size"])
41
+ n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
42
+ self.hybrid = False
43
+
44
+ self.patch_embeddings = Conv2d(in_channels=in_channels,
45
+ out_channels=config.hidden_size,
46
+ kernel_size=patch_size,
47
+ stride=patch_size)
48
+ self.cc_embeddings = Linear(768, config.hidden_size)
49
+ self.lab_embeddings = Linear(1, config.hidden_size)
50
+ self.sex_embeddings = Linear(1, config.hidden_size)
51
+ self.age_embeddings = Linear(1, config.hidden_size)
52
+
53
+ self.position_embeddings = nn.Parameter(torch.zeros(1, 1+n_patches, config.hidden_size))
54
+ self.pe_txt = nn.Parameter(torch.zeros(1, tk_lim, config.hidden_size))
55
+ self.pe_lab = nn.Parameter(torch.zeros(1, num_lab, config.hidden_size))
56
+ self.pe_sex = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
57
+ self.pe_age = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
58
+
59
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
60
+
61
+ self.dropout = Dropout(config.transformer["dropout_rate"])
62
+ self.dropout_txt = Dropout(config.transformer["dropout_rate"])
63
+ self.dropout_lab = Dropout(config.transformer["dropout_rate"])
64
+ self.dropout_sex = Dropout(config.transformer["dropout_rate"])
65
+ self.dropout_age = Dropout(config.transformer["dropout_rate"])
66
+
67
+ def forward(self, x, txt, lab, sex, age):
68
+ B = x.shape[0]
69
+ cls_tokens = self.cls_token.expand(B, -1, -1)
70
+
71
+ if self.hybrid:
72
+ x = self.hybrid_model(x)
73
+ x = self.patch_embeddings(x) # 16*16 --> CNN --> 1*1
74
+ txt = self.cc_embeddings(txt)
75
+ lab = self.lab_embeddings(lab)
76
+ sex = self.sex_embeddings(sex)
77
+ age = self.age_embeddings(age)
78
+
79
+ x = x.flatten(2)
80
+ x = x.transpose(-1, -2)
81
+ x = torch.cat((cls_tokens, x), dim=1)
82
+
83
+ embeddings = x + self.position_embeddings
84
+ cc_embeddings = txt + self.pe_txt
85
+ lab_embeddings = lab + self.pe_lab
86
+ sex_embeddings = sex + self.pe_sex
87
+ age_embeddings = age + self.pe_age
88
+
89
+ embeddings = self.dropout(embeddings)
90
+ cc_embeddings = self.dropout_txt(cc_embeddings)
91
+ lab_embeddings = self.dropout_lab(lab_embeddings)
92
+ sex_embeddings = self.dropout_sex(sex_embeddings)
93
+ age_embeddings = self.dropout_age(age_embeddings)
94
+ return embeddings, cc_embeddings, lab_embeddings, sex_embeddings, age_embeddings
95
+
96
+
97
+
models/encoder.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import copy
7
+ import logging
8
+ import math
9
+
10
+ from os.path import join as pjoin
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+
16
+ from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
17
+ from torch.nn.modules.utils import _pair
18
+ from scipy import ndimage
19
+
20
+ import models.configs as configs
21
+ from models.attention import Attention
22
+ from models.embed import Embeddings
23
+ from models.mlp import Mlp
24
+ from models.block import Block
25
+
26
+ class Encoder(nn.Module):
27
+ def __init__(self, config, vis,mm):
28
+ super(Encoder, self).__init__()
29
+ self.vis = vis
30
+ self.layer = nn.ModuleList()
31
+ self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
32
+ for i in range(config.transformer["num_layers"]):
33
+ if i < 2:
34
+ layer = Block(config, vis, mm)
35
+ else:
36
+ layer = Block(config, vis,mm=False)
37
+ self.layer.append(copy.deepcopy(layer))
38
+ self.img_adaptive_avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 512))
39
+ self.txt_adaptive_avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 512))
40
+
41
+ def forward(self, hidden_states, text=None):
42
+ for (i, layer_block) in enumerate(self.layer):
43
+ if i == 2:
44
+ if text is not None:
45
+ hidden_states = self.img_adaptive_avg_pool(hidden_states)
46
+ text = self.txt_adaptive_avg_pool(text)
47
+ hidden_states = torch.cat((hidden_states, text), 1)
48
+ hidden_states,text ,weights = layer_block(hidden_states)
49
+ else:
50
+ hidden_states, text, weights = layer_block(hidden_states)
51
+ elif i < 2:
52
+ hidden_states, text, weights = layer_block(hidden_states, text)
53
+ else:
54
+ hidden_states,text, weights = layer_block(hidden_states)
55
+ encoded = self.encoder_norm(hidden_states)
56
+ return encoded
models/mlp.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import copy
7
+ import logging
8
+ import math
9
+
10
+ from os.path import join as pjoin
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+
16
+ from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
17
+ from torch.nn.modules.utils import _pair
18
+ from scipy import ndimage
19
+
20
+ import models.configs as configs
21
+ from models.attention import Attention
22
+ from models.embed import Embeddings
23
+
24
+ import pdb
25
+
26
+ def swish(x):
27
+ return x * torch.sigmoid(x)
28
+
29
+ ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
30
+
31
+ class Mlp(nn.Module):
32
+ def __init__(self, config):
33
+ super(Mlp, self).__init__()
34
+ self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
35
+ self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
36
+ self.act_fn = ACT2FN["gelu"]
37
+ self.dropout = Dropout(config.transformer["dropout_rate"])
38
+
39
+ self._init_weights()
40
+
41
+ def _init_weights(self):
42
+ nn.init.xavier_uniform_(self.fc1.weight)
43
+ nn.init.xavier_uniform_(self.fc2.weight)
44
+ nn.init.normal_(self.fc1.bias, std=1e-6)
45
+ nn.init.normal_(self.fc2.bias, std=1e-6)
46
+
47
+ def forward(self, x):
48
+ x = self.fc1(x)
49
+ x = self.act_fn(x)
50
+ x = self.dropout(x)
51
+ x = self.fc2(x)
52
+ x = self.dropout(x)
53
+ return x
54
+
55
+
56
+
models/modeling.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ from __future__ import absolute_import
3
+ from __future__ import division
4
+ from __future__ import print_function
5
+
6
+ import copy
7
+ import logging
8
+ import math
9
+ from os.path import join as pjoin
10
+ import torch
11
+ import torch.nn as nn
12
+ import numpy as np
13
+ from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
14
+ from torch.nn.modules.utils import _pair
15
+ from scipy import ndimage
16
+ import models.configs as configs
17
+ from models.attention import Attention
18
+ from models.embed import Embeddings
19
+ from models.mlp import Mlp
20
+ from models.block import Block
21
+ from models.encoder import Encoder
22
+
23
+
24
+ class Transformer(nn.Module):
25
+ def __init__(self, config, img_size, vis,mm):
26
+ super(Transformer, self).__init__()
27
+ self.encoder = Encoder(config, vis,mm)
28
+ def forward(self, input_ids, txt=None):
29
+
30
+ text = txt
31
+ encoded = self.encoder(input_ids, text)
32
+
33
+ return encoded
34
+
35
+ class PATHOLOGICAL_CLASSFIER(nn.Module):
36
+ def __init__(self, config, img_size=224, num_classes=2, vis=True,mm=True):
37
+ super(PATHOLOGICAL_CLASSFIER, self).__init__()
38
+ self.num_classes = num_classes
39
+ self.transformer1 = Transformer(config, img_size, vis=True,mm=mm)
40
+ self.transformer2 = Transformer(config, img_size, vis=True,mm=mm)
41
+ self.head1 = Linear(config.hidden_size, num_classes)
42
+ self.head2 = Linear(config.hidden_size, num_classes)
43
+
44
+ def forward(self, x, txt=None):
45
+ x1 = self.transformer1(x, txt)
46
+ logits_1 = self.head1(torch.mean(x1, dim=1))
47
+
48
+ x2 = self.transformer2(x, txt)
49
+ logits_2 = self.head2(torch.mean(x2, dim=1))
50
+
51
+ return logits_1, logits_2
52
+
53
+ CONFIGS = {
54
+ 'PATHOLOGICAL_CLASSFIER': configs.get_IRENE_config(),
55
+ }
56
+
57
+