Upload 10 files
Browse filesadd inferecce code and the trained model file.
- checkpoints/model.pt +3 -0
- inference.py +125 -0
- models/__init__.py +0 -0
- models/attention.py +128 -0
- models/block.py +120 -0
- models/configs.py +33 -0
- models/embed.py +97 -0
- models/encoder.py +56 -0
- models/mlp.py +56 -0
- models/modeling.py +57 -0
checkpoints/model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e7ccea2005799786c18f5adc27b1e6586fc614fb6bf8e56092a44bc23adf398
|
3 |
+
size 101028923
|
inference.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, division
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import argparse
|
6 |
+
import warnings
|
7 |
+
import torch
|
8 |
+
import pickle
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.optim as optim
|
11 |
+
import pandas as pd
|
12 |
+
import numpy as np
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
16 |
+
from torchvision import transforms, utils
|
17 |
+
from models.modeling import PATHOLOGICAL_CLASSFIER, CONFIGS
|
18 |
+
|
19 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
20 |
+
def load_weights(model, weight_path):
|
21 |
+
print("Loading PATHOLOGICAL_CLASSFIER...",weight_path)
|
22 |
+
loadnet = torch.load(weight_path,map_location=device)
|
23 |
+
if "model_state_dict" in loadnet:
|
24 |
+
keyname = "model_state_dict"
|
25 |
+
else:
|
26 |
+
keyname = "model_state_dict"
|
27 |
+
model.load_state_dict(loadnet[keyname], strict=True)
|
28 |
+
return model
|
29 |
+
|
30 |
+
class MyDataset(Dataset):
|
31 |
+
def __init__(self, root_path):
|
32 |
+
m_data = []
|
33 |
+
img_pkl_file_path = os.path.join(root_path, "img_feature")
|
34 |
+
txt_pkl_file_path = os.path.join(root_path, "txt_feature")
|
35 |
+
target_pkl_file_path = os.path.join(root_path, "target")
|
36 |
+
for file in os.listdir(img_pkl_file_path):
|
37 |
+
|
38 |
+
img_pkl_file = os.path.join(img_pkl_file_path, file)
|
39 |
+
txt_pkl_file = os.path.join(txt_pkl_file_path, file)
|
40 |
+
target_pkl_file = os.path.join(target_pkl_file_path, file)
|
41 |
+
with open(img_pkl_file, "rb") as img_f:
|
42 |
+
img_load_dict = pickle.load(img_f)
|
43 |
+
m_input_img = img_load_dict["img_feature"]
|
44 |
+
with open(txt_pkl_file, "rb") as txt_f:
|
45 |
+
txt_load_dict = pickle.load(txt_f)
|
46 |
+
m_input_txt = txt_load_dict["txt_feature"]
|
47 |
+
with open(target_pkl_file, "rb") as target_f:
|
48 |
+
target_load_dict = pickle.load(target_f)
|
49 |
+
m_output_os = target_load_dict["target_os"]
|
50 |
+
m_output_dfs = target_load_dict["target_dfs"]
|
51 |
+
m_data.append((m_input_img, m_input_txt, m_output_os, m_output_dfs,file))
|
52 |
+
self.m_data = m_data
|
53 |
+
def __getitem__(self, idx):
|
54 |
+
inp_i, inp_txt, oup_os, oup_dfs,f_name = self.m_data[idx]
|
55 |
+
return inp_i, inp_txt, oup_os, oup_dfs,f_name
|
56 |
+
def __len__(self):
|
57 |
+
return len(self.m_data)
|
58 |
+
|
59 |
+
def valid(args):
|
60 |
+
torch.manual_seed(0)
|
61 |
+
num_classes = 2
|
62 |
+
config = CONFIGS["PATHOLOGICAL_CLASSFIER"]
|
63 |
+
model = PATHOLOGICAL_CLASSFIER(config, num_classes=num_classes, vis=True, mm=True)
|
64 |
+
|
65 |
+
model_path = '/your/trained/model/path/'
|
66 |
+
p_c_model = load_weights(model, model_path)
|
67 |
+
|
68 |
+
p_c_model.to(device)
|
69 |
+
test_dataset = MyDataset("/your/dataset/path/" )
|
70 |
+
test_loader = DataLoader(test_dataset, batch_size=1)
|
71 |
+
|
72 |
+
# #----- Test ------
|
73 |
+
print("--------Start testing-------")
|
74 |
+
p_c_model.eval()
|
75 |
+
|
76 |
+
valid_1_acc = 0
|
77 |
+
valid_1_total = 0
|
78 |
+
valid_1_cnt = 0
|
79 |
+
|
80 |
+
valid_2_acc = 0
|
81 |
+
valid_2_total = 0
|
82 |
+
valid_2_cnt = 0
|
83 |
+
valid_total_cnt=0
|
84 |
+
|
85 |
+
target_cnt_0=0
|
86 |
+
target_cnt_1=0
|
87 |
+
with torch.no_grad():
|
88 |
+
for imgs, txt, target_1, target_2,file_name in test_loader:
|
89 |
+
output_1, output_2, = model(imgs.to(device), txt.to(device))
|
90 |
+
|
91 |
+
out_1_list_prob = (torch.softmax(output_1.squeeze(1), axis=-1).cpu().numpy().tolist())
|
92 |
+
|
93 |
+
out_1_list = (torch.argmax(output_1.squeeze(1), axis=-1).cpu().numpy().tolist())
|
94 |
+
target_1_list = target_1.tolist()
|
95 |
+
|
96 |
+
out_2_list = (torch.argmax(output_2.squeeze(1), axis=-1).cpu().numpy().tolist())
|
97 |
+
target_2_list = target_2.tolist()
|
98 |
+
|
99 |
+
valid_1_total += len(out_1_list)
|
100 |
+
valid_2_total += len(out_2_list)
|
101 |
+
|
102 |
+
for i in range(len(out_1_list)):
|
103 |
+
if out_1_list[i] == target_1_list[i]:
|
104 |
+
valid_1_cnt += 1
|
105 |
+
if out_2_list[i] == target_2_list[i]:
|
106 |
+
valid_2_cnt += 1
|
107 |
+
if out_1_list[i] == target_1_list[i] and out_2_list[i] == target_2_list[i]:
|
108 |
+
valid_total_cnt+=1
|
109 |
+
|
110 |
+
valid_1_acc = valid_1_cnt / valid_1_total
|
111 |
+
valid_2_acc = valid_2_cnt / valid_2_total
|
112 |
+
valid_total_acc =valid_total_cnt/valid_1_total
|
113 |
+
|
114 |
+
print(valid_1_acc,valid_1_total, valid_2_acc,valid_2_total,valid_total_acc,valid_total_cnt)
|
115 |
+
print("="*100)
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
parser = argparse.ArgumentParser(description="")
|
119 |
+
args = parser.parse_args()
|
120 |
+
valid(args)
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
|
models/__init__.py
ADDED
File without changes
|
models/attention.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
|
4 |
+
import models.configs as configs
|
5 |
+
import math
|
6 |
+
|
7 |
+
class Attention(nn.Module):
|
8 |
+
def __init__(self, config, vis, mm=True):
|
9 |
+
super(Attention, self).__init__()
|
10 |
+
self.vis = vis
|
11 |
+
self.num_attention_heads = config.transformer["num_heads"]
|
12 |
+
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
|
13 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
14 |
+
|
15 |
+
self.query = Linear(config.hidden_size, self.all_head_size)
|
16 |
+
self.key = Linear(config.hidden_size, self.all_head_size)
|
17 |
+
self.value = Linear(config.hidden_size, self.all_head_size)
|
18 |
+
|
19 |
+
if mm:
|
20 |
+
self.query_text = Linear(config.hidden_size, self.all_head_size)
|
21 |
+
self.key_text = Linear(config.hidden_size, self.all_head_size)
|
22 |
+
self.value_text = Linear(config.hidden_size, self.all_head_size)
|
23 |
+
self.out_text = Linear(config.hidden_size, config.hidden_size)
|
24 |
+
self.attn_dropout_text = Dropout(config.transformer["attention_dropout_rate"])
|
25 |
+
self.attn_dropout_it = Dropout(config.transformer["attention_dropout_rate"])
|
26 |
+
self.attn_dropout_ti = Dropout(config.transformer["attention_dropout_rate"])
|
27 |
+
self.proj_dropout_text = Dropout(config.transformer["attention_dropout_rate"])
|
28 |
+
|
29 |
+
self.out = Linear(config.hidden_size, config.hidden_size)
|
30 |
+
self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
|
31 |
+
self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])
|
32 |
+
|
33 |
+
self.softmax = Softmax(dim=-1)
|
34 |
+
|
35 |
+
def transpose_for_scores(self, x):
|
36 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
37 |
+
x = x.view(*new_x_shape)
|
38 |
+
return x.permute(0, 2, 1, 3)
|
39 |
+
|
40 |
+
def forward(self, hidden_states, text=None):
|
41 |
+
mixed_query_layer = self.query(hidden_states)
|
42 |
+
mixed_key_layer = self.key(hidden_states)
|
43 |
+
mixed_value_layer = self.value(hidden_states)
|
44 |
+
if text is not None:
|
45 |
+
text_q = self.query_text(text)
|
46 |
+
text_k = self.key_text(text)
|
47 |
+
text_v = self.value_text(text)
|
48 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
49 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
50 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
51 |
+
if text is not None:
|
52 |
+
query_layer_img = query_layer
|
53 |
+
key_layer_img = key_layer
|
54 |
+
value_layer_img = value_layer
|
55 |
+
query_layer_text = self.transpose_for_scores(text_q)
|
56 |
+
key_layer_text = self.transpose_for_scores(text_k)
|
57 |
+
value_layer_text = self.transpose_for_scores(text_v)
|
58 |
+
|
59 |
+
if text is None:
|
60 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
61 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
62 |
+
attention_probs = self.softmax(attention_scores)
|
63 |
+
weights = attention_probs if self.vis else None
|
64 |
+
attention_probs = self.attn_dropout(attention_probs)
|
65 |
+
|
66 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
67 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
68 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
69 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
70 |
+
attention_output = self.out(context_layer)
|
71 |
+
attention_output = self.proj_dropout(attention_output)
|
72 |
+
|
73 |
+
return attention_output,None, weights
|
74 |
+
|
75 |
+
else:
|
76 |
+
attention_scores_img = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
77 |
+
attention_scores_text = torch.matmul(query_layer_text, key_layer_text.transpose(-1, -2))
|
78 |
+
attention_scores_it = torch.matmul(query_layer_img, key_layer_text.transpose(-1, -2))
|
79 |
+
attention_scores_ti = torch.matmul(query_layer_text, key_layer_img.transpose(-1, -2))
|
80 |
+
attention_scores_img = attention_scores_img / math.sqrt(self.attention_head_size)
|
81 |
+
|
82 |
+
attention_probs_img = self.softmax(attention_scores_img)
|
83 |
+
weights_img = attention_probs_img if self.vis else None
|
84 |
+
|
85 |
+
attention_probs_img = self.attn_dropout(attention_probs_img)
|
86 |
+
|
87 |
+
attention_scores_text = attention_scores_text / math.sqrt(self.attention_head_size)
|
88 |
+
attention_probs_text = self.softmax(attention_scores_text)
|
89 |
+
|
90 |
+
text_per_weights = attention_probs_text.mean(dim=-1)
|
91 |
+
|
92 |
+
text_per_weights = self.softmax(text_per_weights)
|
93 |
+
|
94 |
+
weights_text = attention_probs_text if self.vis else None
|
95 |
+
|
96 |
+
attention_probs_text = self.attn_dropout_text(attention_probs_text)
|
97 |
+
|
98 |
+
attention_scores_it = attention_scores_it / math.sqrt(self.attention_head_size)
|
99 |
+
attention_probs_it = self.softmax(attention_scores_it)
|
100 |
+
|
101 |
+
attention_probs_it = self.attn_dropout_it(attention_probs_it)
|
102 |
+
|
103 |
+
attention_scores_ti = attention_scores_ti / math.sqrt(self.attention_head_size)
|
104 |
+
attention_probs_ti = self.softmax(attention_scores_ti)
|
105 |
+
attention_probs_ti = self.attn_dropout_ti(attention_probs_ti)
|
106 |
+
|
107 |
+
context_layer_img = torch.matmul(attention_probs_img, value_layer_img)
|
108 |
+
context_layer_img = context_layer_img.permute(0, 2, 1, 3).contiguous()
|
109 |
+
context_layer_text = torch.matmul(attention_probs_text, value_layer_text)
|
110 |
+
context_layer_text = context_layer_text.permute(0, 2, 1, 3).contiguous()
|
111 |
+
context_layer_it = torch.matmul(attention_probs_it, value_layer_text)
|
112 |
+
context_layer_it = context_layer_it.permute(0, 2, 1, 3).contiguous()
|
113 |
+
context_layer_ti = torch.matmul(attention_probs_ti, value_layer_img)
|
114 |
+
context_layer_ti = context_layer_ti.permute(0, 2, 1, 3).contiguous()
|
115 |
+
new_context_layer_shape = context_layer_img.size()[:-2] + (self.all_head_size,)
|
116 |
+
context_layer_img = context_layer_img.view(*new_context_layer_shape)
|
117 |
+
new_context_layer_shape = context_layer_text.size()[:-2] + (self.all_head_size,)
|
118 |
+
context_layer_text = context_layer_text.view(*new_context_layer_shape)
|
119 |
+
new_context_layer_shape = context_layer_it.size()[:-2] + (self.all_head_size,)
|
120 |
+
context_layer_it = context_layer_it.view(*new_context_layer_shape)
|
121 |
+
new_context_layer_shape = context_layer_ti.size()[:-2] + (self.all_head_size,)
|
122 |
+
context_layer_ti = context_layer_ti.view(*new_context_layer_shape)
|
123 |
+
attention_output_img = self.out((context_layer_img + context_layer_it)/2)
|
124 |
+
attention_output_text = self.out((context_layer_text + context_layer_ti)/2)
|
125 |
+
attention_output_img = self.proj_dropout(attention_output_img)
|
126 |
+
attention_output_text = self.proj_dropout_text(attention_output_text)
|
127 |
+
|
128 |
+
return attention_output_img, attention_output_text
|
models/block.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
from __future__ import absolute_import
|
3 |
+
from __future__ import division
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
|
10 |
+
from os.path import join as pjoin
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
|
17 |
+
from torch.nn.modules.utils import _pair
|
18 |
+
from scipy import ndimage
|
19 |
+
|
20 |
+
import models.configs as configs
|
21 |
+
from models.attention import Attention
|
22 |
+
from models.embed import Embeddings
|
23 |
+
from models.mlp import Mlp
|
24 |
+
|
25 |
+
ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
|
26 |
+
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
|
27 |
+
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
|
28 |
+
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
|
29 |
+
FC_0 = "MlpBlock_3/Dense_0"
|
30 |
+
FC_1 = "MlpBlock_3/Dense_1"
|
31 |
+
ATTENTION_NORM = "LayerNorm_0"
|
32 |
+
MLP_NORM = "LayerNorm_2"
|
33 |
+
|
34 |
+
class Block(nn.Module):
|
35 |
+
def __init__(self, config, vis, mm=True):
|
36 |
+
super(Block, self).__init__()
|
37 |
+
self.hidden_size = config.hidden_size
|
38 |
+
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
39 |
+
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
40 |
+
if mm:
|
41 |
+
self.att_norm_text = LayerNorm(config.hidden_size, eps=1e-6)
|
42 |
+
self.ffn_norm_text = LayerNorm(config.hidden_size, eps=1e-6)
|
43 |
+
self.ffn_text = Mlp(config)
|
44 |
+
|
45 |
+
self.ffn = Mlp(config)
|
46 |
+
self.attn = Attention(config, vis, mm)
|
47 |
+
|
48 |
+
def forward(self, x, text=None):
|
49 |
+
if text is None:
|
50 |
+
h = x
|
51 |
+
x = self.attention_norm(x)
|
52 |
+
x, text,weights = self.attn(x)
|
53 |
+
|
54 |
+
x = x + h
|
55 |
+
|
56 |
+
h = x
|
57 |
+
x = self.ffn_norm(x)
|
58 |
+
x = self.ffn(x)
|
59 |
+
x = x + h
|
60 |
+
return x
|
61 |
+
else:
|
62 |
+
h = x
|
63 |
+
h_text = text
|
64 |
+
x = self.attention_norm(x)
|
65 |
+
text = self.att_norm_text(text)
|
66 |
+
|
67 |
+
x, text, weights_img = self.attn(x, text)
|
68 |
+
|
69 |
+
x = x + h
|
70 |
+
text = text + h_text
|
71 |
+
|
72 |
+
h = x
|
73 |
+
h_text = text
|
74 |
+
x = self.ffn_norm(x)
|
75 |
+
text = self.ffn_norm_text(text)
|
76 |
+
x = self.ffn(x)
|
77 |
+
text = self.ffn_text(text)
|
78 |
+
x = x + h
|
79 |
+
text = text + h_text
|
80 |
+
|
81 |
+
return x
|
82 |
+
|
83 |
+
def load_from(self, weights, n_block):
|
84 |
+
ROOT = f"Transformer/encoderblock_{n_block}"
|
85 |
+
with torch.no_grad():
|
86 |
+
query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
87 |
+
key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
88 |
+
value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
89 |
+
out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()
|
90 |
+
|
91 |
+
query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
|
92 |
+
key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
|
93 |
+
value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
|
94 |
+
out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)
|
95 |
+
|
96 |
+
self.attn.query.weight.copy_(query_weight)
|
97 |
+
self.attn.key.weight.copy_(key_weight)
|
98 |
+
self.attn.value.weight.copy_(value_weight)
|
99 |
+
self.attn.out.weight.copy_(out_weight)
|
100 |
+
self.attn.query.bias.copy_(query_bias)
|
101 |
+
self.attn.key.bias.copy_(key_bias)
|
102 |
+
self.attn.value.bias.copy_(value_bias)
|
103 |
+
self.attn.out.bias.copy_(out_bias)
|
104 |
+
|
105 |
+
mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
|
106 |
+
mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
|
107 |
+
mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
|
108 |
+
mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()
|
109 |
+
|
110 |
+
self.ffn.fc1.weight.copy_(mlp_weight_0)
|
111 |
+
self.ffn.fc2.weight.copy_(mlp_weight_1)
|
112 |
+
self.ffn.fc1.bias.copy_(mlp_bias_0)
|
113 |
+
self.ffn.fc2.bias.copy_(mlp_bias_1)
|
114 |
+
|
115 |
+
self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
|
116 |
+
self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
|
117 |
+
self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
|
118 |
+
self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))
|
119 |
+
|
120 |
+
|
models/configs.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Google LLC
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import ml_collections
|
16 |
+
|
17 |
+
def get_IRENE_config():
|
18 |
+
"""Returns the PATHOLOGICAL_CLASSFIER configuration."""
|
19 |
+
config = ml_collections.ConfigDict()
|
20 |
+
config.patches = ml_collections.ConfigDict({'size': (16, 16)})
|
21 |
+
config.hidden_size = 512
|
22 |
+
config.transformer = ml_collections.ConfigDict()
|
23 |
+
config.transformer.mlp_dim = 1024
|
24 |
+
config.transformer.num_heads = 1 #需要被hidden_size整除
|
25 |
+
config.transformer.num_layers = 4 # 其他三个院训练后续模型都是4 TCGA用的2
|
26 |
+
# config.transformer.num_layers = 2 # 其他三个院训练后续模型都是4 TCGA用的2
|
27 |
+
config.transformer.attention_dropout_rate = 0.2 # 0.0 - 0.2
|
28 |
+
config.transformer.dropout_rate = 0.3 # 0.1 - 0.3
|
29 |
+
config.classifier = 'token'
|
30 |
+
config.representation_size = None
|
31 |
+
config.cc_len = 40
|
32 |
+
config.lab_len = 92
|
33 |
+
return config
|
models/embed.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
from __future__ import absolute_import
|
3 |
+
from __future__ import division
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
|
10 |
+
from os.path import join as pjoin
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
|
17 |
+
from torch.nn.modules.utils import _pair
|
18 |
+
from scipy import ndimage
|
19 |
+
|
20 |
+
import models.configs as configs
|
21 |
+
from models.attention import Attention
|
22 |
+
import pdb
|
23 |
+
|
24 |
+
class Embeddings(nn.Module):
|
25 |
+
"""Construct the embeddings from patch, position embeddings.
|
26 |
+
"""
|
27 |
+
def __init__(self, config, img_size, in_channels=3):
|
28 |
+
super(Embeddings, self).__init__()
|
29 |
+
self.hybrid = None
|
30 |
+
img_size = _pair(img_size)
|
31 |
+
tk_lim = config.cc_len
|
32 |
+
num_lab = config.lab_len
|
33 |
+
|
34 |
+
if config.patches.get("grid") is not None:
|
35 |
+
grid_size = config.patches["grid"]
|
36 |
+
patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
|
37 |
+
n_patches = (img_size[0] // 16) * (img_size[1] // 16)
|
38 |
+
self.hybrid = True
|
39 |
+
else:
|
40 |
+
patch_size = _pair(config.patches["size"])
|
41 |
+
n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
|
42 |
+
self.hybrid = False
|
43 |
+
|
44 |
+
self.patch_embeddings = Conv2d(in_channels=in_channels,
|
45 |
+
out_channels=config.hidden_size,
|
46 |
+
kernel_size=patch_size,
|
47 |
+
stride=patch_size)
|
48 |
+
self.cc_embeddings = Linear(768, config.hidden_size)
|
49 |
+
self.lab_embeddings = Linear(1, config.hidden_size)
|
50 |
+
self.sex_embeddings = Linear(1, config.hidden_size)
|
51 |
+
self.age_embeddings = Linear(1, config.hidden_size)
|
52 |
+
|
53 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, 1+n_patches, config.hidden_size))
|
54 |
+
self.pe_txt = nn.Parameter(torch.zeros(1, tk_lim, config.hidden_size))
|
55 |
+
self.pe_lab = nn.Parameter(torch.zeros(1, num_lab, config.hidden_size))
|
56 |
+
self.pe_sex = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
57 |
+
self.pe_age = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
58 |
+
|
59 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
60 |
+
|
61 |
+
self.dropout = Dropout(config.transformer["dropout_rate"])
|
62 |
+
self.dropout_txt = Dropout(config.transformer["dropout_rate"])
|
63 |
+
self.dropout_lab = Dropout(config.transformer["dropout_rate"])
|
64 |
+
self.dropout_sex = Dropout(config.transformer["dropout_rate"])
|
65 |
+
self.dropout_age = Dropout(config.transformer["dropout_rate"])
|
66 |
+
|
67 |
+
def forward(self, x, txt, lab, sex, age):
|
68 |
+
B = x.shape[0]
|
69 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
70 |
+
|
71 |
+
if self.hybrid:
|
72 |
+
x = self.hybrid_model(x)
|
73 |
+
x = self.patch_embeddings(x) # 16*16 --> CNN --> 1*1
|
74 |
+
txt = self.cc_embeddings(txt)
|
75 |
+
lab = self.lab_embeddings(lab)
|
76 |
+
sex = self.sex_embeddings(sex)
|
77 |
+
age = self.age_embeddings(age)
|
78 |
+
|
79 |
+
x = x.flatten(2)
|
80 |
+
x = x.transpose(-1, -2)
|
81 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
82 |
+
|
83 |
+
embeddings = x + self.position_embeddings
|
84 |
+
cc_embeddings = txt + self.pe_txt
|
85 |
+
lab_embeddings = lab + self.pe_lab
|
86 |
+
sex_embeddings = sex + self.pe_sex
|
87 |
+
age_embeddings = age + self.pe_age
|
88 |
+
|
89 |
+
embeddings = self.dropout(embeddings)
|
90 |
+
cc_embeddings = self.dropout_txt(cc_embeddings)
|
91 |
+
lab_embeddings = self.dropout_lab(lab_embeddings)
|
92 |
+
sex_embeddings = self.dropout_sex(sex_embeddings)
|
93 |
+
age_embeddings = self.dropout_age(age_embeddings)
|
94 |
+
return embeddings, cc_embeddings, lab_embeddings, sex_embeddings, age_embeddings
|
95 |
+
|
96 |
+
|
97 |
+
|
models/encoder.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
from __future__ import absolute_import
|
3 |
+
from __future__ import division
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
|
10 |
+
from os.path import join as pjoin
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
|
17 |
+
from torch.nn.modules.utils import _pair
|
18 |
+
from scipy import ndimage
|
19 |
+
|
20 |
+
import models.configs as configs
|
21 |
+
from models.attention import Attention
|
22 |
+
from models.embed import Embeddings
|
23 |
+
from models.mlp import Mlp
|
24 |
+
from models.block import Block
|
25 |
+
|
26 |
+
class Encoder(nn.Module):
|
27 |
+
def __init__(self, config, vis,mm):
|
28 |
+
super(Encoder, self).__init__()
|
29 |
+
self.vis = vis
|
30 |
+
self.layer = nn.ModuleList()
|
31 |
+
self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
|
32 |
+
for i in range(config.transformer["num_layers"]):
|
33 |
+
if i < 2:
|
34 |
+
layer = Block(config, vis, mm)
|
35 |
+
else:
|
36 |
+
layer = Block(config, vis,mm=False)
|
37 |
+
self.layer.append(copy.deepcopy(layer))
|
38 |
+
self.img_adaptive_avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 512))
|
39 |
+
self.txt_adaptive_avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 512))
|
40 |
+
|
41 |
+
def forward(self, hidden_states, text=None):
|
42 |
+
for (i, layer_block) in enumerate(self.layer):
|
43 |
+
if i == 2:
|
44 |
+
if text is not None:
|
45 |
+
hidden_states = self.img_adaptive_avg_pool(hidden_states)
|
46 |
+
text = self.txt_adaptive_avg_pool(text)
|
47 |
+
hidden_states = torch.cat((hidden_states, text), 1)
|
48 |
+
hidden_states,text ,weights = layer_block(hidden_states)
|
49 |
+
else:
|
50 |
+
hidden_states, text, weights = layer_block(hidden_states)
|
51 |
+
elif i < 2:
|
52 |
+
hidden_states, text, weights = layer_block(hidden_states, text)
|
53 |
+
else:
|
54 |
+
hidden_states,text, weights = layer_block(hidden_states)
|
55 |
+
encoded = self.encoder_norm(hidden_states)
|
56 |
+
return encoded
|
models/mlp.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
from __future__ import absolute_import
|
3 |
+
from __future__ import division
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
|
10 |
+
from os.path import join as pjoin
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
|
17 |
+
from torch.nn.modules.utils import _pair
|
18 |
+
from scipy import ndimage
|
19 |
+
|
20 |
+
import models.configs as configs
|
21 |
+
from models.attention import Attention
|
22 |
+
from models.embed import Embeddings
|
23 |
+
|
24 |
+
import pdb
|
25 |
+
|
26 |
+
def swish(x):
|
27 |
+
return x * torch.sigmoid(x)
|
28 |
+
|
29 |
+
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
30 |
+
|
31 |
+
class Mlp(nn.Module):
|
32 |
+
def __init__(self, config):
|
33 |
+
super(Mlp, self).__init__()
|
34 |
+
self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
|
35 |
+
self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
|
36 |
+
self.act_fn = ACT2FN["gelu"]
|
37 |
+
self.dropout = Dropout(config.transformer["dropout_rate"])
|
38 |
+
|
39 |
+
self._init_weights()
|
40 |
+
|
41 |
+
def _init_weights(self):
|
42 |
+
nn.init.xavier_uniform_(self.fc1.weight)
|
43 |
+
nn.init.xavier_uniform_(self.fc2.weight)
|
44 |
+
nn.init.normal_(self.fc1.bias, std=1e-6)
|
45 |
+
nn.init.normal_(self.fc2.bias, std=1e-6)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
x = self.fc1(x)
|
49 |
+
x = self.act_fn(x)
|
50 |
+
x = self.dropout(x)
|
51 |
+
x = self.fc2(x)
|
52 |
+
x = self.dropout(x)
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
|
models/modeling.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
from __future__ import absolute_import
|
3 |
+
from __future__ import division
|
4 |
+
from __future__ import print_function
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
from os.path import join as pjoin
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import numpy as np
|
13 |
+
from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
|
14 |
+
from torch.nn.modules.utils import _pair
|
15 |
+
from scipy import ndimage
|
16 |
+
import models.configs as configs
|
17 |
+
from models.attention import Attention
|
18 |
+
from models.embed import Embeddings
|
19 |
+
from models.mlp import Mlp
|
20 |
+
from models.block import Block
|
21 |
+
from models.encoder import Encoder
|
22 |
+
|
23 |
+
|
24 |
+
class Transformer(nn.Module):
|
25 |
+
def __init__(self, config, img_size, vis,mm):
|
26 |
+
super(Transformer, self).__init__()
|
27 |
+
self.encoder = Encoder(config, vis,mm)
|
28 |
+
def forward(self, input_ids, txt=None):
|
29 |
+
|
30 |
+
text = txt
|
31 |
+
encoded = self.encoder(input_ids, text)
|
32 |
+
|
33 |
+
return encoded
|
34 |
+
|
35 |
+
class PATHOLOGICAL_CLASSFIER(nn.Module):
|
36 |
+
def __init__(self, config, img_size=224, num_classes=2, vis=True,mm=True):
|
37 |
+
super(PATHOLOGICAL_CLASSFIER, self).__init__()
|
38 |
+
self.num_classes = num_classes
|
39 |
+
self.transformer1 = Transformer(config, img_size, vis=True,mm=mm)
|
40 |
+
self.transformer2 = Transformer(config, img_size, vis=True,mm=mm)
|
41 |
+
self.head1 = Linear(config.hidden_size, num_classes)
|
42 |
+
self.head2 = Linear(config.hidden_size, num_classes)
|
43 |
+
|
44 |
+
def forward(self, x, txt=None):
|
45 |
+
x1 = self.transformer1(x, txt)
|
46 |
+
logits_1 = self.head1(torch.mean(x1, dim=1))
|
47 |
+
|
48 |
+
x2 = self.transformer2(x, txt)
|
49 |
+
logits_2 = self.head2(torch.mean(x2, dim=1))
|
50 |
+
|
51 |
+
return logits_1, logits_2
|
52 |
+
|
53 |
+
CONFIGS = {
|
54 |
+
'PATHOLOGICAL_CLASSFIER': configs.get_IRENE_config(),
|
55 |
+
}
|
56 |
+
|
57 |
+
|