GenSim2 / cliport /agents /transporter_lang_goal.py
gensim2's picture
init
ff66cf3
import numpy as np
from cliport.utils import utils
from cliport.agents.transporter import TransporterAgent
from cliport.models.streams.one_stream_attention_lang_fusion import OneStreamAttentionLangFusion
from cliport.models.streams.one_stream_transport_lang_fusion import OneStreamTransportLangFusion
from cliport.models.streams.two_stream_attention_lang_fusion import TwoStreamAttentionLangFusion
from cliport.models.streams.two_stream_transport_lang_fusion import TwoStreamTransportLangFusion, TwoStreamTransportLangFusionLatReduce, TwoStreamTransportLangFusionLatPretrained18
from cliport.models.streams.two_stream_attention_lang_fusion import TwoStreamAttentionLangFusionLat, TwoStreamAttentionLangFusionLatReduce
from cliport.models.streams.two_stream_transport_lang_fusion import TwoStreamTransportLangFusionLatReduceOneStream
from cliport.models.streams.two_stream_transport_lang_fusion import TwoStreamTransportLangFusionLat
import torch
import time
class TwoStreamClipLingUNetTransporterAgent(TransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_one_fcn = 'plain_resnet'
stream_two_fcn = 'clip_lingunet'
self.attention = TwoStreamAttentionLangFusion(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TwoStreamTransportLangFusion(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
def attn_forward(self, inp, softmax=True):
inp_img = inp['inp_img']
if type(inp_img) is not torch.Tensor:
inp_img = torch.from_numpy(inp_img).to('cuda').float().contiguous()
lang_goal = inp['lang_goal']
out = self.attention.forward(inp_img.float(), lang_goal, softmax=softmax)
return out
def attn_training_step(self, frame, backprop=True, compute_err=False):
inp_img = frame['img']
if type(inp_img) is not torch.Tensor:
inp_img = torch.from_numpy(inp_img).to('cuda').float()
p0, p0_theta = frame['p0'], frame['p0_theta']
lang_goal = frame['lang_goal']
inp = {'inp_img': inp_img, 'lang_goal': lang_goal}
out = self.attn_forward(inp, softmax=False)
return self.attn_criterion(backprop, compute_err, inp, out, p0, p0_theta)
def trans_forward(self, inp, softmax=True):
inp_img = inp['inp_img']
if type(inp_img) is not torch.Tensor:
inp_img = torch.from_numpy(inp_img).to('cuda').float()
p0 = inp['p0']
lang_goal = inp['lang_goal']
out = self.transport.forward(inp_img.float(), p0, lang_goal, softmax=softmax)
return out
def transport_training_step(self, frame, backprop=True, compute_err=False):
inp_img = frame['img']
p0 = frame['p0']
p1, p1_theta = frame['p1'], frame['p1_theta']
lang_goal = frame['lang_goal']
inp = {'inp_img': inp_img, 'p0': p0, 'lang_goal': lang_goal}
out = self.trans_forward(inp, softmax=False)
err, loss = self.transport_criterion(backprop, compute_err, inp, out, p0, p1, p1_theta)
return loss, err
def act(self, obs, info, goal=None): # pylint: disable=unused-argument
"""Run inference and return best action given visual observations."""
# Get heightmap from RGB-D images.
img = self.test_ds.get_image(obs)
lang_goal = info['lang_goal']
# Attention model forward pass.
pick_inp = {'inp_img': img, 'lang_goal': lang_goal}
pick_conf = self.attn_forward(pick_inp)
pick_conf = pick_conf[0].permute(1, 2, 0).detach().cpu().numpy()
#
argmax = np.argmax(pick_conf)
# import IPython; IPython.embed()
argmax = np.unravel_index(argmax, shape=pick_conf.shape)
p0_pix = argmax[:2]
p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2])
# Transport model forward pass.
place_inp = {'inp_img': img, 'p0': p0_pix, 'lang_goal': lang_goal}
place_conf = self.trans_forward(place_inp)
place_conf = place_conf.squeeze().permute(1, 2, 0)
place_conf = place_conf.detach().cpu().numpy()
argmax = np.argmax(place_conf)
argmax = np.unravel_index(argmax, shape=place_conf.shape)
p1_pix = argmax[:2]
p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2])
# Pixels to end effector poses.
hmap = img[:, :, 3]
p0_xyz = utils.pix_to_xyz(p0_pix, hmap, self.bounds, self.pix_size)
p1_xyz = utils.pix_to_xyz(p1_pix, hmap, self.bounds, self.pix_size)
p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta))
p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta))
return {
'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)),
'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)),
'pick': [p0_pix[0], p0_pix[1], p0_theta],
'place': [p1_pix[0], p1_pix[1], p1_theta],
}
def real_act(self, obs, info, goal=None):
"""Run inference and return best action given real images."""
img = obs
lang_goal = info['lang_goal']
# Attention model forward pass.
pick_inp = {'inp_img': img, 'lang_goal': lang_goal}
pick_conf = self.attn_forward(pick_inp)
pick_conf = pick_conf[0].permute(1, 2, 0).detach().cpu().numpy()
#
argmax = np.argmax(pick_conf)
# import IPython; IPython.embed()
argmax = np.unravel_index(argmax, shape=pick_conf.shape)
p0_pix = argmax[:2]
p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2])
# Transport model forward pass.
place_inp = {'inp_img': img, 'p0': p0_pix, 'lang_goal': lang_goal}
place_conf = self.trans_forward(place_inp)
place_conf = place_conf.squeeze().permute(1, 2, 0)
place_conf = place_conf.detach().cpu().numpy()
argmax = np.argmax(place_conf)
argmax = np.unravel_index(argmax, shape=place_conf.shape)
p1_pix = argmax[:2]
p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2])
# Pixels to end effector poses.
hmap = img[:, :, 3]
p0_xyz = utils.pix_to_xyz(p0_pix, hmap, self.bounds, self.pix_size)
p1_xyz = utils.pix_to_xyz(p1_pix, hmap, self.bounds, self.pix_size)
p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta))
p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta))
return {
'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)),
'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)),
'pick': [p0_pix[0], p0_pix[1], p0_theta],
'place': [p1_pix[0], p1_pix[1], p1_theta],
}
class TwoStreamClipFilmLingUNetLatTransporterAgent(TwoStreamClipLingUNetTransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_one_fcn = 'plain_resnet_lat'
stream_two_fcn = 'clip_film_lingunet_lat'
self.attention = TwoStreamAttentionLangFusionLat(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TwoStreamTransportLangFusionLat(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class TwoStreamClipLingUNetLatTransporterAgent(TwoStreamClipLingUNetTransporterAgent): # This is our model
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_one_fcn = 'plain_resnet_lat'
stream_two_fcn = 'clip_lingunet_lat'
self.attention = TwoStreamAttentionLangFusionLat(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TwoStreamTransportLangFusionLat(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class TwoStreamMdetrLingUNetLatTransporterAgent(TwoStreamClipLingUNetTransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_one_fcn = 'plain_resnet_lat_origin'
stream_two_fcn = 'mdetr_lingunet_lat_fuse'
self.attention = TwoStreamAttentionLangFusionLat(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TwoStreamTransportLangFusionLat(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class TwoStreamClipLingUNetLatTransporterAgentReduce(TwoStreamClipLingUNetTransporterAgent): # This is our model
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_one_fcn = 'plain_resnet_lat'
stream_two_fcn = 'clip_lingunet_lat'
self.attention = TwoStreamAttentionLangFusionLat(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TwoStreamTransportLangFusionLatReduce(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class TwoStreamClipLingUNetLatTransporterAgentReduceOneStream(TwoStreamClipLingUNetTransporterAgent): # This is our model
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_one_fcn = 'plain_resnet_lat'
stream_two_fcn = 'clip_lingunet_lat'
self.attention = TwoStreamAttentionLangFusionLatReduce(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TwoStreamTransportLangFusionLatReduceOneStream(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class TwoStreamClipLingUNetLatTransporterAgentReducePretrained(TwoStreamClipLingUNetTransporterAgent): # This is our model
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_one_fcn = 'plain_resnet_lat'
stream_two_fcn = 'clip_lingunet_lat'
self.attention = TwoStreamAttentionLangFusionLat(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TwoStreamTransportLangFusionLatPretrained18(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class TwoStreamRN50BertLingUNetTransporterAgent(TwoStreamClipLingUNetTransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_one_fcn = 'plain_resnet'
stream_two_fcn = 'rn50_bert_lingunet'
self.attention = TwoStreamAttentionLangFusion(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TwoStreamTransportLangFusion(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class TwoStreamUntrainedRN50BertLingUNetTransporterAgent(TwoStreamClipLingUNetTransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_one_fcn = 'plain_resnet'
stream_two_fcn = 'untrained_rn50_bert_lingunet'
self.attention = TwoStreamAttentionLangFusion(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TwoStreamTransportLangFusion(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class TwoStreamRN50BertLingUNetLatTransporterAgent(TwoStreamClipLingUNetTransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_one_fcn = 'plain_resnet_lat'
stream_two_fcn = 'rn50_bert_lingunet_lat'
self.attention = TwoStreamAttentionLangFusionLat(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = TwoStreamTransportLangFusionLat(
stream_fcn=(stream_one_fcn, stream_two_fcn),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class OriginalTransporterLangFusionAgent(TwoStreamClipLingUNetTransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_fcn = 'plain_resnet_lang'
self.attention = OneStreamAttentionLangFusion(
stream_fcn=(stream_fcn, None),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = OneStreamTransportLangFusion(
stream_fcn=(stream_fcn, None),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
class ClipLingUNetTransporterAgent(TwoStreamClipLingUNetTransporterAgent):
def __init__(self, name, cfg, train_ds, test_ds):
super().__init__(name, cfg, train_ds, test_ds)
def _build_model(self):
stream_fcn = 'clip_lingunet'
self.attention = OneStreamAttentionLangFusion(
stream_fcn=(stream_fcn, None),
in_shape=self.in_shape,
n_rotations=1,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)
self.transport = OneStreamTransportLangFusion(
stream_fcn=(stream_fcn, None),
in_shape=self.in_shape,
n_rotations=self.n_rotations,
crop_size=self.crop_size,
preprocess=utils.preprocess,
cfg=self.cfg,
device=self.device_type,
)