#!/usr/bin/python # # Copyright 2018 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licens8.0es/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math import torch import torch.nn as nn import torch.nn.functional as F from torchvision.ops import RoIAlign from . import box_utils from .graph import GraphTripleConv, GraphTripleConvNet from .layout import boxes_to_layout, masks_to_layout, boxes_to_seg, masks_to_seg from .layers import build_mlp,build_cnn from .utils import vocab class Model(nn.Module): def __init__(self, embedding_dim=128, image_size=(128,128), input_dim=3, attribute_dim=35, # graph_net gconv_dim=128, gconv_hidden_dim=512, gconv_num_layers=5, # inside_cnn inside_cnn_arch="C3-32-2,C3-64-2,C3-128-2,C3-256-2", # refinement_net refinement_dims=(1024, 512, 256, 128, 64), # box_refine box_refine_arch = "I15,C3-64-2,C3-128-2,C3-256-2", roi_output_size = (8,8), roi_spatial_scale = 1.0/8.0, roi_cat_feature = True, # others mlp_activation='leakyrelu', mlp_normalization='none', cnn_activation='leakyrelu', cnn_normalization='batch' ): super(Model, self).__init__() ''' embedding ''' self.vocab = vocab num_objs = len(vocab['object_idx_to_name']) num_preds = len(vocab['pred_idx_to_name']) num_doors = len(vocab['door_idx_to_name']) self.obj_embeddings = nn.Embedding(num_objs, embedding_dim) self.pred_embeddings = nn.Embedding(num_preds, embedding_dim) self.image_size = image_size self.feature_dim = embedding_dim+attribute_dim ''' graph_net ''' self.gconv = GraphTripleConv( embedding_dim, attributes_dim=attribute_dim, output_dim=gconv_dim, hidden_dim=gconv_hidden_dim, mlp_normalization=mlp_normalization ) self.gconv_net = GraphTripleConvNet( gconv_dim, num_layers=gconv_num_layers-1, mlp_normalization=mlp_normalization ) ''' inside_cnn ''' inside_cnn,inside_feat_dim = build_cnn( f'I{input_dim},{inside_cnn_arch}', padding='valid' ) self.inside_cnn = nn.Sequential( inside_cnn, nn.AdaptiveAvgPool2d(1) ) inside_output_dim = inside_feat_dim obj_vecs_dim = gconv_dim+inside_output_dim ''' box_net ''' box_net_dim = 4 box_net_layers = [obj_vecs_dim, gconv_hidden_dim, box_net_dim] self.box_net = build_mlp( box_net_layers, activation=mlp_activation, batch_norm=mlp_normalization ) ''' relationship_net ''' rel_aux_layers = [obj_vecs_dim, gconv_hidden_dim, num_doors] self.rel_aux_net = build_mlp( rel_aux_layers, activation=mlp_activation, batch_norm=mlp_normalization ) ''' refinement_net ''' if refinement_dims!=None: self.refinement_net,_ = build_cnn(f"I{obj_vecs_dim},C3-128,C3-64,C3-{num_objs}") else: self.refinement_net = None ''' roi ''' self.box_refine_backbone = None self.roi_cat_feature = roi_cat_feature if box_refine_arch!=None: box_refine_cnn,box_feat_dim = build_cnn( box_refine_arch, padding='valid' ) self.box_refine_backbone = box_refine_cnn self.roi_align = RoIAlign(roi_output_size,roi_spatial_scale,-1) #(256,8,8) self.down_sample = nn.AdaptiveAvgPool2d(1) box_refine_layers = [obj_vecs_dim+256 if self.roi_cat_feature else 256, 512, 4] self.box_reg =build_mlp( box_refine_layers, activation=mlp_activation, batch_norm=mlp_normalization ) def forward( self, objs, triples, boundary, obj_to_img=None, attributes=None, boxes_gt=None, generate=False, refine=False, relative=False, inside_box=None ): """ Required Inputs: - objs: LongTensor of shape (O,) giving categories for all objects - triples: LongTensor of shape (T, 3) where triples[t] = [s, p, o] means that there is a triple (objs[s], p, objs[o]) Optional Inputs: - obj_to_img: LongTensor of shape (O,) where obj_to_img[o] = i means that objects[o] is an object in image i. If not given then all objects are assumed to belong to the same image. - boxes_gt: FloatTensor of shape (O, 4) giving boxes to use for computing the spatial layout; if not given then use predicted boxes. """ # input size O, T = objs.size(0), triples.size(0) s, p, o = triples.chunk(3, dim=1) # All have shape (T, 1) s, p, o = [x.squeeze(1) for x in [s, p, o]] # Now have shape (T,) edges = torch.stack([s, o], dim=1) # Shape is (T, 2) B = boundary.size(0) H, W = self.image_size if obj_to_img is None: obj_to_img = torch.zeros(O, dtype=objs.dtype, device=objs.device) ''' embedding ''' obj_vecs = self.obj_embeddings(objs) pred_vecs = self.pred_embeddings(p) ''' attribute ''' if attributes is not None: obj_vecs = torch.cat([obj_vecs,attributes],1) obj_vecs_orig = obj_vecs ''' gconv ''' obj_vecs, pred_vecs = self.gconv(obj_vecs, pred_vecs, edges) obj_vecs, pred_vecs = self.gconv_net(obj_vecs, pred_vecs, edges) ''' inside ''' inside_vecs = self.inside_cnn(boundary).view(B,-1) obj_vecs = torch.cat([obj_vecs,inside_vecs[obj_to_img]],dim=1) ''' box ''' boxes_pred = self.box_net(obj_vecs) if relative: boxes_pred = box_utils.box_rel2abs(boxes_pred,inside_box,obj_to_img) ''' relation ''' # unused, for door position predition # rel_scores = self.rel_aux_net(obj_vecs) ''' generate ''' gene_layout = None boxes_refine = None layout_boxes = boxes_pred if boxes_gt is None else boxes_gt if generate: layout_features = boxes_to_layout(obj_vecs,layout_boxes,obj_to_img,H,W) gene_layout = self.refinement_net(layout_features) ''' box refine ''' if refine: gene_feat = self.box_refine_backbone(gene_layout) rois = torch.cat([ obj_to_img.float().view(-1,1), box_utils.centers_to_extents(layout_boxes)*H ],-1) roi_feat = self.down_sample(self.roi_align(gene_feat,rois)).flatten(1) roi_feat = torch.cat([ roi_feat, obj_vecs ],-1) boxes_refine = self.box_reg(roi_feat) if relative: boxes_refine = box_utils.box_rel2abs(boxes_refine,inside_box,obj_to_img) return boxes_pred, gene_layout, boxes_refine