Gen1 / cliport /agents /transporter_image_goal.py
gensim2's picture
init
ff66cf3
import numpy as np
from cliport.utils import utils
from cliport.agents.transporter import OriginalTransporterAgent
from cliport.models.core.attention import Attention
from cliport.models.core.attention_image_goal import AttentionImageGoal
from cliport.models.core.transport_image_goal import TransportImageGoal
class ImageGoalTransporterAgent(OriginalTransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_fcn = 'plain_resnet'
self.attention = AttentionImageGoal(
stream_fcn=(stream_fcn, None),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TransportImageGoal(
stream_fcn=(stream_fcn, None),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
def attn_forward(self, inp, softmax=True):
inp_img = inp['inp_img']
goal_img = inp['goal_img']
out = self.attention.forward(inp_img, goal_img, softmax=softmax)
return out
def attn_training_step(self, frame, goal, backprop=True, compute_err=False):
inp_img = frame['img']
goal_img = goal['img']
p0, p0_theta = frame['p0'], frame['p0_theta']
inp = {'inp_img': inp_img, 'goal_img': goal_img}
out = self.attn_forward(inp, softmax=False)
return self.attn_criterion(backprop, compute_err, inp, out, p0, p0_theta)
def trans_forward(self, inp, softmax=True):
inp_img = inp['inp_img']
goal_img = inp['goal_img']
p0 = inp['p0']
out = self.transport.forward(inp_img, goal_img, p0, softmax=softmax)
return out
def transport_training_step(self, frame, goal, backprop=True, compute_err=False):
inp_img = frame['img']
goal_img = goal['img']
p0 = frame['p0']
p1, p1_theta = frame['p1'], frame['p1_theta']
inp = {'inp_img': inp_img, 'goal_img': goal_img, 'p0': p0}
out = self.trans_forward(inp, softmax=False)
err, loss = self.transport_criterion(backprop, compute_err, inp, out, p0, p1, p1_theta)
return loss, err
def training_step(self, batch, batch_idx):
self.attention.train()
self.transport.train()
frame, goal = batch
# Get training losses.
step = self.total_steps + 1
loss0, err0 = self.attn_training_step(frame, goal)
if isinstance(self.transport, Attention):
loss1, err1 = self.attn_training_step(frame, goal)
else:
loss1, err1 = self.transport_training_step(frame, goal)
total_loss = loss0 + loss1
self.log('tr/attn/loss', loss0)
self.log('tr/trans/loss', loss1)
self.log('tr/loss', total_loss)
self.total_steps = step
self.trainer.train_loop.running_loss.append(total_loss)
self.check_save_iteration()
return dict(
loss=total_loss,
)
def validation_step(self, batch, batch_idx):
self.attention.eval()
self.transport.eval()
loss0, loss1 = 0, 0
for i in range(self.val_repeats):
frame, goal = batch
l0, err0 = self.attn_training_step(frame, goal, backprop=False, compute_err=True)
loss0 += l0
if isinstance(self.transport, Attention):
l1, err1 = self.attn_training_step(frame, goal, backprop=False, compute_err=True)
loss1 += l1
else:
l1, err1 = self.transport_training_step(frame, goal, backprop=False, compute_err=True)
loss1 += l1
loss0 /= self.val_repeats
loss1 /= self.val_repeats
val_total_loss = loss0 + loss1
self.trainer.evaluation_loop.trainer.train_loop.running_loss.append(val_total_loss)
return dict(
val_loss=val_total_loss,
val_loss0=loss0,
val_loss1=loss1,
val_attn_dist_err=err0['dist'],
val_attn_theta_err=err0['theta'],
val_trans_dist_err=err1['dist'],
val_trans_theta_err=err1['theta'],
)
def act(self, obs, info=None, goal=None): # pylint: disable=unused-argument
"""Run inference and return best action given visual observations."""
# Get heightmap from RGB-D images.
img = self.test_ds.get_image(obs)
goal_img = self.test_ds.get_image(goal[0])
# Attention model forward pass.
pick_conf = self.attention.forward(img, goal_img)
pick_conf = pick_conf.detach().cpu().numpy()
argmax = np.argmax(pick_conf)
argmax = np.unravel_index(argmax, shape=pick_conf.shape)
p0_pix = argmax[:2]
p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2])
# Transport model forward pass.
place_conf = self.transport.forward(img, goal_img, p0_pix)
place_conf = place_conf.permute(1, 2, 0)
place_conf = place_conf.detach().cpu().numpy()
argmax = np.argmax(place_conf)
argmax = np.unravel_index(argmax, shape=place_conf.shape)
p1_pix = argmax[:2]
p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2])
# Pixels to end effector poses.
hmap = img[:, :, 3]
p0_xyz = utils.pix_to_xyz(p0_pix, hmap, self.bounds, self.pix_size)
p1_xyz = utils.pix_to_xyz(p1_pix, hmap, self.bounds, self.pix_size)
p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta))
p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta))
return {
'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)),
'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)),
'pick': p0_pix,
'place': p1_pix,
}