yuxin
commited on
Commit
•
6e933c4
1
Parent(s):
da33dfe
add config
Browse files- config.json +0 -1
- merges.txt +0 -0
- model_segvol_single.py +8 -13
- special_tokens_map.json +1 -0
- tokenizer.json +0 -0
- tokenizer_config.json +1 -0
- vocab.json +0 -0
config.json
CHANGED
@@ -6,7 +6,6 @@
|
|
6 |
"AutoConfig": "model_segvol_single.SegVolConfig",
|
7 |
"AutoModel": "model_segvol_single.SegVolModel"
|
8 |
},
|
9 |
-
"clip_model": "openai/clip-vit-base-patch32",
|
10 |
"model_type": "segvol",
|
11 |
"patch_size": [
|
12 |
4,
|
|
|
6 |
"AutoConfig": "model_segvol_single.SegVolConfig",
|
7 |
"AutoModel": "model_segvol_single.SegVolModel"
|
8 |
},
|
|
|
9 |
"model_type": "segvol",
|
10 |
"patch_size": [
|
11 |
4,
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model_segvol_single.py
CHANGED
@@ -9,13 +9,13 @@ class SegVolConfig(PretrainedConfig):
|
|
9 |
def __init__(
|
10 |
self,
|
11 |
test_mode=True,
|
12 |
-
clip_model='
|
13 |
**kwargs,
|
14 |
):
|
15 |
self.spatial_size = [32, 256, 256]
|
16 |
self.patch_size = [4, 16, 16]
|
17 |
self.test_mode = test_mode
|
18 |
-
self.clip_model = clip_model
|
19 |
super().__init__(**kwargs)
|
20 |
|
21 |
class SegVolModel(PreTrainedModel):
|
@@ -36,7 +36,7 @@ class SegVolModel(PreTrainedModel):
|
|
36 |
prompt_encoder=sam_model.prompt_encoder,
|
37 |
roi_size=self.config.spatial_size,
|
38 |
patch_size=self.config.patch_size,
|
39 |
-
clip_model=self.config.clip_model,
|
40 |
test_mode=self.config.test_mode,
|
41 |
)
|
42 |
|
@@ -118,7 +118,6 @@ class SegVolModel(PreTrainedModel):
|
|
118 |
return logits_global_single
|
119 |
|
120 |
def forward_train(self, image, train_organs, train_labels):
|
121 |
-
print('in forward_train')
|
122 |
loss = self.model(image, text=None, boxes=None, points=None,
|
123 |
train_organs=train_organs,
|
124 |
train_labels=train_labels)
|
@@ -318,7 +317,6 @@ def generate_box(pred_pre, bbox_shift=None):
|
|
318 |
ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True)
|
319 |
if all(tensor.nelement() == 0 for tensor in ones_idx):
|
320 |
bboxes = torch.tensor([-1,-1,-1,-1,-1,-1])
|
321 |
-
# print(bboxes, bboxes.shape)
|
322 |
return bboxes
|
323 |
min_coords = [dim.min() for dim in ones_idx] # [x_min, y_min, z_min]
|
324 |
max_coords = [dim.max() for dim in ones_idx] # [x_max, y_max, z_max]
|
@@ -395,8 +393,6 @@ def select_points(preds, num_positive_extra=4, num_negative_extra=0, fix_extra_p
|
|
395 |
extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
|
396 |
points = torch.cat((points, extra_negative_points), dim=0)
|
397 |
labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
|
398 |
-
# print('extra_negative_points ', extra_negative_points, extra_negative_points.shape)
|
399 |
-
# print('==> points ', points.shape, labels)
|
400 |
|
401 |
if fix_extra_point_num is None:
|
402 |
left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
|
@@ -415,7 +411,7 @@ import torch
|
|
415 |
import torch.nn as nn
|
416 |
import torch.nn.functional as F
|
417 |
import numpy as np
|
418 |
-
from transformers import
|
419 |
import random
|
420 |
|
421 |
#%% set up model
|
@@ -426,7 +422,7 @@ class SegVol(nn.Module):
|
|
426 |
prompt_encoder,
|
427 |
roi_size,
|
428 |
patch_size,
|
429 |
-
clip_model,
|
430 |
test_mode=False,
|
431 |
):
|
432 |
super().__init__()
|
@@ -434,7 +430,7 @@ class SegVol(nn.Module):
|
|
434 |
self.image_encoder = image_encoder
|
435 |
self.mask_decoder = mask_decoder
|
436 |
self.prompt_encoder = prompt_encoder
|
437 |
-
self.text_encoder = TextEncoder(
|
438 |
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
439 |
self.test_mode = test_mode
|
440 |
self.dice_loss = BinaryDiceLoss().to(self.custom_device)
|
@@ -453,7 +449,6 @@ class SegVol(nn.Module):
|
|
453 |
|
454 |
# train mode
|
455 |
## sl
|
456 |
-
print('supervised_forward ready')
|
457 |
sl_loss = self.supervised_forward(image, image_embedding, img_shape, kwargs['train_organs'], kwargs['train_labels'])
|
458 |
## ssl
|
459 |
# ssl_loss = self.unsupervised_forward(image, image_embedding, kwargs['pseudo_seg_cleaned'], img_shape)
|
@@ -612,12 +607,12 @@ class SegVol(nn.Module):
|
|
612 |
# return pseudo_labels, bboxes
|
613 |
|
614 |
class TextEncoder(nn.Module):
|
615 |
-
def __init__(self
|
616 |
super().__init__()
|
617 |
self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
618 |
config = CLIPTextConfig()
|
619 |
self.clip_text_model = CLIPTextModel(config)
|
620 |
-
self.tokenizer =
|
621 |
self.dim_align = nn.Linear(512, 768)
|
622 |
# freeze text encoder
|
623 |
for param in self.clip_text_model.parameters():
|
|
|
9 |
def __init__(
|
10 |
self,
|
11 |
test_mode=True,
|
12 |
+
# clip_model='.',
|
13 |
**kwargs,
|
14 |
):
|
15 |
self.spatial_size = [32, 256, 256]
|
16 |
self.patch_size = [4, 16, 16]
|
17 |
self.test_mode = test_mode
|
18 |
+
# self.clip_model = clip_model
|
19 |
super().__init__(**kwargs)
|
20 |
|
21 |
class SegVolModel(PreTrainedModel):
|
|
|
36 |
prompt_encoder=sam_model.prompt_encoder,
|
37 |
roi_size=self.config.spatial_size,
|
38 |
patch_size=self.config.patch_size,
|
39 |
+
# clip_model=self.config.clip_model,
|
40 |
test_mode=self.config.test_mode,
|
41 |
)
|
42 |
|
|
|
118 |
return logits_global_single
|
119 |
|
120 |
def forward_train(self, image, train_organs, train_labels):
|
|
|
121 |
loss = self.model(image, text=None, boxes=None, points=None,
|
122 |
train_organs=train_organs,
|
123 |
train_labels=train_labels)
|
|
|
317 |
ones_idx = (meaning_post_label > 0).nonzero(as_tuple=True)
|
318 |
if all(tensor.nelement() == 0 for tensor in ones_idx):
|
319 |
bboxes = torch.tensor([-1,-1,-1,-1,-1,-1])
|
|
|
320 |
return bboxes
|
321 |
min_coords = [dim.min() for dim in ones_idx] # [x_min, y_min, z_min]
|
322 |
max_coords = [dim.max() for dim in ones_idx] # [x_max, y_max, z_max]
|
|
|
393 |
extra_negative_points = torch.tensor(extra_negative_points).reshape(-1, 3)
|
394 |
points = torch.cat((points, extra_negative_points), dim=0)
|
395 |
labels = torch.cat((labels, torch.zeros((extra_negative_points.shape[0]))))
|
|
|
|
|
396 |
|
397 |
if fix_extra_point_num is None:
|
398 |
left_point_num = num_positive_extra + num_negative_extra + 1 - labels.shape[0]
|
|
|
411 |
import torch.nn as nn
|
412 |
import torch.nn.functional as F
|
413 |
import numpy as np
|
414 |
+
from transformers import CLIPTextModel, CLIPTextConfig
|
415 |
import random
|
416 |
|
417 |
#%% set up model
|
|
|
422 |
prompt_encoder,
|
423 |
roi_size,
|
424 |
patch_size,
|
425 |
+
# clip_model,
|
426 |
test_mode=False,
|
427 |
):
|
428 |
super().__init__()
|
|
|
430 |
self.image_encoder = image_encoder
|
431 |
self.mask_decoder = mask_decoder
|
432 |
self.prompt_encoder = prompt_encoder
|
433 |
+
self.text_encoder = TextEncoder()
|
434 |
self.feat_shape = np.array(roi_size)/np.array(patch_size)
|
435 |
self.test_mode = test_mode
|
436 |
self.dice_loss = BinaryDiceLoss().to(self.custom_device)
|
|
|
449 |
|
450 |
# train mode
|
451 |
## sl
|
|
|
452 |
sl_loss = self.supervised_forward(image, image_embedding, img_shape, kwargs['train_organs'], kwargs['train_labels'])
|
453 |
## ssl
|
454 |
# ssl_loss = self.unsupervised_forward(image, image_embedding, kwargs['pseudo_seg_cleaned'], img_shape)
|
|
|
607 |
# return pseudo_labels, bboxes
|
608 |
|
609 |
class TextEncoder(nn.Module):
|
610 |
+
def __init__(self):
|
611 |
super().__init__()
|
612 |
self.custom_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
613 |
config = CLIPTextConfig()
|
614 |
self.clip_text_model = CLIPTextModel(config)
|
615 |
+
self.tokenizer = None
|
616 |
self.dim_align = nn.Linear(512, 768)
|
617 |
# freeze text encoder
|
618 |
for param in self.clip_text_model.parameters():
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true, "__type": "AddedToken"}, "pad_token": "<|endoftext|>", "add_prefix_space": false, "errors": "replace", "do_lower_case": true, "name_or_path": "/home/yuxin/BAAI/code_release/segvol_transformers/config/clip", "special_tokens_map_file": "/home/yuxin/BAAI/code_release/segvol_transformers/config/clip/special_tokens_map.json", "tokenizer_class": "CLIPTokenizer"}
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|