Zai
test
06db6e9
#!/usr/bin/python
#
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from .layers import build_mlp
"""
PyTorch modules for dealing with graphs.
"""
def _init_weights(module):
if hasattr(module, 'weight'):
if isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight)
class GraphTripleConv(nn.Module):
"""
A single layer of scene graph convolution.
"""
def __init__(self, input_dim, attributes_dim=0, output_dim=None, hidden_dim=512,
pooling='avg', mlp_normalization='none'):
super(GraphTripleConv, self).__init__()
if output_dim is None:
output_dim = input_dim
self.input_dim = input_dim
self.output_dim = output_dim
self.hidden_dim = hidden_dim
assert pooling in ['sum', 'avg'], 'Invalid pooling "%s"' % pooling
self.pooling = pooling
net1_layers = [3 * input_dim + 2 * attributes_dim, hidden_dim, 2 * hidden_dim + output_dim]
net1_layers = [l for l in net1_layers if l is not None]
self.net1 = build_mlp(net1_layers, batch_norm=mlp_normalization)
self.net1.apply(_init_weights)
net2_layers = [hidden_dim, hidden_dim, output_dim]
self.net2 = build_mlp(net2_layers, batch_norm=mlp_normalization)
self.net2.apply(_init_weights)
def forward(self, obj_vecs, pred_vecs, edges):
"""
Inputs:
- obj_vecs: FloatTensor of shape (O, D) giving vectors for all objects
- pred_vecs: FloatTensor of shape (T, D) giving vectors for all predicates
- edges: LongTensor of shape (T, 2) where edges[k] = [i, j] indicates the
presence of a triple [obj_vecs[i], pred_vecs[k], obj_vecs[j]]
Outputs:
- new_obj_vecs: FloatTensor of shape (O, D) giving new vectors for objects
- new_pred_vecs: FloatTensor of shape (T, D) giving new vectors for predicates
"""
dtype, device = obj_vecs.dtype, obj_vecs.device
O, T = obj_vecs.size(0), pred_vecs.size(0)
Din, H, Dout = self.input_dim, self.hidden_dim, self.output_dim
# Break apart indices for subjects and objects; these have shape (T,)
s_idx = edges[:, 0].contiguous()
o_idx = edges[:, 1].contiguous()
# Get current vectors for subjects and objects; these have shape (T, Din)
cur_s_vecs = obj_vecs[s_idx]
cur_o_vecs = obj_vecs[o_idx]
# Get current vectors for triples; shape is (T, 3 * Din)
# Pass through net1 to get new triple vecs; shape is (T, 2 * H + Dout)
cur_t_vecs = torch.cat([cur_s_vecs, pred_vecs, cur_o_vecs], dim=1)
new_t_vecs = self.net1(cur_t_vecs)
# Break apart into new s, p, and o vecs; s and o vecs have shape (T, H) and
# p vecs have shape (T, Dout)
new_s_vecs = new_t_vecs[:, :H]
new_p_vecs = new_t_vecs[:, H:(H + Dout)]
new_o_vecs = new_t_vecs[:, (H + Dout):(2 * H + Dout)]
# Allocate space for pooled object vectors of shape (O, H)
pooled_obj_vecs = torch.zeros(O, H, dtype=dtype, device=device)
# Use scatter_add to sum vectors for objects that appear in multiple triples;
# we first need to expand the indices to have shape (T, D)
s_idx_exp = s_idx.view(-1, 1).expand_as(new_s_vecs)
o_idx_exp = o_idx.view(-1, 1).expand_as(new_o_vecs)
pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, s_idx_exp, new_s_vecs)
pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, o_idx_exp, new_o_vecs)
if self.pooling == 'avg':
# Figure out how many times each object has appeared, again using
# some scatter_add trickery.
obj_counts = torch.zeros(O, dtype=dtype, device=device)
ones = torch.ones(T, dtype=dtype, device=device)
obj_counts = obj_counts.scatter_add(0, s_idx, ones)
obj_counts = obj_counts.scatter_add(0, o_idx, ones)
# Divide the new object vectors by the number of times they
# appeared, but first clamp at 1 to avoid dividing by zero;
# objects that appear in no triples will have output vector 0
# so this will not affect them.
obj_counts = obj_counts.clamp(min=1)
pooled_obj_vecs = pooled_obj_vecs / obj_counts.view(-1, 1)
# Send pooled object vectors through net2 to get output object vectors,
# of shape (O, Dout)
new_obj_vecs = self.net2(pooled_obj_vecs)
return new_obj_vecs, new_p_vecs
class GraphTripleConvNet(nn.Module):
""" A sequence of scene graph convolution layers """
def __init__(self, input_dim, num_layers=5, hidden_dim=512, pooling='avg',
mlp_normalization='none'):
super(GraphTripleConvNet, self).__init__()
self.num_layers = num_layers
self.gconvs = nn.ModuleList()
gconv_kwargs = {
'input_dim': input_dim,
'hidden_dim': hidden_dim,
'pooling': pooling,
'mlp_normalization': mlp_normalization,
}
for _ in range(self.num_layers):
self.gconvs.append(GraphTripleConv(**gconv_kwargs))
def forward(self, obj_vecs, pred_vecs, edges):
for i in range(self.num_layers):
gconv = self.gconvs[i]
obj_vecs, pred_vecs = gconv(obj_vecs, pred_vecs, edges)
return obj_vecs, pred_vecs