import torch import torch.nn.functional as F import pytorch_lightning as pl from StructDiffusion.models.models import TransformerDiffusionModel, PCTDiscriminator, FocalLoss from StructDiffusion.diffusion.noise_schedule import NoiseSchedule, q_sample from StructDiffusion.diffusion.pose_conversion import get_diffusion_variables_from_H, get_diffusion_variables_from_9D_actions class ConditionalPoseDiffusionModel(pl.LightningModule): def __init__(self, vocab_size, model_cfg, loss_cfg, noise_scheduler_cfg, optimizer_cfg): super().__init__() self.save_hyperparameters() self.model = TransformerDiffusionModel(vocab_size, **model_cfg) self.noise_schedule = NoiseSchedule(**noise_scheduler_cfg) self.loss_type = loss_cfg.type self.optimizer_cfg = optimizer_cfg self.configure_optimizers() self.batch_size = None def forward(self, batch): # input pcs = batch["pcs"] B = pcs.shape[0] self.batch_size = B sentence = batch["sentence"] goal_poses = batch["goal_poses"] type_index = batch["type_index"] position_index = batch["position_index"] pad_mask = batch["pad_mask"] t = torch.randint(0, self.noise_schedule.timesteps, (B,), dtype=torch.long).to(self.device) # -------------- x_start = get_diffusion_variables_from_H(goal_poses) noise = torch.randn_like(x_start, device=self.device) x_noisy = q_sample(x_start=x_start, t=t, noise_schedule=self.noise_schedule, noise=noise) predicted_noise = self.model.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask) # important: skip computing loss for masked positions num_poses = goal_poses.shape[1] # B, N, 4, 4 pose_pad_mask = pad_mask[:, -num_poses:] keep_mask = (pose_pad_mask == 0) noise = noise[keep_mask] # dim: number of positions that need loss calculation predicted_noise = predicted_noise[keep_mask] return noise, predicted_noise def compute_loss(self, noise, predicted_noise, prefix="train/"): if self.loss_type == 'l1': loss = F.l1_loss(noise, predicted_noise) elif self.loss_type == 'l2': loss = F.mse_loss(noise, predicted_noise) elif self.loss_type == "huber": loss = F.smooth_l1_loss(noise, predicted_noise) else: raise NotImplementedError() self.log(prefix + "loss", loss, prog_bar=True, batch_size=self.batch_size) return loss def training_step(self, batch, batch_idx): noise, pred_noise = self.forward(batch) loss = self.compute_loss(noise, pred_noise, prefix="train/") return loss def validation_step(self, batch, batch_idx): noise, pred_noise = self.forward(batch) loss = self.compute_loss(noise, pred_noise, prefix="val/") def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.optimizer_cfg.lr, weight_decay=self.optimizer_cfg.weight_decay) # 1e-5 return optimizer class PairwiseCollisionModel(pl.LightningModule): def __init__(self, model_cfg, loss_cfg, optimizer_cfg, data_cfg): super().__init__() self.save_hyperparameters() self.model = PCTDiscriminator(**model_cfg) self.loss_cfg = loss_cfg self.loss = None self.configure_loss() self.optimizer_cfg = optimizer_cfg self.configure_optimizers() # this is stored, because some of the data parameters affect the model behavior self.data_cfg = data_cfg def forward(self, batch): label = batch["label"] predicted_label = self.model.forward(batch["scene_xyz"]) return label, predicted_label def compute_loss(self, label, predicted_label, prefix="train/"): if self.loss_cfg.type == "MSE": predicted_label = torch.sigmoid(predicted_label) loss = self.loss(predicted_label, label) self.log(prefix + "loss", loss, prog_bar=True) return loss def training_step(self, batch, batch_idx): label, predicted_label = self.forward(batch) loss = self.compute_loss(label, predicted_label, prefix="train/") return loss def validation_step(self, batch, batch_idx): label, predicted_label = self.forward(batch) loss = self.compute_loss(label, predicted_label, prefix="val/") def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.optimizer_cfg.lr, weight_decay=self.optimizer_cfg.weight_decay) # 1e-5 return optimizer def configure_loss(self): if self.loss_cfg.type == "Focal": print("use focal loss with gamma {}".format(self.loss_cfg.focal_gamma)) self.loss = FocalLoss(gamma=self.loss_cfg.focal_gamma) elif self.loss_cfg.type == "MSE": print("use regression L2 loss") self.loss = torch.nn.MSELoss() elif self.loss_cfg.type == "BCE": print("use standard BCE logit loss") self.loss = torch.nn.BCEWithLogitsLoss(reduction="mean")