""" code from https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py This file contains the Definition of GraphCNN GraphCNN includes ResNet50 as a submodule """ from __future__ import division import torch import torch.nn as nn # from .resnet import resnet50 import torchvision.models as models import os import sys sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..')) from src.graph_networks.graphcmr.utils_mesh import Mesh from src.graph_networks.graphcmr.graph_layers import GraphResBlock, GraphLinear class GraphCNN(nn.Module): def __init__(self, A, ref_vertices, n_resnet_in, n_resnet_out, num_layers=5, num_channels=512): super(GraphCNN, self).__init__() self.A = A self.ref_vertices = ref_vertices # self.resnet = resnet50(pretrained=True) # -> within the GraphCMR network they ignore the last fully connected layer # replace the first layer self.resnet = models.resnet34(pretrained=False) n_in = 3 + 1 self.resnet.conv1 = nn.Conv2d(n_resnet_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # replace the last layer self.resnet.fc = nn.Linear(512, n_resnet_out) layers = [GraphLinear(3 + n_resnet_out, 2 * num_channels)] # [GraphLinear(3 + 2048, 2 * num_channels)] layers.append(GraphResBlock(2 * num_channels, num_channels, A)) for i in range(num_layers): layers.append(GraphResBlock(num_channels, num_channels, A)) self.n_out_gc = 2 # two labels per vertex self.gc = nn.Sequential(GraphResBlock(num_channels, 64, A), GraphResBlock(64, 32, A), nn.GroupNorm(32 // 8, 32), nn.ReLU(inplace=True), GraphLinear(32, self.n_out_gc)) self.gcnn = nn.Sequential(*layers) self.n_out_flatground = 1 self.flat_ground = nn.Sequential(nn.GroupNorm(num_channels // 8, num_channels), nn.ReLU(inplace=True), GraphLinear(num_channels, 1), nn.ReLU(inplace=True), nn.Linear(A.shape[0], self.n_out_flatground)) def forward(self, image): """Forward pass Inputs: image: size = (B, 3, 256, 256) Returns: Regressed (subsampled) non-parametric shape: size = (B, 1723, 3) Weak-perspective camera: size = (B, 3) """ # import pdb; pdb.set_trace() batch_size = image.shape[0] ref_vertices = self.ref_vertices[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973) image_resnet = self.resnet(image) # (bs, 512) image_enc = image_resnet.view(batch_size, -1, 1).expand(-1, -1, ref_vertices.shape[-1]) # (bs, 512, 973) x = torch.cat([ref_vertices, image_enc], dim=1) x = self.gcnn(x) # (bs, 512, 973) ground_contact = self.gc(x) # (bs, 2, 973) ground_flatness = self.flat_ground(x).view(batch_size, self.n_out_flatground) # (bs, 1) return ground_contact, ground_flatness # how to use it: # # from src.graph_networks.graphcmr.utils_mesh import Mesh # # create Mesh object # self.mesh = Mesh() # self.faces = self.mesh.faces.to(self.device) # # create GraphCNN # self.graph_cnn = GraphCNN(self.mesh.adjmat, # self.mesh.ref_vertices.t(), # num_channels=self.options.num_channels, # num_layers=self.options.num_layers # ).to(self.device) # ------------ # # Feed image in the GraphCNN # Returns subsampled mesh and camera parameters # pred_vertices_sub, pred_camera = self.graph_cnn(images) # # Upsample mesh in the original size # pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1,2)) #