Spaces:
Sleeping
Sleeping
更新了kappamask模型
Browse files- .netrc +3 -3
- app.py +4 -1
- configs/data/hrcwhu/hrcwhu.yaml +0 -3
- configs/experiment/cnn.yaml +55 -55
- configs/experiment/hrcwhu_hrcloud.yaml +46 -46
- configs/experiment/hrcwhu_kappamask.yaml +47 -0
- configs/experiment/lnn.yaml +56 -56
- configs/model/kappamask/kappamask.yaml +23 -0
- environment.yaml +3 -1
- logs/train/runs/hrcwhu_kappamask/2024-08-07_16-34-34/checkpoints/epoch_015.ckpt +3 -0
- logs/train/runs/hrcwhu_kappamask/2024-08-07_16-34-34/checkpoints/last.ckpt +3 -0
- src/models/components/kappamask.py +152 -0
- wandb_vis.py +5 -1
.netrc
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
machine api.wandb.ai
|
2 |
-
login user
|
3 |
-
password 76211aa17d75da9ddab7b8cba5743454194fe1d5
|
|
|
1 |
+
machine api.wandb.ai
|
2 |
+
login user
|
3 |
+
password 76211aa17d75da9ddab7b8cba5743454194fe1d5
|
app.py
CHANGED
@@ -18,6 +18,7 @@ from src.models.components.cdnetv1 import CDnetV1
|
|
18 |
from src.models.components.cdnetv2 import CDnetV2
|
19 |
from src.models.components.dbnet import DBNet
|
20 |
from src.models.components.hrcloudnet import HRCloudNet
|
|
|
21 |
from src.models.components.mcdnet import MCDNet
|
22 |
from src.models.components.scnn import SCNN
|
23 |
from src.models.components.unetmobv2 import UNetMobV2
|
@@ -36,11 +37,13 @@ class Application:
|
|
36 |
self.device
|
37 |
),
|
38 |
"unetmobv2": UNetMobV2(num_classes=2).to(self.device),
|
|
|
39 |
}
|
40 |
self.__load_weight()
|
41 |
self.transform = albu.Compose(
|
42 |
[
|
43 |
albu.Resize(256, 256, always_apply=True),
|
|
|
44 |
ToTensorV2(),
|
45 |
]
|
46 |
)
|
@@ -119,7 +122,7 @@ class Application:
|
|
119 |
[
|
120 |
gr.Image(sources=["clipboard", "upload"], type="pil"),
|
121 |
gr.Radio(
|
122 |
-
["cdnetv1", "cdnetv2", "hrcloudnet", "mcdnet", "scnn", "dbnet", "unetmobv2"],
|
123 |
label="model_name",
|
124 |
info="选择使用的模型",
|
125 |
),
|
|
|
18 |
from src.models.components.cdnetv2 import CDnetV2
|
19 |
from src.models.components.dbnet import DBNet
|
20 |
from src.models.components.hrcloudnet import HRCloudNet
|
21 |
+
from src.models.components.kappamask import KappaMask
|
22 |
from src.models.components.mcdnet import MCDNet
|
23 |
from src.models.components.scnn import SCNN
|
24 |
from src.models.components.unetmobv2 import UNetMobV2
|
|
|
37 |
self.device
|
38 |
),
|
39 |
"unetmobv2": UNetMobV2(num_classes=2).to(self.device),
|
40 |
+
"kappamask":KappaMask(num_classes=2,in_channels=3).to(self.device)
|
41 |
}
|
42 |
self.__load_weight()
|
43 |
self.transform = albu.Compose(
|
44 |
[
|
45 |
albu.Resize(256, 256, always_apply=True),
|
46 |
+
albu.ToFloat(),
|
47 |
ToTensorV2(),
|
48 |
]
|
49 |
)
|
|
|
122 |
[
|
123 |
gr.Image(sources=["clipboard", "upload"], type="pil"),
|
124 |
gr.Radio(
|
125 |
+
["cdnetv1", "cdnetv2", "hrcloudnet", "mcdnet", "scnn", "dbnet", "unetmobv2","kappamask"],
|
126 |
label="model_name",
|
127 |
info="选择使用的模型",
|
128 |
),
|
configs/data/hrcwhu/hrcwhu.yaml
CHANGED
@@ -46,7 +46,6 @@ train_pipeline:
|
|
46 |
_target_: albumentations.Compose
|
47 |
transforms:
|
48 |
- _target_: albumentations.ToFloat
|
49 |
-
max_value: 255.0
|
50 |
- _target_: albumentations.pytorch.transforms.ToTensorV2
|
51 |
|
52 |
ann_transform: null
|
@@ -62,7 +61,6 @@ val_pipeline:
|
|
62 |
_target_: albumentations.Compose
|
63 |
transforms:
|
64 |
- _target_: albumentations.ToFloat
|
65 |
-
max_value: 255.0
|
66 |
- _target_: albumentations.pytorch.transforms.ToTensorV2
|
67 |
ann_transform: null
|
68 |
|
@@ -78,7 +76,6 @@ test_pipeline:
|
|
78 |
_target_: albumentations.Compose
|
79 |
transforms:
|
80 |
- _target_: albumentations.ToFloat
|
81 |
-
max_value: 255.0
|
82 |
- _target_: albumentations.pytorch.transforms.ToTensorV2
|
83 |
ann_transform: null
|
84 |
|
|
|
46 |
_target_: albumentations.Compose
|
47 |
transforms:
|
48 |
- _target_: albumentations.ToFloat
|
|
|
49 |
- _target_: albumentations.pytorch.transforms.ToTensorV2
|
50 |
|
51 |
ann_transform: null
|
|
|
61 |
_target_: albumentations.Compose
|
62 |
transforms:
|
63 |
- _target_: albumentations.ToFloat
|
|
|
64 |
- _target_: albumentations.pytorch.transforms.ToTensorV2
|
65 |
ann_transform: null
|
66 |
|
|
|
76 |
_target_: albumentations.Compose
|
77 |
transforms:
|
78 |
- _target_: albumentations.ToFloat
|
|
|
79 |
- _target_: albumentations.pytorch.transforms.ToTensorV2
|
80 |
ann_transform: null
|
81 |
|
configs/experiment/cnn.yaml
CHANGED
@@ -1,56 +1,56 @@
|
|
1 |
-
# @package _global_
|
2 |
-
|
3 |
-
# to execute this experiment run:
|
4 |
-
# python train.py experiment=example
|
5 |
-
|
6 |
-
defaults:
|
7 |
-
- override /trainer: gpu
|
8 |
-
- override /data: mnist
|
9 |
-
- override /model: cnn
|
10 |
-
- override /logger: wandb
|
11 |
-
- override /callbacks: default
|
12 |
-
|
13 |
-
# all parameters below will be merged with parameters from default configurations set above
|
14 |
-
# this allows you to overwrite only specified parameters
|
15 |
-
|
16 |
-
tags: ["mnist", "cnn"]
|
17 |
-
|
18 |
-
seed: 42
|
19 |
-
|
20 |
-
trainer:
|
21 |
-
min_epochs: 10
|
22 |
-
max_epochs: 10
|
23 |
-
gradient_clip_val: 0.5
|
24 |
-
devices: 1
|
25 |
-
|
26 |
-
data:
|
27 |
-
batch_size: 128
|
28 |
-
train_val_test_split: [55_000, 5_000, 10_000]
|
29 |
-
num_workers: 31
|
30 |
-
pin_memory: False
|
31 |
-
persistent_workers: False
|
32 |
-
|
33 |
-
model:
|
34 |
-
net:
|
35 |
-
dim: 32
|
36 |
-
|
37 |
-
logger:
|
38 |
-
wandb:
|
39 |
-
project: "mnist"
|
40 |
-
name: "cnn"
|
41 |
-
aim:
|
42 |
-
experiment: "cnn"
|
43 |
-
|
44 |
-
callbacks:
|
45 |
-
model_checkpoint:
|
46 |
-
dirpath: ${paths.output_dir}/checkpoints
|
47 |
-
filename: "epoch_{epoch:03d}"
|
48 |
-
monitor: "val/acc"
|
49 |
-
mode: "max"
|
50 |
-
save_last: True
|
51 |
-
auto_insert_metric_name: False
|
52 |
-
|
53 |
-
early_stopping:
|
54 |
-
monitor: "val/acc"
|
55 |
-
patience: 100
|
56 |
mode: "max"
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# to execute this experiment run:
|
4 |
+
# python train.py experiment=example
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /trainer: gpu
|
8 |
+
- override /data: mnist
|
9 |
+
- override /model: cnn
|
10 |
+
- override /logger: wandb
|
11 |
+
- override /callbacks: default
|
12 |
+
|
13 |
+
# all parameters below will be merged with parameters from default configurations set above
|
14 |
+
# this allows you to overwrite only specified parameters
|
15 |
+
|
16 |
+
tags: ["mnist", "cnn"]
|
17 |
+
|
18 |
+
seed: 42
|
19 |
+
|
20 |
+
trainer:
|
21 |
+
min_epochs: 10
|
22 |
+
max_epochs: 10
|
23 |
+
gradient_clip_val: 0.5
|
24 |
+
devices: 1
|
25 |
+
|
26 |
+
data:
|
27 |
+
batch_size: 128
|
28 |
+
train_val_test_split: [55_000, 5_000, 10_000]
|
29 |
+
num_workers: 31
|
30 |
+
pin_memory: False
|
31 |
+
persistent_workers: False
|
32 |
+
|
33 |
+
model:
|
34 |
+
net:
|
35 |
+
dim: 32
|
36 |
+
|
37 |
+
logger:
|
38 |
+
wandb:
|
39 |
+
project: "mnist"
|
40 |
+
name: "cnn"
|
41 |
+
aim:
|
42 |
+
experiment: "cnn"
|
43 |
+
|
44 |
+
callbacks:
|
45 |
+
model_checkpoint:
|
46 |
+
dirpath: ${paths.output_dir}/checkpoints
|
47 |
+
filename: "epoch_{epoch:03d}"
|
48 |
+
monitor: "val/acc"
|
49 |
+
mode: "max"
|
50 |
+
save_last: True
|
51 |
+
auto_insert_metric_name: False
|
52 |
+
|
53 |
+
early_stopping:
|
54 |
+
monitor: "val/acc"
|
55 |
+
patience: 100
|
56 |
mode: "max"
|
configs/experiment/hrcwhu_hrcloud.yaml
CHANGED
@@ -1,47 +1,47 @@
|
|
1 |
-
# @package _global_
|
2 |
-
|
3 |
-
# to execute this experiment run:
|
4 |
-
# python train.py experiment=example
|
5 |
-
|
6 |
-
defaults:
|
7 |
-
- override /trainer: gpu
|
8 |
-
- override /data: hrcwhu/hrcwhu
|
9 |
-
- override /model: hrcloudnet/hrcloudnet
|
10 |
-
- override /logger: wandb
|
11 |
-
- override /callbacks: default
|
12 |
-
|
13 |
-
# all parameters below will be merged with parameters from default configurations set above
|
14 |
-
# this allows you to overwrite only specified parameters
|
15 |
-
|
16 |
-
tags: ["hrcWhu", "hrcloud"]
|
17 |
-
|
18 |
-
seed: 42
|
19 |
-
|
20 |
-
|
21 |
-
# scheduler:
|
22 |
-
# _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
|
23 |
-
# _partial_: true
|
24 |
-
# mode: min
|
25 |
-
# factor: 0.1
|
26 |
-
# patience: 10
|
27 |
-
|
28 |
-
logger:
|
29 |
-
wandb:
|
30 |
-
project: "hrcWhu"
|
31 |
-
name: "hrcloud"
|
32 |
-
aim:
|
33 |
-
experiment: "hrcwhu_hrcloud"
|
34 |
-
|
35 |
-
callbacks:
|
36 |
-
model_checkpoint:
|
37 |
-
dirpath: ${paths.output_dir}/checkpoints
|
38 |
-
filename: "epoch_{epoch:03d}"
|
39 |
-
monitor: "val/loss"
|
40 |
-
mode: "min"
|
41 |
-
save_last: True
|
42 |
-
auto_insert_metric_name: False
|
43 |
-
|
44 |
-
early_stopping:
|
45 |
-
monitor: "val/loss"
|
46 |
-
patience: 10
|
47 |
mode: "min"
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# to execute this experiment run:
|
4 |
+
# python train.py experiment=example
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /trainer: gpu
|
8 |
+
- override /data: hrcwhu/hrcwhu
|
9 |
+
- override /model: hrcloudnet/hrcloudnet
|
10 |
+
- override /logger: wandb
|
11 |
+
- override /callbacks: default
|
12 |
+
|
13 |
+
# all parameters below will be merged with parameters from default configurations set above
|
14 |
+
# this allows you to overwrite only specified parameters
|
15 |
+
|
16 |
+
tags: ["hrcWhu", "hrcloud"]
|
17 |
+
|
18 |
+
seed: 42
|
19 |
+
|
20 |
+
|
21 |
+
# scheduler:
|
22 |
+
# _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
|
23 |
+
# _partial_: true
|
24 |
+
# mode: min
|
25 |
+
# factor: 0.1
|
26 |
+
# patience: 10
|
27 |
+
|
28 |
+
logger:
|
29 |
+
wandb:
|
30 |
+
project: "hrcWhu"
|
31 |
+
name: "hrcloud"
|
32 |
+
aim:
|
33 |
+
experiment: "hrcwhu_hrcloud"
|
34 |
+
|
35 |
+
callbacks:
|
36 |
+
model_checkpoint:
|
37 |
+
dirpath: ${paths.output_dir}/checkpoints
|
38 |
+
filename: "epoch_{epoch:03d}"
|
39 |
+
monitor: "val/loss"
|
40 |
+
mode: "min"
|
41 |
+
save_last: True
|
42 |
+
auto_insert_metric_name: False
|
43 |
+
|
44 |
+
early_stopping:
|
45 |
+
monitor: "val/loss"
|
46 |
+
patience: 10
|
47 |
mode: "min"
|
configs/experiment/hrcwhu_kappamask.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# to execute this experiment run:
|
4 |
+
# python train.py experiment=example
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /trainer: gpu
|
8 |
+
- override /data: hrcwhu/hrcwhu
|
9 |
+
- override /model: kappamask/kappamask
|
10 |
+
- override /logger: wandb
|
11 |
+
- override /callbacks: default
|
12 |
+
|
13 |
+
# all parameters below will be merged with parameters from default configurations set above
|
14 |
+
# this allows you to overwrite only specified parameters
|
15 |
+
|
16 |
+
tags: ["hrcWhu", "kappamask"]
|
17 |
+
|
18 |
+
seed: 42
|
19 |
+
|
20 |
+
|
21 |
+
# scheduler:
|
22 |
+
# _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
|
23 |
+
# _partial_: true
|
24 |
+
# mode: min
|
25 |
+
# factor: 0.1
|
26 |
+
# patience: 10
|
27 |
+
|
28 |
+
logger:
|
29 |
+
wandb:
|
30 |
+
project: "hrcWhu"
|
31 |
+
name: "kappamask"
|
32 |
+
aim:
|
33 |
+
experiment: "hrcwhu_kappamask"
|
34 |
+
|
35 |
+
callbacks:
|
36 |
+
model_checkpoint:
|
37 |
+
dirpath: ${paths.output_dir}/checkpoints
|
38 |
+
filename: "epoch_{epoch:03d}"
|
39 |
+
monitor: "val/loss"
|
40 |
+
mode: "min"
|
41 |
+
save_last: True
|
42 |
+
auto_insert_metric_name: False
|
43 |
+
|
44 |
+
early_stopping:
|
45 |
+
monitor: "val/loss"
|
46 |
+
patience: 10
|
47 |
+
mode: "min"
|
configs/experiment/lnn.yaml
CHANGED
@@ -1,57 +1,57 @@
|
|
1 |
-
# @package _global_
|
2 |
-
|
3 |
-
# to execute this experiment run:
|
4 |
-
# python train.py experiment=example
|
5 |
-
|
6 |
-
defaults:
|
7 |
-
- override /trainer: gpu
|
8 |
-
- override /data: mnist
|
9 |
-
- override /model: lnn
|
10 |
-
- override /logger: wandb
|
11 |
-
- override /callbacks: default
|
12 |
-
|
13 |
-
# all parameters below will be merged with parameters from default configurations set above
|
14 |
-
# this allows you to overwrite only specified parameters
|
15 |
-
|
16 |
-
tags: ["mnist", "lnn"]
|
17 |
-
|
18 |
-
seed: 42
|
19 |
-
|
20 |
-
trainer:
|
21 |
-
min_epochs: 10
|
22 |
-
max_epochs: 10
|
23 |
-
gradient_clip_val: 0.5
|
24 |
-
devices: 1
|
25 |
-
|
26 |
-
data:
|
27 |
-
batch_size: 128
|
28 |
-
train_val_test_split: [55_000, 5_000, 10_000]
|
29 |
-
num_workers: 31
|
30 |
-
pin_memory: False
|
31 |
-
persistent_workers: False
|
32 |
-
|
33 |
-
model:
|
34 |
-
net:
|
35 |
-
_target_: src.models.components.lnn.LNN
|
36 |
-
dim: 32
|
37 |
-
|
38 |
-
logger:
|
39 |
-
wandb:
|
40 |
-
project: "mnist"
|
41 |
-
name: "lnn"
|
42 |
-
aim:
|
43 |
-
experiment: "lnn"
|
44 |
-
|
45 |
-
callbacks:
|
46 |
-
model_checkpoint:
|
47 |
-
dirpath: ${paths.output_dir}/checkpoints
|
48 |
-
filename: "epoch_{epoch:03d}"
|
49 |
-
monitor: "val/acc"
|
50 |
-
mode: "max"
|
51 |
-
save_last: True
|
52 |
-
auto_insert_metric_name: False
|
53 |
-
|
54 |
-
early_stopping:
|
55 |
-
monitor: "val/acc"
|
56 |
-
patience: 100
|
57 |
mode: "max"
|
|
|
1 |
+
# @package _global_
|
2 |
+
|
3 |
+
# to execute this experiment run:
|
4 |
+
# python train.py experiment=example
|
5 |
+
|
6 |
+
defaults:
|
7 |
+
- override /trainer: gpu
|
8 |
+
- override /data: mnist
|
9 |
+
- override /model: lnn
|
10 |
+
- override /logger: wandb
|
11 |
+
- override /callbacks: default
|
12 |
+
|
13 |
+
# all parameters below will be merged with parameters from default configurations set above
|
14 |
+
# this allows you to overwrite only specified parameters
|
15 |
+
|
16 |
+
tags: ["mnist", "lnn"]
|
17 |
+
|
18 |
+
seed: 42
|
19 |
+
|
20 |
+
trainer:
|
21 |
+
min_epochs: 10
|
22 |
+
max_epochs: 10
|
23 |
+
gradient_clip_val: 0.5
|
24 |
+
devices: 1
|
25 |
+
|
26 |
+
data:
|
27 |
+
batch_size: 128
|
28 |
+
train_val_test_split: [55_000, 5_000, 10_000]
|
29 |
+
num_workers: 31
|
30 |
+
pin_memory: False
|
31 |
+
persistent_workers: False
|
32 |
+
|
33 |
+
model:
|
34 |
+
net:
|
35 |
+
_target_: src.models.components.lnn.LNN
|
36 |
+
dim: 32
|
37 |
+
|
38 |
+
logger:
|
39 |
+
wandb:
|
40 |
+
project: "mnist"
|
41 |
+
name: "lnn"
|
42 |
+
aim:
|
43 |
+
experiment: "lnn"
|
44 |
+
|
45 |
+
callbacks:
|
46 |
+
model_checkpoint:
|
47 |
+
dirpath: ${paths.output_dir}/checkpoints
|
48 |
+
filename: "epoch_{epoch:03d}"
|
49 |
+
monitor: "val/acc"
|
50 |
+
mode: "max"
|
51 |
+
save_last: True
|
52 |
+
auto_insert_metric_name: False
|
53 |
+
|
54 |
+
early_stopping:
|
55 |
+
monitor: "val/acc"
|
56 |
+
patience: 100
|
57 |
mode: "max"
|
configs/model/kappamask/kappamask.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.base_module.BaseLitModule
|
2 |
+
|
3 |
+
net:
|
4 |
+
_target_: src.models.components.kappamask.KappaMask
|
5 |
+
num_classes: 2
|
6 |
+
in_channels: 3
|
7 |
+
|
8 |
+
num_classes: 2
|
9 |
+
|
10 |
+
criterion:
|
11 |
+
_target_: torch.nn.CrossEntropyLoss
|
12 |
+
|
13 |
+
optimizer:
|
14 |
+
_target_: torch.optim.AdamW
|
15 |
+
_partial_: true
|
16 |
+
lr: 0.00001
|
17 |
+
|
18 |
+
scheduler:
|
19 |
+
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
|
20 |
+
_partial_: true
|
21 |
+
mode: min
|
22 |
+
factor: 0.1
|
23 |
+
patience: 4
|
environment.yaml
CHANGED
@@ -24,4 +24,6 @@ dependencies:
|
|
24 |
- aim
|
25 |
- gradio
|
26 |
- image-dehazer
|
27 |
-
- thop
|
|
|
|
|
|
24 |
- aim
|
25 |
- gradio
|
26 |
- image-dehazer
|
27 |
+
- thop
|
28 |
+
- albumentations
|
29 |
+
- segmentation_models_pytorch
|
logs/train/runs/hrcwhu_kappamask/2024-08-07_16-34-34/checkpoints/epoch_015.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fb7a77410abe75b5a66c9ec7d46f4203c8c41d70180174b0d0e2db2ff21b31ca
|
3 |
+
size 372474542
|
logs/train/runs/hrcwhu_kappamask/2024-08-07_16-34-34/checkpoints/last.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b2cdda7ab749f7a75c157da0cbe88cb379b712ab604294e8436e204748f5beb4
|
3 |
+
size 372474606
|
src/models/components/kappamask.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2024/8/7 下午3:51
|
3 |
+
# @Author : xiaoshun
|
4 |
+
# @Email : [email protected]
|
5 |
+
# @File : kappamask.py.py
|
6 |
+
# @Software: PyCharm
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from torch import nn as nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class KappaMask(nn.Module):
|
14 |
+
def __init__(self, num_classes=2, in_channels=3):
|
15 |
+
super().__init__()
|
16 |
+
self.conv1 = nn.Sequential(
|
17 |
+
nn.Conv2d(in_channels, 64, 3, 1, 1),
|
18 |
+
nn.ReLU(inplace=True),
|
19 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
20 |
+
nn.ReLU(inplace=True),
|
21 |
+
)
|
22 |
+
self.conv2 = nn.Sequential(
|
23 |
+
nn.Conv2d(64, 128, 3, 1, 1),
|
24 |
+
nn.ReLU(inplace=True),
|
25 |
+
nn.Conv2d(128, 128, 3, 1, 1),
|
26 |
+
nn.ReLU(inplace=True),
|
27 |
+
)
|
28 |
+
self.conv3 = nn.Sequential(
|
29 |
+
nn.Conv2d(128, 256, 3, 1, 1),
|
30 |
+
nn.ReLU(inplace=True),
|
31 |
+
nn.Conv2d(256, 256, 3, 1, 1),
|
32 |
+
nn.ReLU(inplace=True),
|
33 |
+
)
|
34 |
+
|
35 |
+
self.conv4 = nn.Sequential(
|
36 |
+
nn.Conv2d(256, 512, 3, 1, 1),
|
37 |
+
nn.ReLU(inplace=True),
|
38 |
+
nn.Conv2d(512, 512, 3, 1, 1),
|
39 |
+
nn.ReLU(inplace=True),
|
40 |
+
)
|
41 |
+
self.drop4 = nn.Dropout(0.5)
|
42 |
+
|
43 |
+
self.conv5 = nn.Sequential(
|
44 |
+
nn.Conv2d(512, 1024, 3, 1, 1),
|
45 |
+
nn.ReLU(inplace=True),
|
46 |
+
nn.Conv2d(1024, 1024, 3, 1, 1),
|
47 |
+
nn.ReLU(inplace=True),
|
48 |
+
)
|
49 |
+
self.drop5 = nn.Dropout(0.5)
|
50 |
+
|
51 |
+
self.up6 = nn.Sequential(
|
52 |
+
nn.Upsample(scale_factor=2),
|
53 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
54 |
+
nn.Conv2d(1024, 512, 2),
|
55 |
+
nn.ReLU(inplace=True)
|
56 |
+
)
|
57 |
+
self.conv6 = nn.Sequential(
|
58 |
+
nn.Conv2d(1024, 512, 3, 1, 1),
|
59 |
+
nn.ReLU(inplace=True),
|
60 |
+
nn.Conv2d(512, 512, 3, 1, 1),
|
61 |
+
nn.ReLU(inplace=True),
|
62 |
+
)
|
63 |
+
self.up7 = nn.Sequential(
|
64 |
+
nn.Upsample(scale_factor=2),
|
65 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
66 |
+
nn.Conv2d(512, 256, 2),
|
67 |
+
nn.ReLU(inplace=True)
|
68 |
+
)
|
69 |
+
self.conv7 = nn.Sequential(
|
70 |
+
nn.Conv2d(512, 256, 3, 1, 1),
|
71 |
+
nn.ReLU(inplace=True),
|
72 |
+
nn.Conv2d(256, 256, 3, 1, 1),
|
73 |
+
nn.ReLU(inplace=True),
|
74 |
+
)
|
75 |
+
|
76 |
+
self.up8 = nn.Sequential(
|
77 |
+
nn.Upsample(scale_factor=2),
|
78 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
79 |
+
nn.Conv2d(256, 128, 2),
|
80 |
+
nn.ReLU(inplace=True)
|
81 |
+
)
|
82 |
+
self.conv8 = nn.Sequential(
|
83 |
+
nn.Conv2d(256, 128, 3, 1, 1),
|
84 |
+
nn.ReLU(inplace=True),
|
85 |
+
nn.Conv2d(128, 128, 3, 1, 1),
|
86 |
+
nn.ReLU(inplace=True),
|
87 |
+
)
|
88 |
+
|
89 |
+
self.up9 = nn.Sequential(
|
90 |
+
nn.Upsample(scale_factor=2),
|
91 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
92 |
+
nn.Conv2d(128, 64, 2),
|
93 |
+
nn.ReLU(inplace=True)
|
94 |
+
)
|
95 |
+
self.conv9 = nn.Sequential(
|
96 |
+
nn.Conv2d(128, 64, 3, 1, 1),
|
97 |
+
nn.ReLU(inplace=True),
|
98 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
99 |
+
nn.ReLU(inplace=True),
|
100 |
+
nn.Conv2d(64, 2, 3, 1, 1),
|
101 |
+
nn.ReLU(inplace=True),
|
102 |
+
)
|
103 |
+
self.conv10 = nn.Conv2d(2, num_classes, 1)
|
104 |
+
self.__init_weights()
|
105 |
+
|
106 |
+
def __init_weights(self):
|
107 |
+
for m in self.modules():
|
108 |
+
if isinstance(m, nn.Conv2d):
|
109 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
conv1 = self.conv1(x)
|
113 |
+
pool1 = F.max_pool2d(conv1, 2, 2)
|
114 |
+
|
115 |
+
conv2 = self.conv2(pool1)
|
116 |
+
pool2 = F.max_pool2d(conv2, 2, 2)
|
117 |
+
|
118 |
+
conv3 = self.conv3(pool2)
|
119 |
+
pool3 = F.max_pool2d(conv3, 2, 2)
|
120 |
+
|
121 |
+
conv4 = self.conv4(pool3)
|
122 |
+
drop4 = self.drop4(conv4)
|
123 |
+
pool4 = F.max_pool2d(drop4, 2, 2)
|
124 |
+
|
125 |
+
conv5 = self.conv5(pool4)
|
126 |
+
drop5 = self.drop5(conv5)
|
127 |
+
|
128 |
+
up6 = self.up6(drop5)
|
129 |
+
merge6 = torch.cat((drop4, up6), dim=1)
|
130 |
+
conv6 = self.conv6(merge6)
|
131 |
+
|
132 |
+
up7 = self.up7(conv6)
|
133 |
+
merge7 = torch.cat((conv3, up7), dim=1)
|
134 |
+
conv7 = self.conv7(merge7)
|
135 |
+
|
136 |
+
up8 = self.up8(conv7)
|
137 |
+
merge8 = torch.cat((conv2, up8), dim=1)
|
138 |
+
conv8 = self.conv8(merge8)
|
139 |
+
|
140 |
+
up9 = self.up9(conv8)
|
141 |
+
merge9 = torch.cat((conv1, up9), dim=1)
|
142 |
+
conv9 = self.conv9(merge9)
|
143 |
+
|
144 |
+
output = self.conv10(conv9)
|
145 |
+
return output
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == '__main__':
|
149 |
+
model = KappaMask(num_classes=2, in_channels=3)
|
150 |
+
fake_data = torch.rand(2, 3, 256, 256)
|
151 |
+
output = model(fake_data)
|
152 |
+
print(output.shape)
|
wandb_vis.py
CHANGED
@@ -25,6 +25,7 @@ from src.models.components.cdnetv1 import CDnetV1
|
|
25 |
from src.models.components.cdnetv2 import CDnetV2
|
26 |
from src.models.components.dbnet import DBNet
|
27 |
from src.models.components.hrcloudnet import HRCloudNet
|
|
|
28 |
from src.models.components.mcdnet import MCDNet
|
29 |
from src.models.components.scnn import SCNN
|
30 |
from src.models.components.unetmobv2 import UNetMobV2
|
@@ -69,6 +70,9 @@ class WandbVis:
|
|
69 |
if self.model_name == "unetmobv2":
|
70 |
return UNetMobV2(num_classes=2).to(self.device)
|
71 |
|
|
|
|
|
|
|
72 |
raise ValueError(f"{self.model_name}模型不存在")
|
73 |
|
74 |
def load_model(self):
|
@@ -91,7 +95,7 @@ class WandbVis:
|
|
91 |
]
|
92 |
)
|
93 |
img_transform = albu.Compose([
|
94 |
-
albu.ToFloat(
|
95 |
ToTensorV2()
|
96 |
])
|
97 |
ann_transform = None
|
|
|
25 |
from src.models.components.cdnetv2 import CDnetV2
|
26 |
from src.models.components.dbnet import DBNet
|
27 |
from src.models.components.hrcloudnet import HRCloudNet
|
28 |
+
from src.models.components.kappamask import KappaMask
|
29 |
from src.models.components.mcdnet import MCDNet
|
30 |
from src.models.components.scnn import SCNN
|
31 |
from src.models.components.unetmobv2 import UNetMobV2
|
|
|
70 |
if self.model_name == "unetmobv2":
|
71 |
return UNetMobV2(num_classes=2).to(self.device)
|
72 |
|
73 |
+
if self.model_name == "kappamask":
|
74 |
+
return KappaMask(num_classes=2, in_channels=3).to(self.device)
|
75 |
+
|
76 |
raise ValueError(f"{self.model_name}模型不存在")
|
77 |
|
78 |
def load_model(self):
|
|
|
95 |
]
|
96 |
)
|
97 |
img_transform = albu.Compose([
|
98 |
+
albu.ToFloat(),
|
99 |
ToTensorV2()
|
100 |
])
|
101 |
ann_transform = None
|