Spaces:
Paused
Paused
import torch | |
import torch.nn.functional as F | |
import pytorch_lightning as pl | |
from StructDiffusion.models.models import TransformerDiffusionModel, PCTDiscriminator, FocalLoss | |
from StructDiffusion.diffusion.noise_schedule import NoiseSchedule, q_sample | |
from StructDiffusion.diffusion.pose_conversion import get_diffusion_variables_from_H, get_diffusion_variables_from_9D_actions | |
class ConditionalPoseDiffusionModel(pl.LightningModule): | |
def __init__(self, vocab_size, model_cfg, loss_cfg, noise_scheduler_cfg, optimizer_cfg): | |
super().__init__() | |
self.save_hyperparameters() | |
self.model = TransformerDiffusionModel(vocab_size, **model_cfg) | |
self.noise_schedule = NoiseSchedule(**noise_scheduler_cfg) | |
self.loss_type = loss_cfg.type | |
self.optimizer_cfg = optimizer_cfg | |
self.configure_optimizers() | |
self.batch_size = None | |
def forward(self, batch): | |
# input | |
pcs = batch["pcs"] | |
B = pcs.shape[0] | |
self.batch_size = B | |
sentence = batch["sentence"] | |
goal_poses = batch["goal_poses"] | |
type_index = batch["type_index"] | |
position_index = batch["position_index"] | |
pad_mask = batch["pad_mask"] | |
t = torch.randint(0, self.noise_schedule.timesteps, (B,), dtype=torch.long).to(self.device) | |
# -------------- | |
x_start = get_diffusion_variables_from_H(goal_poses) | |
noise = torch.randn_like(x_start, device=self.device) | |
x_noisy = q_sample(x_start=x_start, t=t, noise_schedule=self.noise_schedule, noise=noise) | |
predicted_noise = self.model.forward(t, pcs, sentence, x_noisy, type_index, position_index, pad_mask) | |
# important: skip computing loss for masked positions | |
num_poses = goal_poses.shape[1] # B, N, 4, 4 | |
pose_pad_mask = pad_mask[:, -num_poses:] | |
keep_mask = (pose_pad_mask == 0) | |
noise = noise[keep_mask] # dim: number of positions that need loss calculation | |
predicted_noise = predicted_noise[keep_mask] | |
return noise, predicted_noise | |
def compute_loss(self, noise, predicted_noise, prefix="train/"): | |
if self.loss_type == 'l1': | |
loss = F.l1_loss(noise, predicted_noise) | |
elif self.loss_type == 'l2': | |
loss = F.mse_loss(noise, predicted_noise) | |
elif self.loss_type == "huber": | |
loss = F.smooth_l1_loss(noise, predicted_noise) | |
else: | |
raise NotImplementedError() | |
self.log(prefix + "loss", loss, prog_bar=True, batch_size=self.batch_size) | |
return loss | |
def training_step(self, batch, batch_idx): | |
noise, pred_noise = self.forward(batch) | |
loss = self.compute_loss(noise, pred_noise, prefix="train/") | |
return loss | |
def validation_step(self, batch, batch_idx): | |
noise, pred_noise = self.forward(batch) | |
loss = self.compute_loss(noise, pred_noise, prefix="val/") | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.optimizer_cfg.lr, weight_decay=self.optimizer_cfg.weight_decay) # 1e-5 | |
return optimizer | |
class PairwiseCollisionModel(pl.LightningModule): | |
def __init__(self, model_cfg, loss_cfg, optimizer_cfg, data_cfg): | |
super().__init__() | |
self.save_hyperparameters() | |
self.model = PCTDiscriminator(**model_cfg) | |
self.loss_cfg = loss_cfg | |
self.loss = None | |
self.configure_loss() | |
self.optimizer_cfg = optimizer_cfg | |
self.configure_optimizers() | |
# this is stored, because some of the data parameters affect the model behavior | |
self.data_cfg = data_cfg | |
def forward(self, batch): | |
label = batch["label"] | |
predicted_label = self.model.forward(batch["scene_xyz"]) | |
return label, predicted_label | |
def compute_loss(self, label, predicted_label, prefix="train/"): | |
if self.loss_cfg.type == "MSE": | |
predicted_label = torch.sigmoid(predicted_label) | |
loss = self.loss(predicted_label, label) | |
self.log(prefix + "loss", loss, prog_bar=True) | |
return loss | |
def training_step(self, batch, batch_idx): | |
label, predicted_label = self.forward(batch) | |
loss = self.compute_loss(label, predicted_label, prefix="train/") | |
return loss | |
def validation_step(self, batch, batch_idx): | |
label, predicted_label = self.forward(batch) | |
loss = self.compute_loss(label, predicted_label, prefix="val/") | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.optimizer_cfg.lr, weight_decay=self.optimizer_cfg.weight_decay) # 1e-5 | |
return optimizer | |
def configure_loss(self): | |
if self.loss_cfg.type == "Focal": | |
print("use focal loss with gamma {}".format(self.loss_cfg.focal_gamma)) | |
self.loss = FocalLoss(gamma=self.loss_cfg.focal_gamma) | |
elif self.loss_cfg.type == "MSE": | |
print("use regression L2 loss") | |
self.loss = torch.nn.MSELoss() | |
elif self.loss_cfg.type == "BCE": | |
print("use standard BCE logit loss") | |
self.loss = torch.nn.BCEWithLogitsLoss(reduction="mean") |