File size: 3,991 Bytes
753fd9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
code from https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py
This file contains the Definition of GraphCNN
GraphCNN includes ResNet50 as a submodule
"""
from __future__ import division

import torch
import torch.nn as nn

# from .resnet import resnet50
import torchvision.models as models


import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
from src.graph_networks.graphcmr.utils_mesh import Mesh
from src.graph_networks.graphcmr.graph_layers import GraphResBlock, GraphLinear


class GraphCNN(nn.Module):
    
    def __init__(self, A, ref_vertices, n_resnet_in, n_resnet_out, num_layers=5, num_channels=512):
        super(GraphCNN, self).__init__()
        self.A = A
        self.ref_vertices = ref_vertices
        # self.resnet = resnet50(pretrained=True)
        #   -> within the GraphCMR network they ignore the last fully connected layer
        # replace the first layer
        self.resnet = models.resnet34(pretrained=False)  
        n_in = 3 + 1
        self.resnet.conv1 = nn.Conv2d(n_resnet_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        # replace the last layer
        self.resnet.fc = nn.Linear(512, n_resnet_out) 


        layers = [GraphLinear(3 + n_resnet_out, 2 * num_channels)]  # [GraphLinear(3 + 2048, 2 * num_channels)]
        layers.append(GraphResBlock(2 * num_channels, num_channels, A))
        for i in range(num_layers):
            layers.append(GraphResBlock(num_channels, num_channels, A))
        self.n_out_gc = 2       # two labels per vertex  
        self.gc = nn.Sequential(GraphResBlock(num_channels, 64, A),
                                   GraphResBlock(64, 32, A),
                                   nn.GroupNorm(32 // 8, 32),
                                   nn.ReLU(inplace=True),
                                   GraphLinear(32, self.n_out_gc))
        self.gcnn = nn.Sequential(*layers)
        self.n_out_flatground = 1
        self.flat_ground = nn.Sequential(nn.GroupNorm(num_channels // 8, num_channels),
                                      nn.ReLU(inplace=True),
                                      GraphLinear(num_channels, 1),
                                      nn.ReLU(inplace=True),
                                      nn.Linear(A.shape[0], self.n_out_flatground))

    def forward(self, image):
        """Forward pass
        Inputs:
            image: size = (B, 3, 256, 256)
        Returns:
            Regressed (subsampled) non-parametric shape: size = (B, 1723, 3)
            Weak-perspective camera: size = (B, 3)
        """
        # import pdb; pdb.set_trace()

        batch_size = image.shape[0]
        ref_vertices = self.ref_vertices[None, :, :].expand(batch_size, -1, -1)     # (bs, 3, 973)
        image_resnet = self.resnet(image)       # (bs, 512)
        image_enc = image_resnet.view(batch_size, -1, 1).expand(-1, -1, ref_vertices.shape[-1]) # (bs, 512, 973)
        x = torch.cat([ref_vertices, image_enc], dim=1)
        x = self.gcnn(x)        # (bs, 512, 973)
        ground_contact = self.gc(x)      # (bs, 2, 973)
        ground_flatness = self.flat_ground(x).view(batch_size, self.n_out_flatground)    # (bs, 1)
        return ground_contact, ground_flatness




# how to use it:
#
# from src.graph_networks.graphcmr.utils_mesh import Mesh 
#
# create Mesh object
# self.mesh = Mesh()
# self.faces = self.mesh.faces.to(self.device)
#
# create GraphCNN
# self.graph_cnn = GraphCNN(self.mesh.adjmat,
#                     self.mesh.ref_vertices.t(),
#                     num_channels=self.options.num_channels,
#                     num_layers=self.options.num_layers
#                     ).to(self.device)
# ------------
#
# Feed image in the GraphCNN
# Returns subsampled mesh and camera parameters
# pred_vertices_sub, pred_camera = self.graph_cnn(images)
# 
# Upsample mesh in the original size
# pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1,2))
#