caixiaoshun commited on
Commit
3692a8e
·
verified ·
1 Parent(s): e4a7f6f

更新了kappamask模型

Browse files
.netrc CHANGED
@@ -1,3 +1,3 @@
1
- machine api.wandb.ai
2
- login user
3
- password 76211aa17d75da9ddab7b8cba5743454194fe1d5
 
1
+ machine api.wandb.ai
2
+ login user
3
+ password 76211aa17d75da9ddab7b8cba5743454194fe1d5
app.py CHANGED
@@ -18,6 +18,7 @@ from src.models.components.cdnetv1 import CDnetV1
18
  from src.models.components.cdnetv2 import CDnetV2
19
  from src.models.components.dbnet import DBNet
20
  from src.models.components.hrcloudnet import HRCloudNet
 
21
  from src.models.components.mcdnet import MCDNet
22
  from src.models.components.scnn import SCNN
23
  from src.models.components.unetmobv2 import UNetMobV2
@@ -36,11 +37,13 @@ class Application:
36
  self.device
37
  ),
38
  "unetmobv2": UNetMobV2(num_classes=2).to(self.device),
 
39
  }
40
  self.__load_weight()
41
  self.transform = albu.Compose(
42
  [
43
  albu.Resize(256, 256, always_apply=True),
 
44
  ToTensorV2(),
45
  ]
46
  )
@@ -119,7 +122,7 @@ class Application:
119
  [
120
  gr.Image(sources=["clipboard", "upload"], type="pil"),
121
  gr.Radio(
122
- ["cdnetv1", "cdnetv2", "hrcloudnet", "mcdnet", "scnn", "dbnet", "unetmobv2"],
123
  label="model_name",
124
  info="选择使用的模型",
125
  ),
 
18
  from src.models.components.cdnetv2 import CDnetV2
19
  from src.models.components.dbnet import DBNet
20
  from src.models.components.hrcloudnet import HRCloudNet
21
+ from src.models.components.kappamask import KappaMask
22
  from src.models.components.mcdnet import MCDNet
23
  from src.models.components.scnn import SCNN
24
  from src.models.components.unetmobv2 import UNetMobV2
 
37
  self.device
38
  ),
39
  "unetmobv2": UNetMobV2(num_classes=2).to(self.device),
40
+ "kappamask":KappaMask(num_classes=2,in_channels=3).to(self.device)
41
  }
42
  self.__load_weight()
43
  self.transform = albu.Compose(
44
  [
45
  albu.Resize(256, 256, always_apply=True),
46
+ albu.ToFloat(),
47
  ToTensorV2(),
48
  ]
49
  )
 
122
  [
123
  gr.Image(sources=["clipboard", "upload"], type="pil"),
124
  gr.Radio(
125
+ ["cdnetv1", "cdnetv2", "hrcloudnet", "mcdnet", "scnn", "dbnet", "unetmobv2","kappamask"],
126
  label="model_name",
127
  info="选择使用的模型",
128
  ),
configs/data/hrcwhu/hrcwhu.yaml CHANGED
@@ -46,7 +46,6 @@ train_pipeline:
46
  _target_: albumentations.Compose
47
  transforms:
48
  - _target_: albumentations.ToFloat
49
- max_value: 255.0
50
  - _target_: albumentations.pytorch.transforms.ToTensorV2
51
 
52
  ann_transform: null
@@ -62,7 +61,6 @@ val_pipeline:
62
  _target_: albumentations.Compose
63
  transforms:
64
  - _target_: albumentations.ToFloat
65
- max_value: 255.0
66
  - _target_: albumentations.pytorch.transforms.ToTensorV2
67
  ann_transform: null
68
 
@@ -78,7 +76,6 @@ test_pipeline:
78
  _target_: albumentations.Compose
79
  transforms:
80
  - _target_: albumentations.ToFloat
81
- max_value: 255.0
82
  - _target_: albumentations.pytorch.transforms.ToTensorV2
83
  ann_transform: null
84
 
 
46
  _target_: albumentations.Compose
47
  transforms:
48
  - _target_: albumentations.ToFloat
 
49
  - _target_: albumentations.pytorch.transforms.ToTensorV2
50
 
51
  ann_transform: null
 
61
  _target_: albumentations.Compose
62
  transforms:
63
  - _target_: albumentations.ToFloat
 
64
  - _target_: albumentations.pytorch.transforms.ToTensorV2
65
  ann_transform: null
66
 
 
76
  _target_: albumentations.Compose
77
  transforms:
78
  - _target_: albumentations.ToFloat
 
79
  - _target_: albumentations.pytorch.transforms.ToTensorV2
80
  ann_transform: null
81
 
configs/experiment/cnn.yaml CHANGED
@@ -1,56 +1,56 @@
1
- # @package _global_
2
-
3
- # to execute this experiment run:
4
- # python train.py experiment=example
5
-
6
- defaults:
7
- - override /trainer: gpu
8
- - override /data: mnist
9
- - override /model: cnn
10
- - override /logger: wandb
11
- - override /callbacks: default
12
-
13
- # all parameters below will be merged with parameters from default configurations set above
14
- # this allows you to overwrite only specified parameters
15
-
16
- tags: ["mnist", "cnn"]
17
-
18
- seed: 42
19
-
20
- trainer:
21
- min_epochs: 10
22
- max_epochs: 10
23
- gradient_clip_val: 0.5
24
- devices: 1
25
-
26
- data:
27
- batch_size: 128
28
- train_val_test_split: [55_000, 5_000, 10_000]
29
- num_workers: 31
30
- pin_memory: False
31
- persistent_workers: False
32
-
33
- model:
34
- net:
35
- dim: 32
36
-
37
- logger:
38
- wandb:
39
- project: "mnist"
40
- name: "cnn"
41
- aim:
42
- experiment: "cnn"
43
-
44
- callbacks:
45
- model_checkpoint:
46
- dirpath: ${paths.output_dir}/checkpoints
47
- filename: "epoch_{epoch:03d}"
48
- monitor: "val/acc"
49
- mode: "max"
50
- save_last: True
51
- auto_insert_metric_name: False
52
-
53
- early_stopping:
54
- monitor: "val/acc"
55
- patience: 100
56
  mode: "max"
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=example
5
+
6
+ defaults:
7
+ - override /trainer: gpu
8
+ - override /data: mnist
9
+ - override /model: cnn
10
+ - override /logger: wandb
11
+ - override /callbacks: default
12
+
13
+ # all parameters below will be merged with parameters from default configurations set above
14
+ # this allows you to overwrite only specified parameters
15
+
16
+ tags: ["mnist", "cnn"]
17
+
18
+ seed: 42
19
+
20
+ trainer:
21
+ min_epochs: 10
22
+ max_epochs: 10
23
+ gradient_clip_val: 0.5
24
+ devices: 1
25
+
26
+ data:
27
+ batch_size: 128
28
+ train_val_test_split: [55_000, 5_000, 10_000]
29
+ num_workers: 31
30
+ pin_memory: False
31
+ persistent_workers: False
32
+
33
+ model:
34
+ net:
35
+ dim: 32
36
+
37
+ logger:
38
+ wandb:
39
+ project: "mnist"
40
+ name: "cnn"
41
+ aim:
42
+ experiment: "cnn"
43
+
44
+ callbacks:
45
+ model_checkpoint:
46
+ dirpath: ${paths.output_dir}/checkpoints
47
+ filename: "epoch_{epoch:03d}"
48
+ monitor: "val/acc"
49
+ mode: "max"
50
+ save_last: True
51
+ auto_insert_metric_name: False
52
+
53
+ early_stopping:
54
+ monitor: "val/acc"
55
+ patience: 100
56
  mode: "max"
configs/experiment/hrcwhu_hrcloud.yaml CHANGED
@@ -1,47 +1,47 @@
1
- # @package _global_
2
-
3
- # to execute this experiment run:
4
- # python train.py experiment=example
5
-
6
- defaults:
7
- - override /trainer: gpu
8
- - override /data: hrcwhu/hrcwhu
9
- - override /model: hrcloudnet/hrcloudnet
10
- - override /logger: wandb
11
- - override /callbacks: default
12
-
13
- # all parameters below will be merged with parameters from default configurations set above
14
- # this allows you to overwrite only specified parameters
15
-
16
- tags: ["hrcWhu", "hrcloud"]
17
-
18
- seed: 42
19
-
20
-
21
- # scheduler:
22
- # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
23
- # _partial_: true
24
- # mode: min
25
- # factor: 0.1
26
- # patience: 10
27
-
28
- logger:
29
- wandb:
30
- project: "hrcWhu"
31
- name: "hrcloud"
32
- aim:
33
- experiment: "hrcwhu_hrcloud"
34
-
35
- callbacks:
36
- model_checkpoint:
37
- dirpath: ${paths.output_dir}/checkpoints
38
- filename: "epoch_{epoch:03d}"
39
- monitor: "val/loss"
40
- mode: "min"
41
- save_last: True
42
- auto_insert_metric_name: False
43
-
44
- early_stopping:
45
- monitor: "val/loss"
46
- patience: 10
47
  mode: "min"
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=example
5
+
6
+ defaults:
7
+ - override /trainer: gpu
8
+ - override /data: hrcwhu/hrcwhu
9
+ - override /model: hrcloudnet/hrcloudnet
10
+ - override /logger: wandb
11
+ - override /callbacks: default
12
+
13
+ # all parameters below will be merged with parameters from default configurations set above
14
+ # this allows you to overwrite only specified parameters
15
+
16
+ tags: ["hrcWhu", "hrcloud"]
17
+
18
+ seed: 42
19
+
20
+
21
+ # scheduler:
22
+ # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
23
+ # _partial_: true
24
+ # mode: min
25
+ # factor: 0.1
26
+ # patience: 10
27
+
28
+ logger:
29
+ wandb:
30
+ project: "hrcWhu"
31
+ name: "hrcloud"
32
+ aim:
33
+ experiment: "hrcwhu_hrcloud"
34
+
35
+ callbacks:
36
+ model_checkpoint:
37
+ dirpath: ${paths.output_dir}/checkpoints
38
+ filename: "epoch_{epoch:03d}"
39
+ monitor: "val/loss"
40
+ mode: "min"
41
+ save_last: True
42
+ auto_insert_metric_name: False
43
+
44
+ early_stopping:
45
+ monitor: "val/loss"
46
+ patience: 10
47
  mode: "min"
configs/experiment/hrcwhu_kappamask.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=example
5
+
6
+ defaults:
7
+ - override /trainer: gpu
8
+ - override /data: hrcwhu/hrcwhu
9
+ - override /model: kappamask/kappamask
10
+ - override /logger: wandb
11
+ - override /callbacks: default
12
+
13
+ # all parameters below will be merged with parameters from default configurations set above
14
+ # this allows you to overwrite only specified parameters
15
+
16
+ tags: ["hrcWhu", "kappamask"]
17
+
18
+ seed: 42
19
+
20
+
21
+ # scheduler:
22
+ # _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
23
+ # _partial_: true
24
+ # mode: min
25
+ # factor: 0.1
26
+ # patience: 10
27
+
28
+ logger:
29
+ wandb:
30
+ project: "hrcWhu"
31
+ name: "kappamask"
32
+ aim:
33
+ experiment: "hrcwhu_kappamask"
34
+
35
+ callbacks:
36
+ model_checkpoint:
37
+ dirpath: ${paths.output_dir}/checkpoints
38
+ filename: "epoch_{epoch:03d}"
39
+ monitor: "val/loss"
40
+ mode: "min"
41
+ save_last: True
42
+ auto_insert_metric_name: False
43
+
44
+ early_stopping:
45
+ monitor: "val/loss"
46
+ patience: 10
47
+ mode: "min"
configs/experiment/lnn.yaml CHANGED
@@ -1,57 +1,57 @@
1
- # @package _global_
2
-
3
- # to execute this experiment run:
4
- # python train.py experiment=example
5
-
6
- defaults:
7
- - override /trainer: gpu
8
- - override /data: mnist
9
- - override /model: lnn
10
- - override /logger: wandb
11
- - override /callbacks: default
12
-
13
- # all parameters below will be merged with parameters from default configurations set above
14
- # this allows you to overwrite only specified parameters
15
-
16
- tags: ["mnist", "lnn"]
17
-
18
- seed: 42
19
-
20
- trainer:
21
- min_epochs: 10
22
- max_epochs: 10
23
- gradient_clip_val: 0.5
24
- devices: 1
25
-
26
- data:
27
- batch_size: 128
28
- train_val_test_split: [55_000, 5_000, 10_000]
29
- num_workers: 31
30
- pin_memory: False
31
- persistent_workers: False
32
-
33
- model:
34
- net:
35
- _target_: src.models.components.lnn.LNN
36
- dim: 32
37
-
38
- logger:
39
- wandb:
40
- project: "mnist"
41
- name: "lnn"
42
- aim:
43
- experiment: "lnn"
44
-
45
- callbacks:
46
- model_checkpoint:
47
- dirpath: ${paths.output_dir}/checkpoints
48
- filename: "epoch_{epoch:03d}"
49
- monitor: "val/acc"
50
- mode: "max"
51
- save_last: True
52
- auto_insert_metric_name: False
53
-
54
- early_stopping:
55
- monitor: "val/acc"
56
- patience: 100
57
  mode: "max"
 
1
+ # @package _global_
2
+
3
+ # to execute this experiment run:
4
+ # python train.py experiment=example
5
+
6
+ defaults:
7
+ - override /trainer: gpu
8
+ - override /data: mnist
9
+ - override /model: lnn
10
+ - override /logger: wandb
11
+ - override /callbacks: default
12
+
13
+ # all parameters below will be merged with parameters from default configurations set above
14
+ # this allows you to overwrite only specified parameters
15
+
16
+ tags: ["mnist", "lnn"]
17
+
18
+ seed: 42
19
+
20
+ trainer:
21
+ min_epochs: 10
22
+ max_epochs: 10
23
+ gradient_clip_val: 0.5
24
+ devices: 1
25
+
26
+ data:
27
+ batch_size: 128
28
+ train_val_test_split: [55_000, 5_000, 10_000]
29
+ num_workers: 31
30
+ pin_memory: False
31
+ persistent_workers: False
32
+
33
+ model:
34
+ net:
35
+ _target_: src.models.components.lnn.LNN
36
+ dim: 32
37
+
38
+ logger:
39
+ wandb:
40
+ project: "mnist"
41
+ name: "lnn"
42
+ aim:
43
+ experiment: "lnn"
44
+
45
+ callbacks:
46
+ model_checkpoint:
47
+ dirpath: ${paths.output_dir}/checkpoints
48
+ filename: "epoch_{epoch:03d}"
49
+ monitor: "val/acc"
50
+ mode: "max"
51
+ save_last: True
52
+ auto_insert_metric_name: False
53
+
54
+ early_stopping:
55
+ monitor: "val/acc"
56
+ patience: 100
57
  mode: "max"
configs/model/kappamask/kappamask.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.models.base_module.BaseLitModule
2
+
3
+ net:
4
+ _target_: src.models.components.kappamask.KappaMask
5
+ num_classes: 2
6
+ in_channels: 3
7
+
8
+ num_classes: 2
9
+
10
+ criterion:
11
+ _target_: torch.nn.CrossEntropyLoss
12
+
13
+ optimizer:
14
+ _target_: torch.optim.AdamW
15
+ _partial_: true
16
+ lr: 0.00001
17
+
18
+ scheduler:
19
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
20
+ _partial_: true
21
+ mode: min
22
+ factor: 0.1
23
+ patience: 4
environment.yaml CHANGED
@@ -24,4 +24,6 @@ dependencies:
24
  - aim
25
  - gradio
26
  - image-dehazer
27
- - thop
 
 
 
24
  - aim
25
  - gradio
26
  - image-dehazer
27
+ - thop
28
+ - albumentations
29
+ - segmentation_models_pytorch
logs/train/runs/hrcwhu_kappamask/2024-08-07_16-34-34/checkpoints/epoch_015.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb7a77410abe75b5a66c9ec7d46f4203c8c41d70180174b0d0e2db2ff21b31ca
3
+ size 372474542
logs/train/runs/hrcwhu_kappamask/2024-08-07_16-34-34/checkpoints/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2cdda7ab749f7a75c157da0cbe88cb379b712ab604294e8436e204748f5beb4
3
+ size 372474606
src/models/components/kappamask.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/8/7 下午3:51
3
+ # @Author : xiaoshun
4
+ # @Email : [email protected]
5
+ # @File : kappamask.py.py
6
+ # @Software: PyCharm
7
+
8
+ import torch
9
+ from torch import nn as nn
10
+ from torch.nn import functional as F
11
+
12
+
13
+ class KappaMask(nn.Module):
14
+ def __init__(self, num_classes=2, in_channels=3):
15
+ super().__init__()
16
+ self.conv1 = nn.Sequential(
17
+ nn.Conv2d(in_channels, 64, 3, 1, 1),
18
+ nn.ReLU(inplace=True),
19
+ nn.Conv2d(64, 64, 3, 1, 1),
20
+ nn.ReLU(inplace=True),
21
+ )
22
+ self.conv2 = nn.Sequential(
23
+ nn.Conv2d(64, 128, 3, 1, 1),
24
+ nn.ReLU(inplace=True),
25
+ nn.Conv2d(128, 128, 3, 1, 1),
26
+ nn.ReLU(inplace=True),
27
+ )
28
+ self.conv3 = nn.Sequential(
29
+ nn.Conv2d(128, 256, 3, 1, 1),
30
+ nn.ReLU(inplace=True),
31
+ nn.Conv2d(256, 256, 3, 1, 1),
32
+ nn.ReLU(inplace=True),
33
+ )
34
+
35
+ self.conv4 = nn.Sequential(
36
+ nn.Conv2d(256, 512, 3, 1, 1),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv2d(512, 512, 3, 1, 1),
39
+ nn.ReLU(inplace=True),
40
+ )
41
+ self.drop4 = nn.Dropout(0.5)
42
+
43
+ self.conv5 = nn.Sequential(
44
+ nn.Conv2d(512, 1024, 3, 1, 1),
45
+ nn.ReLU(inplace=True),
46
+ nn.Conv2d(1024, 1024, 3, 1, 1),
47
+ nn.ReLU(inplace=True),
48
+ )
49
+ self.drop5 = nn.Dropout(0.5)
50
+
51
+ self.up6 = nn.Sequential(
52
+ nn.Upsample(scale_factor=2),
53
+ nn.ZeroPad2d((0, 1, 0, 1)),
54
+ nn.Conv2d(1024, 512, 2),
55
+ nn.ReLU(inplace=True)
56
+ )
57
+ self.conv6 = nn.Sequential(
58
+ nn.Conv2d(1024, 512, 3, 1, 1),
59
+ nn.ReLU(inplace=True),
60
+ nn.Conv2d(512, 512, 3, 1, 1),
61
+ nn.ReLU(inplace=True),
62
+ )
63
+ self.up7 = nn.Sequential(
64
+ nn.Upsample(scale_factor=2),
65
+ nn.ZeroPad2d((0, 1, 0, 1)),
66
+ nn.Conv2d(512, 256, 2),
67
+ nn.ReLU(inplace=True)
68
+ )
69
+ self.conv7 = nn.Sequential(
70
+ nn.Conv2d(512, 256, 3, 1, 1),
71
+ nn.ReLU(inplace=True),
72
+ nn.Conv2d(256, 256, 3, 1, 1),
73
+ nn.ReLU(inplace=True),
74
+ )
75
+
76
+ self.up8 = nn.Sequential(
77
+ nn.Upsample(scale_factor=2),
78
+ nn.ZeroPad2d((0, 1, 0, 1)),
79
+ nn.Conv2d(256, 128, 2),
80
+ nn.ReLU(inplace=True)
81
+ )
82
+ self.conv8 = nn.Sequential(
83
+ nn.Conv2d(256, 128, 3, 1, 1),
84
+ nn.ReLU(inplace=True),
85
+ nn.Conv2d(128, 128, 3, 1, 1),
86
+ nn.ReLU(inplace=True),
87
+ )
88
+
89
+ self.up9 = nn.Sequential(
90
+ nn.Upsample(scale_factor=2),
91
+ nn.ZeroPad2d((0, 1, 0, 1)),
92
+ nn.Conv2d(128, 64, 2),
93
+ nn.ReLU(inplace=True)
94
+ )
95
+ self.conv9 = nn.Sequential(
96
+ nn.Conv2d(128, 64, 3, 1, 1),
97
+ nn.ReLU(inplace=True),
98
+ nn.Conv2d(64, 64, 3, 1, 1),
99
+ nn.ReLU(inplace=True),
100
+ nn.Conv2d(64, 2, 3, 1, 1),
101
+ nn.ReLU(inplace=True),
102
+ )
103
+ self.conv10 = nn.Conv2d(2, num_classes, 1)
104
+ self.__init_weights()
105
+
106
+ def __init_weights(self):
107
+ for m in self.modules():
108
+ if isinstance(m, nn.Conv2d):
109
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
110
+
111
+ def forward(self, x):
112
+ conv1 = self.conv1(x)
113
+ pool1 = F.max_pool2d(conv1, 2, 2)
114
+
115
+ conv2 = self.conv2(pool1)
116
+ pool2 = F.max_pool2d(conv2, 2, 2)
117
+
118
+ conv3 = self.conv3(pool2)
119
+ pool3 = F.max_pool2d(conv3, 2, 2)
120
+
121
+ conv4 = self.conv4(pool3)
122
+ drop4 = self.drop4(conv4)
123
+ pool4 = F.max_pool2d(drop4, 2, 2)
124
+
125
+ conv5 = self.conv5(pool4)
126
+ drop5 = self.drop5(conv5)
127
+
128
+ up6 = self.up6(drop5)
129
+ merge6 = torch.cat((drop4, up6), dim=1)
130
+ conv6 = self.conv6(merge6)
131
+
132
+ up7 = self.up7(conv6)
133
+ merge7 = torch.cat((conv3, up7), dim=1)
134
+ conv7 = self.conv7(merge7)
135
+
136
+ up8 = self.up8(conv7)
137
+ merge8 = torch.cat((conv2, up8), dim=1)
138
+ conv8 = self.conv8(merge8)
139
+
140
+ up9 = self.up9(conv8)
141
+ merge9 = torch.cat((conv1, up9), dim=1)
142
+ conv9 = self.conv9(merge9)
143
+
144
+ output = self.conv10(conv9)
145
+ return output
146
+
147
+
148
+ if __name__ == '__main__':
149
+ model = KappaMask(num_classes=2, in_channels=3)
150
+ fake_data = torch.rand(2, 3, 256, 256)
151
+ output = model(fake_data)
152
+ print(output.shape)
wandb_vis.py CHANGED
@@ -25,6 +25,7 @@ from src.models.components.cdnetv1 import CDnetV1
25
  from src.models.components.cdnetv2 import CDnetV2
26
  from src.models.components.dbnet import DBNet
27
  from src.models.components.hrcloudnet import HRCloudNet
 
28
  from src.models.components.mcdnet import MCDNet
29
  from src.models.components.scnn import SCNN
30
  from src.models.components.unetmobv2 import UNetMobV2
@@ -69,6 +70,9 @@ class WandbVis:
69
  if self.model_name == "unetmobv2":
70
  return UNetMobV2(num_classes=2).to(self.device)
71
 
 
 
 
72
  raise ValueError(f"{self.model_name}模型不存在")
73
 
74
  def load_model(self):
@@ -91,7 +95,7 @@ class WandbVis:
91
  ]
92
  )
93
  img_transform = albu.Compose([
94
- albu.ToFloat(255.0),
95
  ToTensorV2()
96
  ])
97
  ann_transform = None
 
25
  from src.models.components.cdnetv2 import CDnetV2
26
  from src.models.components.dbnet import DBNet
27
  from src.models.components.hrcloudnet import HRCloudNet
28
+ from src.models.components.kappamask import KappaMask
29
  from src.models.components.mcdnet import MCDNet
30
  from src.models.components.scnn import SCNN
31
  from src.models.components.unetmobv2 import UNetMobV2
 
70
  if self.model_name == "unetmobv2":
71
  return UNetMobV2(num_classes=2).to(self.device)
72
 
73
+ if self.model_name == "kappamask":
74
+ return KappaMask(num_classes=2, in_channels=3).to(self.device)
75
+
76
  raise ValueError(f"{self.model_name}模型不存在")
77
 
78
  def load_model(self):
 
95
  ]
96
  )
97
  img_transform = albu.Compose([
98
+ albu.ToFloat(),
99
  ToTensorV2()
100
  ])
101
  ann_transform = None