|
import numpy as np |
|
|
|
from cliport.utils import utils |
|
from cliport.agents.transporter import TransporterAgent |
|
|
|
from cliport.models.streams.one_stream_attention_lang_fusion import OneStreamAttentionLangFusion |
|
from cliport.models.streams.one_stream_transport_lang_fusion import OneStreamTransportLangFusion |
|
from cliport.models.streams.two_stream_attention_lang_fusion import TwoStreamAttentionLangFusion |
|
from cliport.models.streams.two_stream_transport_lang_fusion import TwoStreamTransportLangFusion, TwoStreamTransportLangFusionLatReduce, TwoStreamTransportLangFusionLatPretrained18 |
|
from cliport.models.streams.two_stream_attention_lang_fusion import TwoStreamAttentionLangFusionLat, TwoStreamAttentionLangFusionLatReduce |
|
|
|
from cliport.models.streams.two_stream_transport_lang_fusion import TwoStreamTransportLangFusionLatReduceOneStream |
|
from cliport.models.streams.two_stream_transport_lang_fusion import TwoStreamTransportLangFusionLat |
|
import torch |
|
import time |
|
|
|
|
|
class TwoStreamClipLingUNetTransporterAgent(TransporterAgent): |
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_one_fcn = 'plain_resnet' |
|
stream_two_fcn = 'clip_lingunet' |
|
self.attention = TwoStreamAttentionLangFusion( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = TwoStreamTransportLangFusion( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
def attn_forward(self, inp, softmax=True): |
|
inp_img = inp['inp_img'] |
|
if type(inp_img) is not torch.Tensor: |
|
inp_img = torch.from_numpy(inp_img).to('cuda').float().contiguous() |
|
lang_goal = inp['lang_goal'] |
|
|
|
out = self.attention.forward(inp_img.float(), lang_goal, softmax=softmax) |
|
return out |
|
|
|
def attn_training_step(self, frame, backprop=True, compute_err=False): |
|
inp_img = frame['img'] |
|
if type(inp_img) is not torch.Tensor: |
|
inp_img = torch.from_numpy(inp_img).to('cuda').float() |
|
p0, p0_theta = frame['p0'], frame['p0_theta'] |
|
lang_goal = frame['lang_goal'] |
|
|
|
inp = {'inp_img': inp_img, 'lang_goal': lang_goal} |
|
out = self.attn_forward(inp, softmax=False) |
|
return self.attn_criterion(backprop, compute_err, inp, out, p0, p0_theta) |
|
|
|
def trans_forward(self, inp, softmax=True): |
|
inp_img = inp['inp_img'] |
|
if type(inp_img) is not torch.Tensor: |
|
inp_img = torch.from_numpy(inp_img).to('cuda').float() |
|
p0 = inp['p0'] |
|
lang_goal = inp['lang_goal'] |
|
out = self.transport.forward(inp_img.float(), p0, lang_goal, softmax=softmax) |
|
return out |
|
|
|
def transport_training_step(self, frame, backprop=True, compute_err=False): |
|
inp_img = frame['img'] |
|
p0 = frame['p0'] |
|
p1, p1_theta = frame['p1'], frame['p1_theta'] |
|
lang_goal = frame['lang_goal'] |
|
|
|
inp = {'inp_img': inp_img, 'p0': p0, 'lang_goal': lang_goal} |
|
out = self.trans_forward(inp, softmax=False) |
|
err, loss = self.transport_criterion(backprop, compute_err, inp, out, p0, p1, p1_theta) |
|
return loss, err |
|
|
|
def act(self, obs, info, goal=None): |
|
"""Run inference and return best action given visual observations.""" |
|
|
|
img = self.test_ds.get_image(obs) |
|
lang_goal = info['lang_goal'] |
|
|
|
|
|
pick_inp = {'inp_img': img, 'lang_goal': lang_goal} |
|
pick_conf = self.attn_forward(pick_inp) |
|
pick_conf = pick_conf[0].permute(1, 2, 0).detach().cpu().numpy() |
|
|
|
argmax = np.argmax(pick_conf) |
|
|
|
argmax = np.unravel_index(argmax, shape=pick_conf.shape) |
|
p0_pix = argmax[:2] |
|
p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2]) |
|
|
|
|
|
place_inp = {'inp_img': img, 'p0': p0_pix, 'lang_goal': lang_goal} |
|
place_conf = self.trans_forward(place_inp) |
|
place_conf = place_conf.squeeze().permute(1, 2, 0) |
|
place_conf = place_conf.detach().cpu().numpy() |
|
argmax = np.argmax(place_conf) |
|
argmax = np.unravel_index(argmax, shape=place_conf.shape) |
|
p1_pix = argmax[:2] |
|
p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2]) |
|
|
|
|
|
hmap = img[:, :, 3] |
|
p0_xyz = utils.pix_to_xyz(p0_pix, hmap, self.bounds, self.pix_size) |
|
p1_xyz = utils.pix_to_xyz(p1_pix, hmap, self.bounds, self.pix_size) |
|
p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta)) |
|
p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta)) |
|
|
|
return { |
|
'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)), |
|
'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)), |
|
'pick': [p0_pix[0], p0_pix[1], p0_theta], |
|
'place': [p1_pix[0], p1_pix[1], p1_theta], |
|
} |
|
|
|
|
|
def real_act(self, obs, info, goal=None): |
|
"""Run inference and return best action given real images.""" |
|
|
|
img = obs |
|
lang_goal = info['lang_goal'] |
|
|
|
pick_inp = {'inp_img': img, 'lang_goal': lang_goal} |
|
pick_conf = self.attn_forward(pick_inp) |
|
pick_conf = pick_conf[0].permute(1, 2, 0).detach().cpu().numpy() |
|
|
|
argmax = np.argmax(pick_conf) |
|
|
|
argmax = np.unravel_index(argmax, shape=pick_conf.shape) |
|
p0_pix = argmax[:2] |
|
p0_theta = argmax[2] * (2 * np.pi / pick_conf.shape[2]) |
|
|
|
|
|
place_inp = {'inp_img': img, 'p0': p0_pix, 'lang_goal': lang_goal} |
|
place_conf = self.trans_forward(place_inp) |
|
place_conf = place_conf.squeeze().permute(1, 2, 0) |
|
place_conf = place_conf.detach().cpu().numpy() |
|
argmax = np.argmax(place_conf) |
|
argmax = np.unravel_index(argmax, shape=place_conf.shape) |
|
p1_pix = argmax[:2] |
|
p1_theta = argmax[2] * (2 * np.pi / place_conf.shape[2]) |
|
|
|
|
|
hmap = img[:, :, 3] |
|
p0_xyz = utils.pix_to_xyz(p0_pix, hmap, self.bounds, self.pix_size) |
|
p1_xyz = utils.pix_to_xyz(p1_pix, hmap, self.bounds, self.pix_size) |
|
p0_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p0_theta)) |
|
p1_xyzw = utils.eulerXYZ_to_quatXYZW((0, 0, -p1_theta)) |
|
|
|
return { |
|
'pose0': (np.asarray(p0_xyz), np.asarray(p0_xyzw)), |
|
'pose1': (np.asarray(p1_xyz), np.asarray(p1_xyzw)), |
|
'pick': [p0_pix[0], p0_pix[1], p0_theta], |
|
'place': [p1_pix[0], p1_pix[1], p1_theta], |
|
} |
|
|
|
|
|
class TwoStreamClipFilmLingUNetLatTransporterAgent(TwoStreamClipLingUNetTransporterAgent): |
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_one_fcn = 'plain_resnet_lat' |
|
stream_two_fcn = 'clip_film_lingunet_lat' |
|
self.attention = TwoStreamAttentionLangFusionLat( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = TwoStreamTransportLangFusionLat( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
|
|
class TwoStreamClipLingUNetLatTransporterAgent(TwoStreamClipLingUNetTransporterAgent): |
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_one_fcn = 'plain_resnet_lat' |
|
stream_two_fcn = 'clip_lingunet_lat' |
|
self.attention = TwoStreamAttentionLangFusionLat( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = TwoStreamTransportLangFusionLat( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
class TwoStreamMdetrLingUNetLatTransporterAgent(TwoStreamClipLingUNetTransporterAgent): |
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_one_fcn = 'plain_resnet_lat_origin' |
|
stream_two_fcn = 'mdetr_lingunet_lat_fuse' |
|
|
|
self.attention = TwoStreamAttentionLangFusionLat( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = TwoStreamTransportLangFusionLat( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
|
|
|
|
|
|
class TwoStreamClipLingUNetLatTransporterAgentReduce(TwoStreamClipLingUNetTransporterAgent): |
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_one_fcn = 'plain_resnet_lat' |
|
stream_two_fcn = 'clip_lingunet_lat' |
|
self.attention = TwoStreamAttentionLangFusionLat( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = TwoStreamTransportLangFusionLatReduce( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
|
|
|
|
class TwoStreamClipLingUNetLatTransporterAgentReduceOneStream(TwoStreamClipLingUNetTransporterAgent): |
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_one_fcn = 'plain_resnet_lat' |
|
stream_two_fcn = 'clip_lingunet_lat' |
|
self.attention = TwoStreamAttentionLangFusionLatReduce( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = TwoStreamTransportLangFusionLatReduceOneStream( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
|
|
class TwoStreamClipLingUNetLatTransporterAgentReducePretrained(TwoStreamClipLingUNetTransporterAgent): |
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_one_fcn = 'plain_resnet_lat' |
|
stream_two_fcn = 'clip_lingunet_lat' |
|
self.attention = TwoStreamAttentionLangFusionLat( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = TwoStreamTransportLangFusionLatPretrained18( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
|
|
|
|
class TwoStreamRN50BertLingUNetTransporterAgent(TwoStreamClipLingUNetTransporterAgent): |
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_one_fcn = 'plain_resnet' |
|
stream_two_fcn = 'rn50_bert_lingunet' |
|
self.attention = TwoStreamAttentionLangFusion( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = TwoStreamTransportLangFusion( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
|
|
class TwoStreamUntrainedRN50BertLingUNetTransporterAgent(TwoStreamClipLingUNetTransporterAgent): |
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_one_fcn = 'plain_resnet' |
|
stream_two_fcn = 'untrained_rn50_bert_lingunet' |
|
self.attention = TwoStreamAttentionLangFusion( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = TwoStreamTransportLangFusion( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
|
|
class TwoStreamRN50BertLingUNetLatTransporterAgent(TwoStreamClipLingUNetTransporterAgent): |
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_one_fcn = 'plain_resnet_lat' |
|
stream_two_fcn = 'rn50_bert_lingunet_lat' |
|
self.attention = TwoStreamAttentionLangFusionLat( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = TwoStreamTransportLangFusionLat( |
|
stream_fcn=(stream_one_fcn, stream_two_fcn), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
|
|
class OriginalTransporterLangFusionAgent(TwoStreamClipLingUNetTransporterAgent): |
|
|
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_fcn = 'plain_resnet_lang' |
|
self.attention = OneStreamAttentionLangFusion( |
|
stream_fcn=(stream_fcn, None), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = OneStreamTransportLangFusion( |
|
stream_fcn=(stream_fcn, None), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
|
|
|
|
|
|
class ClipLingUNetTransporterAgent(TwoStreamClipLingUNetTransporterAgent): |
|
|
|
def __init__(self, name, cfg, train_ds, test_ds): |
|
super().__init__(name, cfg, train_ds, test_ds) |
|
|
|
def _build_model(self): |
|
stream_fcn = 'clip_lingunet' |
|
self.attention = OneStreamAttentionLangFusion( |
|
stream_fcn=(stream_fcn, None), |
|
in_shape=self.in_shape, |
|
n_rotations=1, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |
|
self.transport = OneStreamTransportLangFusion( |
|
stream_fcn=(stream_fcn, None), |
|
in_shape=self.in_shape, |
|
n_rotations=self.n_rotations, |
|
crop_size=self.crop_size, |
|
preprocess=utils.preprocess, |
|
cfg=self.cfg, |
|
device=self.device_type, |
|
) |