|
import os |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from pytorch_lightning import LightningModule |
|
|
|
from cliport.tasks import cameras |
|
from cliport.utils import utils |
|
from cliport.models.core.attention import Attention |
|
from cliport.models.core.transport import Transport |
|
from cliport.models.streams.two_stream_attention import TwoStreamAttention |
|
from cliport.models.streams.two_stream_transport import TwoStreamTransport |
|
|
|
from cliport.models.streams.two_stream_attention import TwoStreamAttentionLat |
|
from cliport.models.streams.two_stream_transport import TwoStreamTransportLat |
|
import time |
|
import IPython |
|
|
|
class TransporterAgent(LightningModule): |
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__() |
|
utils.set_seed(0) |
|
self.automatic_optimization=False |
|
self.device_type = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.name = name |
|
self.cfg = cfg |
|
self.train_loader = train_ds |
|
self.test_loader = test_ds |
|
|
|
self.train_ds = train_ds.dataset |
|
self.test_ds = test_ds.dataset |
|
|
|
self.name = name |
|
self.task = cfg['train']['task'] |
|
self.total_steps = 0 |
|
self.crop_size = 64 |
|
self.n_rotations = cfg['train']['n_rotations'] |
|
|
|
self.pix_size = 0.003125 |
|
self.in_shape = (320, 160, 6) |
|
self.cam_config = cameras.RealSenseD415.CONFIG |
|
self.bounds = np.array([[0.25, 0.75], [-0.5, 0.5], [0, 0.28]]) |
|
|
|
self.val_repeats = cfg['train']['val_repeats'] |
|
self.save_steps = cfg['train']['save_steps'] |
|
|
|
self._build_model() |
|
|
|
|
|
|
|
self._optimizers = { |
|
'attn': torch.optim.Adam(self.attention.parameters(), lr=self.cfg['train']['lr']), |
|
'trans': torch.optim.Adam(self.transport.parameters(), lr=self.cfg['train']['lr']) |
|
} |
|
print("Agent: {}, Logging: {}".format(name, cfg['train']['log'])) |
|
|
|
def configure_optimizers(self): |
|
return self._optimizers |
|
|
|
def _build_model(self): |
|
self.attention = None |
|
self.transport = None |
|
raise NotImplementedError() |
|
|
|
def forward(self, x): |
|
raise NotImplementedError() |
|
|
|
def cross_entropy_with_logits(self, pred, labels, reduction='mean'): |
|
|
|
x = (-labels.view(len(labels), -1) * F.log_softmax(pred.view(len(labels), -1), -1)) |
|
if reduction == 'sum': |
|
return x.sum() |
|
elif reduction == 'mean': |
|
return x.mean() |
|
else: |
|
raise NotImplementedError() |
|
|
|
def attn_forward(self, inp, softmax=True): |
|
inp_img = inp['inp_img'] |
|
output = self.attention.forward(inp_img, softmax=softmax) |
|
return output |
|
|
|
def attn_training_step(self, frame, backprop=True, compute_err=False): |
|
inp_img = frame['img'] |
|
p0, p0_theta = frame['p0'], frame['p0_theta'] |
|
|
|
inp = {'inp_img': inp_img} |
|
out = self.attn_forward(inp, softmax=False) |
|
return self.attn_criterion(backprop, compute_err, inp, out, p0, p0_theta) |
|
|
|
def attn_criterion(self, backprop, compute_err, inp, out, p, theta): |
|
|
|
if type(theta) is torch.Tensor: |
|
theta = theta.detach().cpu().numpy() |
|
|
|
theta_i = theta / (2 * np.pi / self.attention.n_rotations) |
|
theta_i = np.int32(np.round(theta_i)) % self.attention.n_rotations |
|
inp_img = inp['inp_img'].float() |
|
|
|
label_size = inp_img.shape[:3] + (self.attention.n_rotations,) |
|
label = torch.zeros(label_size, dtype=torch.float, device=out.device) |
|
|
|
|
|
for idx, p_i in enumerate(p): |
|
label[idx, int(p_i[0]), int(p_i[1]), theta_i[idx]] = 1 |
|
label = label.permute((0, 3, 1, 2)).contiguous() |
|
|
|
|
|
loss = self.cross_entropy_with_logits(out, label) |
|
|
|
|
|
if backprop: |
|
attn_optim = self._optimizers['attn'] |
|
self.manual_backward(loss) |
|
attn_optim.step() |
|
attn_optim.zero_grad() |
|
|
|
|
|
err = {} |
|
if compute_err: |
|
with torch.no_grad(): |
|
pick_conf = self.attn_forward(inp) |
|
pick_conf = pick_conf[0].permute(1,2,0) |
|
pick_conf = pick_conf.detach().cpu().numpy() |
|
p = p[0] |
|
theta = theta[0] |
|
|
|
|
|
argmax = np.argmax(pick_conf) |
|
argmax = np.unravel_index(argmax, shape=pick_conf.shape) |
|
p0_pix = argmax[:2] |
|
p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2]) |
|
|
|
err = { |
|
'dist': np.linalg.norm(np.array(p.detach().cpu().numpy()) - p0_pix, ord=1), |
|
'theta': np.absolute((theta - p0_theta) % np.pi) |
|
} |
|
return loss, err |
|
|
|
def trans_forward(self, inp, softmax=True): |
|
inp_img = inp['inp_img'] |
|
p0 = inp['p0'] |
|
|
|
output = self.transport.forward(inp_img, p0, softmax=softmax) |
|
return output |
|
|
|
def transport_training_step(self, frame, backprop=True, compute_err=False): |
|
inp_img = frame['img'].float() |
|
p0 = frame['p0'] |
|
p1, p1_theta = frame['p1'], frame['p1_theta'] |
|
|
|
inp = {'inp_img': inp_img, 'p0': p0} |
|
output = self.trans_forward(inp, softmax=False) |
|
err, loss = self.transport_criterion(backprop, compute_err, inp, output, p0, p1, p1_theta) |
|
return loss, err |
|
|
|
def transport_criterion(self, backprop, compute_err, inp, output, p, q, theta): |
|
s = time.time() |
|
if type(theta) is torch.Tensor: |
|
theta = theta.detach().cpu().numpy() |
|
|
|
itheta = theta / (2 * np.pi / self.transport.n_rotations) |
|
itheta = np.int32(np.round(itheta)) % self.transport.n_rotations |
|
|
|
|
|
inp_img = inp['inp_img'] |
|
|
|
|
|
label_size = inp_img.shape[:3] + (self.transport.n_rotations,) |
|
label = torch.zeros(label_size, dtype=torch.float, device=output.device) |
|
|
|
|
|
q[:,0] = torch.clamp(q[:,0], 0, label.shape[1]-1) |
|
q[:,1] = torch.clamp(q[:,1], 0, label.shape[2]-1) |
|
|
|
for idx, q_i in enumerate(q): |
|
label[idx, int(q_i[0]), int(q_i[1]), itheta[idx]] = 1 |
|
label = label.permute((0, 3, 1, 2)).contiguous() |
|
|
|
|
|
loss = self.cross_entropy_with_logits(output, label) |
|
|
|
if backprop: |
|
transport_optim = self._optimizers['trans'] |
|
transport_optim.zero_grad() |
|
self.manual_backward(loss) |
|
transport_optim.step() |
|
|
|
|
|
err = {} |
|
if compute_err: |
|
with torch.no_grad(): |
|
place_conf = self.trans_forward(inp) |
|
|
|
place_conf = place_conf[0] |
|
q = q[0] |
|
theta = theta[0] |
|
place_conf = place_conf.permute(1, 2, 0) |
|
place_conf = place_conf.detach().cpu().numpy() |
|
argmax = np.argmax(place_conf) |
|
argmax = np.unravel_index(argmax, shape=place_conf.shape) |
|
p1_pix = argmax[:2] |
|
p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2]) |
|
|
|
err = { |
|
'dist': np.linalg.norm(np.array(q.detach().cpu().numpy()) - p1_pix, ord=1), |
|
'theta': np.absolute((theta - p1_theta) % np.pi) |
|
} |
|
|
|
self.transport.iters += 1 |
|
return err, loss |
|
|
|
def training_step(self, batch, batch_idx): |
|
|
|
self.attention.train() |
|
self.transport.train() |
|
|
|
frame, _ = batch |
|
self.start_time = time.time() |
|
|
|
|
|
step = self.total_steps + 1 |
|
loss0, err0 = self.attn_training_step(frame) |
|
|
|
self.start_time = time.time() |
|
|
|
if isinstance(self.transport, Attention): |
|
loss1, err1 = self.attn_training_step(frame) |
|
else: |
|
loss1, err1 = self.transport_training_step(frame) |
|
|
|
total_loss = loss0 + loss1 |
|
self.total_steps = step |
|
self.start_time = time.time() |
|
self.log('tr/attn/loss', loss0) |
|
self.log('tr/trans/loss', loss1) |
|
self.log('tr/loss', total_loss) |
|
self.check_save_iteration() |
|
|
|
return dict( |
|
loss=total_loss, |
|
) |
|
|
|
def check_save_iteration(self): |
|
global_step = self.total_steps |
|
|
|
if (global_step + 1) % 100 == 0: |
|
|
|
print(f"Saving last.ckpt Epoch: {self.trainer.current_epoch} | Global Step: {self.trainer.global_step}") |
|
self.save_last_checkpoint() |
|
|
|
def save_last_checkpoint(self): |
|
checkpoint_path = os.path.join(self.cfg['train']['train_dir'], 'checkpoints') |
|
ckpt_path = os.path.join(checkpoint_path, 'last.ckpt') |
|
self.trainer.save_checkpoint(ckpt_path) |
|
|
|
def validation_step(self, batch, batch_idx): |
|
self.attention.eval() |
|
self.transport.eval() |
|
|
|
loss0, loss1 = 0, 0 |
|
assert self.val_repeats >= 1 |
|
for i in range(self.val_repeats): |
|
frame, _ = batch |
|
l0, err0 = self.attn_training_step(frame, backprop=False, compute_err=True) |
|
loss0 += l0 |
|
if isinstance(self.transport, Attention): |
|
l1, err1 = self.attn_training_step(frame, backprop=False, compute_err=True) |
|
loss1 += l1 |
|
else: |
|
l1, err1 = self.transport_training_step(frame, backprop=False, compute_err=True) |
|
loss1 += l1 |
|
loss0 /= self.val_repeats |
|
loss1 /= self.val_repeats |
|
val_total_loss = loss0 + loss1 |
|
|
|
return dict( |
|
val_loss=val_total_loss, |
|
val_loss0=loss0, |
|
val_loss1=loss1, |
|
val_attn_dist_err=err0['dist'], |
|
val_attn_theta_err=err0['theta'], |
|
val_trans_dist_err=err1['dist'], |
|
val_trans_theta_err=err1['theta'], |
|
) |
|
|
|
def training_epoch_end(self, all_outputs): |
|
super().training_epoch_end(all_outputs) |
|
utils.set_seed(self.trainer.current_epoch+1) |
|
|
|
def validation_epoch_end(self, all_outputs): |
|
mean_val_total_loss = np.mean([v['val_loss'].item() for v in all_outputs]) |
|
mean_val_loss0 = np.mean([v['val_loss0'].item() for v in all_outputs]) |
|
mean_val_loss1 = np.mean([v['val_loss1'].item() for v in all_outputs]) |
|
total_attn_dist_err = np.sum([v['val_attn_dist_err'].sum() for v in all_outputs]) |
|
total_attn_theta_err = np.sum([v['val_attn_theta_err'].sum() for v in all_outputs]) |
|
total_trans_dist_err = np.sum([v['val_trans_dist_err'].sum() for v in all_outputs]) |
|
total_trans_theta_err = np.sum([v['val_trans_theta_err'].sum() for v in all_outputs]) |
|
|
|
|
|
self.log('vl/attn/loss', mean_val_loss0) |
|
self.log('vl/trans/loss', mean_val_loss1) |
|
self.log('vl/loss', mean_val_total_loss) |
|
self.log('vl/total_attn_dist_err', total_attn_dist_err) |
|
self.log('vl/total_attn_theta_err', total_attn_theta_err) |
|
self.log('vl/total_trans_dist_err', total_trans_dist_err) |
|
self.log('vl/total_trans_theta_err', total_trans_theta_err) |
|
|
|
print("\nAttn Err - Dist: {:.2f}, Theta: {:.2f}".format(total_attn_dist_err, total_attn_theta_err)) |
|
print("Transport Err - Dist: {:.2f}, Theta: {:.2f}".format(total_trans_dist_err, total_trans_theta_err)) |
|
|
|
return dict( |
|
val_loss=mean_val_total_loss, |
|
val_loss0=mean_val_loss0, |
|
mean_val_loss1=mean_val_loss1, |
|
total_attn_dist_err=total_attn_dist_err, |
|
total_attn_theta_err=total_attn_theta_err, |
|
total_trans_dist_err=total_trans_dist_err, |
|
total_trans_theta_err=total_trans_theta_err, |
|
) |
|
|
|
def act(self, obs, info=None, goal=None): |
|
"""Run inference and return best action given visual observations.""" |
|
|
|
img = self.test_ds.get_image(obs) |
|
|
|
|
|
pick_inp = {'inp_img': img} |
|
pick_conf = self.attn_forward(pick_inp) |
|
|
|
|
|
pick_conf = pick_conf.detach().cpu().numpy() |
|
argmax = np.argmax(pick_conf) |
|
argmax = np.unravel_index(argmax, shape=pick_conf.shape) |
|
p0_pix = argmax[:2] |
|
p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2]) |
|
|
|
|
|
place_inp = {'inp_img': img, 'p0': p0_pix} |
|
place_conf = self.trans_forward(place_inp) |
|
place_conf = place_conf.permute(1, 2, 0) |
|
place_conf = place_conf.detach().cpu().numpy() |
|
argmax = np.argmax(place_conf) |
|
argmax = np.unravel_index(argmax, shape=place_conf.shape) |
|
p1_pix = argmax[:2] |
|
p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2]) |
|
|
|
|
|
hmap = img[:, :, 3] |
|
p0_xyz = utils.pix_to_xyz(p0_pix, hmap, self.bounds, self.pix_size) |
|
p1_xyz = utils.pix_to_xyz(p1_pix, hmap, self.bounds, self.pix_size) |
|
p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta)) |
|
p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta)) |
|
|
|
return { |
|
'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)), |
|
'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)), |
|
'pick': p0_pix, |
|
'place': p1_pix, |
|
} |
|
|
|
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_i, second_order_closure, on_tpu, using_native_amp, using_lbfgs): |
|
pass |
|
|
|
def configure_optimizers(self): |
|
pass |
|
|
|
def train_dataloader(self): |
|
return self.train_loader |
|
|
|
def val_dataloader(self): |
|
return self.test_loader |
|
|
|
def load(self, model_path): |
|
self.load_state_dict(torch.load(model_path)['state_dict']) |
|
self.to(device=self.device_type) |
|
|
|
|
|
class OriginalTransporterAgent(TransporterAgent): |
|
|
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_fcn = 'plain_resnet' |
|
self.attention = Attention( |
|
stream_fcn=(stream_fcn, None), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = Transport( |
|
stream_fcn=(stream_fcn, None), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
|
|
class ClipUNetTransporterAgent(TransporterAgent): |
|
|
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_fcn = 'clip_unet' |
|
self.attention = Attention( |
|
stream_fcn=(stream_fcn, None), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = Transport( |
|
stream_fcn=(stream_fcn, None), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
|
|
class TwoStreamClipUNetTransporterAgent(TransporterAgent): |
|
|
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_one_fcn = 'plain_resnet' |
|
stream_two_fcn = 'clip_unet' |
|
self.attention = TwoStreamAttention( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = TwoStreamTransport( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
|
|
class TwoStreamClipUNetLatTransporterAgent(TransporterAgent): |
|
|
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_one_fcn = 'plain_resnet_lat' |
|
stream_two_fcn = 'clip_unet_lat' |
|
self.attention = TwoStreamAttentionLat( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = TwoStreamTransportLat( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
|
|
class TwoStreamClipWithoutSkipsTransporterAgent(TransporterAgent): |
|
|
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
|
|
stream_one_fcn = 'plain_resnet' |
|
stream_two_fcn = 'clip_woskip' |
|
self.attention = TwoStreamAttention( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = TwoStreamTransport( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
|
|
class TwoStreamRN50BertUNetTransporterAgent(TransporterAgent): |
|
|
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
|
|
stream_one_fcn = 'plain_resnet' |
|
stream_two_fcn = 'rn50_bert_unet' |
|
self.attention = TwoStreamAttention( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = TwoStreamTransport( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|