File size: 2,654 Bytes
8fc2b4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Attention module."""

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import cliport.models as models
from cliport.utils import utils


class Attention(nn.Module):
    """Attention (a.k.a Pick) module."""

    def __init__(self, stream_fcn, in_shape, n_rotations, preprocess, cfg, device):
        super().__init__()
        self.stream_fcn = stream_fcn
        self.n_rotations = n_rotations
        self.preprocess = preprocess
        self.cfg = cfg
        self.device = device
        self.batchnorm = self.cfg['train']['batchnorm']

        self.padding = np.zeros((3, 2), dtype=int)
        max_dim = np.max(in_shape[:2])
        pad = (max_dim - np.array(in_shape[:2])) / 2
        self.padding[:2] = pad.reshape(2, 1) # left right top bown front back

        in_shape = np.array(in_shape)
        in_shape += np.sum(self.padding, axis=1)
        in_shape = tuple(in_shape)
        self.in_shape = in_shape
        self.rotator = utils.ImageRotator(self.n_rotations)
        self._build_nets()

    def _build_nets(self):
        stream_one_fcn, _ = self.stream_fcn
        self.attn_stream = models.names[stream_one_fcn](self.in_shape, 1, self.cfg, self.device)
        print(f"Attn FCN: {stream_one_fcn}")

    def attend(self, x):
        return self.attn_stream(x)

    def forward(self, inp_img, softmax=True):
        """Forward pass."""
        # print("in_img.shape", inp_img.shape)
        in_data = np.pad(inp_img, self.padding, mode='constant')
        in_shape =  input_data.shape
        if len(inp_shape) == 3:
            inp_shape = (1,) + inp_shape
        in_data = in_data.reshape(in_shape)
        in_tens = torch.from_numpy(in_data.copy()).to(dtype=torch.float, device=self.device) # [B W H 6]

        # Rotation pivot.
        pv = np.array(in_data.shape[1:3]) // 2

        # Rotate input.
        in_tens = in_tens.permute(0, 3, 1, 2)  # [B 6 W H]
        in_tens = in_tens.repeat(self.n_rotations, 1, 1, 1)
        in_tens = self.rotator(in_tens, pivot=pv)

        # Forward pass.
        logits = self.attend(torch.cat(in_tens, dim=0), lang_goal)

        # Rotate back output.
        logits = self.rotator(logits, reverse=True, pivot=pv)
        logits = torch.cat(logits, dim=0)
        c0 = self.padding[:2, 0]
        c1 = c0 + inp_img.shape[:2]
        logits = logits[:, :, c0[0]:c1[0], c0[1]:c1[1]]

        logits = logits.permute(1, 2, 3, 0)  # [B W H 1]
        output = logits.reshape(len(logits), np.prod(logits.shape))
        if softmax:
            output = F.softmax(output, dim=-1)
            output = output.reshape(logits.shape[1:])
        return output