File size: 5,187 Bytes
8c02843
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from StructDiffusion.models.models import TransformerDiffusionModel, PCTDiscriminator, FocalLoss

from StructDiffusion.diffusion.noise_schedule import NoiseSchedule, q_sample
from StructDiffusion.diffusion.pose_conversion import get_diffusion_variables_from_H, get_diffusion_variables_from_9D_actions


class ConditionalPoseDiffusionModel(pl.LightningModule):

    def __init__(self, vocab_size, model_cfg, loss_cfg, noise_scheduler_cfg, optimizer_cfg):
        super().__init__()
        self.save_hyperparameters()

        self.model = TransformerDiffusionModel(vocab_size, **model_cfg)

        self.noise_schedule = NoiseSchedule(**noise_scheduler_cfg)

        self.loss_type = loss_cfg.type

        self.optimizer_cfg = optimizer_cfg
        self.configure_optimizers()

        self.batch_size = None

    def forward(self, batch):

        # input
        pcs = batch["pcs"]
        B = pcs.shape[0]
        self.batch_size = B
        sentence = batch["sentence"]
        goal_poses = batch["goal_poses"]
        type_index = batch["type_index"]
        position_index = batch["position_index"]
        pad_mask = batch["pad_mask"]

        t = torch.randint(0, self.noise_schedule.timesteps, (B,), dtype=torch.long).to(self.device)

        # --------------
        x_start = get_diffusion_variables_from_H(goal_poses)
        noise = torch.randn_like(x_start, device=self.device)
        x_noisy = q_sample(x_start=x_start, t=t, noise_schedule=self.noise_schedule, noise=noise)

        predicted_noise = self.model.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask)

        # important: skip computing loss for masked positions
        num_poses = goal_poses.shape[1]  # B, N, 4, 4
        pose_pad_mask = pad_mask[:, -num_poses:]
        keep_mask = (pose_pad_mask == 0)
        noise = noise[keep_mask]  # dim: number of positions that need loss calculation
        predicted_noise = predicted_noise[keep_mask]

        return noise, predicted_noise

    def compute_loss(self, noise, predicted_noise, prefix="train/"):
        if self.loss_type == 'l1':
            loss = F.l1_loss(noise, predicted_noise)
        elif self.loss_type == 'l2':
            loss = F.mse_loss(noise, predicted_noise)
        elif self.loss_type == "huber":
            loss = F.smooth_l1_loss(noise, predicted_noise)
        else:
            raise NotImplementedError()

        self.log(prefix + "loss", loss, prog_bar=True, batch_size=self.batch_size)
        return loss

    def training_step(self, batch, batch_idx):
        noise, pred_noise = self.forward(batch)
        loss = self.compute_loss(noise, pred_noise, prefix="train/")
        return loss

    def validation_step(self, batch, batch_idx):
        noise, pred_noise = self.forward(batch)
        loss = self.compute_loss(noise, pred_noise, prefix="val/")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.optimizer_cfg.lr, weight_decay=self.optimizer_cfg.weight_decay) # 1e-5
        return optimizer


class PairwiseCollisionModel(pl.LightningModule):

    def __init__(self, model_cfg, loss_cfg, optimizer_cfg, data_cfg):
        super().__init__()
        self.save_hyperparameters()

        self.model = PCTDiscriminator(**model_cfg)

        self.loss_cfg = loss_cfg
        self.loss = None
        self.configure_loss()

        self.optimizer_cfg = optimizer_cfg
        self.configure_optimizers()

        # this is stored, because some of the data parameters affect the model behavior
        self.data_cfg = data_cfg

    def forward(self, batch):
        label = batch["label"]
        predicted_label = self.model.forward(batch["scene_xyz"])
        return label, predicted_label

    def compute_loss(self, label, predicted_label, prefix="train/"):
        if self.loss_cfg.type == "MSE":
            predicted_label = torch.sigmoid(predicted_label)
        loss = self.loss(predicted_label, label)
        self.log(prefix + "loss", loss, prog_bar=True)
        return loss

    def training_step(self, batch, batch_idx):
        label, predicted_label = self.forward(batch)
        loss = self.compute_loss(label, predicted_label, prefix="train/")
        return loss

    def validation_step(self, batch, batch_idx):
        label, predicted_label = self.forward(batch)
        loss = self.compute_loss(label, predicted_label, prefix="val/")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.optimizer_cfg.lr, weight_decay=self.optimizer_cfg.weight_decay) # 1e-5
        return optimizer

    def configure_loss(self):
        if self.loss_cfg.type == "Focal":
            print("use focal loss with gamma {}".format(self.loss_cfg.focal_gamma))
            self.loss = FocalLoss(gamma=self.loss_cfg.focal_gamma)
        elif self.loss_cfg.type == "MSE":
            print("use regression L2 loss")
            self.loss = torch.nn.MSELoss()
        elif self.loss_cfg.type == "BCE":
            print("use standard BCE logit loss")
            self.loss = torch.nn.BCEWithLogitsLoss(reduction="mean")