File size: 11,679 Bytes
8e5cc83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import random
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from PIL import Image
from tqdm import tqdm

from efficientvit.apps.trainer import Trainer
from efficientvit.apps.utils import AverageMeter, get_dist_local_rank, get_dist_size, is_master, sync_tensor
from efficientvit.models.utils import list_join
from efficientvit.samcore.data_provider import SAMDataProvider
from efficientvit.samcore.trainer import SAMRunConfig
from efficientvit.samcore.trainer.utils import (
    compute_boundary_iou,
    compute_iou,
    loss_masks,
    mask_iou_batch,
    masks_sample_points,
)

__all__ = ["SAMTrainer"]


class SAMTrainer(Trainer):
    def __init__(

        self,

        path: str,

        model: nn.Module,

        data_provider: SAMDataProvider,

    ) -> None:
        super().__init__(
            path=path,
            model=model,
            data_provider=data_provider,
        )

        if is_master():
            self.wandb_log = wandb.init(project="efficientvit-sam")

    def _validate(self, model, data_loader, epoch: int, sub_epoch: int) -> dict[str, any]:
        val_loss = AverageMeter()
        val_iou = AverageMeter()
        val_iou_boundary = AverageMeter()

        with torch.no_grad():
            with tqdm(
                total=len(data_loader),
                desc=f"Validate Epoch #{epoch + 1}, Sub Epoch #{sub_epoch+1}",
                disable=not is_master(),
                file=sys.stdout,
            ) as t:
                for i, data in enumerate(data_loader):
                    image = data["image"].cuda()
                    masks = data["masks"].cuda()
                    bboxs = data["bboxs"].cuda() * 2 if image.shape[2] == 512 else data["bboxs"].cuda()
                    points = data["points"].cuda() * 2 if image.shape[2] == 512 else data["points"].cuda()

                    bboxs[..., 2] = bboxs[..., 0] + bboxs[..., 2]
                    bboxs[..., 3] = bboxs[..., 1] + bboxs[..., 3]

                    batched_input = []
                    for b_i in range(len(image)):
                        dict_input = dict()

                        dict_input["image"] = image[b_i]
                        dict_input["boxes"] = bboxs[b_i]

                        batched_input.append(dict_input)

                    output, iou_predictions = model(batched_input, True)

                    B, M, N, H, W = output.shape
                    output = torch.stack(
                        [
                            output[k][torch.arange(M), iou_predictions[k].argmax(-1).squeeze()]
                            for k in range(len(output))
                        ],
                        dim=0,
                    )
                    output = (
                        F.interpolate(output, size=(image.shape[2], image.shape[3]), mode="bilinear")
                        .reshape(-1, image.shape[2], image.shape[3])
                        .unsqueeze(1)
                    )
                    masks = masks.reshape(-1, image.shape[2], image.shape[3]).unsqueeze(1)

                    loss_mask, loss_dice = loss_masks(output, masks, len(output))
                    loss = loss_mask * 20 + loss_dice

                    iou = compute_iou(output, masks * 255)
                    boundary_iou = compute_boundary_iou(output, masks * 255)

                    loss = sync_tensor(loss)
                    iou = sync_tensor(iou)
                    boundary_iou = sync_tensor(boundary_iou)

                    val_loss.update(loss, image.shape[0] * get_dist_size())
                    val_iou.update(iou, image.shape[0] * get_dist_size())
                    val_iou_boundary.update(boundary_iou, image.shape[0] * get_dist_size())

                    t.set_postfix(
                        {
                            "loss": val_loss.avg,
                            "iou": val_iou.avg,
                            "boundary_iou": val_iou_boundary.avg,
                            "bs": image.shape[0] * get_dist_size(),
                        }
                    )
                    t.update()

        if is_master():
            self.wandb_log.log(
                {"val_loss": val_loss.avg, "val_iou": val_iou.avg, "val_boundary_iou": val_iou_boundary.avg}
            )

        return {
            "val_loss": val_loss.avg,
            "val_iou": val_iou.avg,
            "val_boundary_iou": val_iou_boundary.avg,
        }

    def validate(self, model=None, data_loader=None, epoch=0, sub_epoch=0) -> dict[str, any]:
        model = model or self.eval_network
        if data_loader is None:
            data_loader = self.data_provider.valid

        model.eval()
        return self._validate(model, data_loader, epoch, sub_epoch)

    def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
        image = feed_dict["image"].cuda()
        masks = feed_dict["masks"].cuda()
        bboxs = feed_dict["bboxs"].cuda() * 2 if image.shape[2] == 512 else feed_dict["bboxs"].cuda()
        points = feed_dict["points"].cuda() * 2 if image.shape[2] == 512 else feed_dict["points"].cuda()

        bboxs[..., 2] = bboxs[..., 0] + bboxs[..., 2]
        bboxs[..., 3] = bboxs[..., 1] + bboxs[..., 3]

        return {
            "image": image,
            "masks": masks,
            "points": points,
            "bboxs": bboxs,
        }

    def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]:
        image = feed_dict["image"]
        masks = feed_dict["masks"]
        bboxs = feed_dict["bboxs"]
        points = feed_dict["points"]

        batched_input = []
        for b_i in range(len(image)):
            dict_input = dict()
            dict_input["image"] = image[b_i]

            if random.random() >= 0.5:
                dict_input["boxes"] = bboxs[b_i]
            else:
                try:
                    n_p = int(random.random() * 10 + 1)
                    dict_input["point_coords"] = masks_sample_points(masks[b_i], k=n_p)
                    if image.shape[2] == 512:
                        dict_input["point_coords"] = dict_input["point_coords"] * 2
                    dict_input["point_labels"] = torch.ones((points[b_i].shape[0], n_p), device=image.device)
                except:
                    dict_input["boxes"] = bboxs[b_i]

            batched_input.append(dict_input)

        with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.enable_amp):
            if random.random() >= 0.5:
                output, iou_predictions = self.model(batched_input, multimask_output=True)
            else:
                output, iou_predictions = self.model(batched_input, multimask_output=False)

            masks = masks.reshape(-1, image.shape[2], image.shape[3]).unsqueeze(1)

            loss_list = []
            for i in range(output.shape[2]):
                output_i = (
                    F.interpolate(output[:, :, i], size=(image.shape[2], image.shape[3]), mode="bilinear")
                    .reshape(-1, image.shape[2], image.shape[3])
                    .unsqueeze(1)
                )
                loss_mask_i, loss_dice_i = loss_masks(output_i, masks, len(output_i), mode="none")
                loss_i = loss_mask_i * 20 + loss_dice_i
                loss_list.append(loss_i)
            loss = torch.stack(loss_list, -1)

            min_indices = torch.argmin(loss, dim=1)
            mask = torch.zeros_like(loss, device=loss.device)
            mask.scatter_(1, min_indices.unsqueeze(1), 1)

            loss = (loss * mask).mean() * loss.shape[-1]

        self.scaler.scale(loss).backward()

        return {"loss": loss, "output": output}

    def _train_one_sub_epoch(self, epoch: int, sub_epoch: int) -> dict[str, any]:
        train_loss = AverageMeter()

        with tqdm(
            total=len(self.data_provider.train),
            desc=f"Train Epoch #{epoch + 1}, Sub Epoch #{sub_epoch + 1}",
            disable=not is_master(),
            file=sys.stdout,
        ) as t:
            for i, data in enumerate(self.data_provider.train):
                feed_dict = data

                # preprocessing
                feed_dict = self.before_step(feed_dict)
                # clear gradient
                self.optimizer.zero_grad()
                # forward & backward
                output_dict = self.run_step(feed_dict)
                # update: optimizer, lr_scheduler
                self.after_step()

                loss = output_dict["loss"]
                loss = sync_tensor(loss)
                train_loss.update(loss, data["image"].shape[0] * get_dist_size())

                if is_master():
                    self.wandb_log.log(
                        {
                            "train_loss": train_loss.avg,
                            "epoch": epoch,
                            "sub_epoch": sub_epoch,
                            "learning_rate": sorted(set([group["lr"] for group in self.optimizer.param_groups]))[0],
                        }
                    )

                t.set_postfix(
                    {
                        "loss": train_loss.avg,
                        "bs": data["image"].shape[0] * get_dist_size(),
                        "res": data["image"].shape[2],
                        "lr": list_join(
                            sorted(set([group["lr"] for group in self.optimizer.param_groups])),
                            "#",
                            "%.1E",
                        ),
                        "progress": self.run_config.progress,
                    }
                )
                t.update()

        return {
            "train_loss": train_loss.avg,
        }

    def train_one_sub_epoch(self, epoch: int, sub_epoch: int) -> dict[str, any]:
        self.model.train()

        self.data_provider.set_epoch_and_sub_epoch(epoch, sub_epoch)

        train_info_dict = self._train_one_sub_epoch(epoch, sub_epoch)

        return train_info_dict

    def train(self) -> None:
        for sub_epoch in range(self.start_epoch, self.run_config.n_epochs):
            epoch = sub_epoch // self.data_provider.sub_epochs_per_epoch

            train_info_dict = self.train_one_sub_epoch(epoch, sub_epoch)

            val_info_dict = self.validate(epoch=epoch, sub_epoch=sub_epoch)

            val_iou = val_info_dict["val_iou"]
            is_best = val_iou > self.best_val
            self.best_val = max(val_iou, self.best_val)

            self.save_model(
                only_state_dict=False,
                epoch=sub_epoch,
                model_name=f"checkpoint_{epoch}_{sub_epoch}.pt",
            )

    def prep_for_training(self, run_config: SAMRunConfig, amp="fp32") -> None:
        self.run_config = run_config
        self.model = nn.parallel.DistributedDataParallel(
            self.model.cuda(),
            device_ids=[get_dist_local_rank()],
            find_unused_parameters=True,
        )

        self.run_config.global_step = 0
        self.run_config.batch_per_epoch = len(self.data_provider.train)
        assert self.run_config.batch_per_epoch > 0, "Training set is empty"

        # build optimizer
        self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model)

        # amp
        self.amp = amp
        self.scaler = torch.cuda.amp.GradScaler(enabled=self.enable_amp)