|
|
|
""" |
|
Created on Sun Jul 4 15:07:27 2021 |
|
|
|
@author: AlexandreN |
|
""" |
|
from __future__ import print_function, division |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torchvision |
|
|
|
|
|
class SingleTractionHead(nn.Module): |
|
|
|
def __init__(self): |
|
super(SingleTractionHead, self).__init__() |
|
|
|
self.head_locs = nn.Sequential(nn.Linear(2048, 1024), |
|
nn.ReLU(), |
|
nn.Dropout(p=0.3), |
|
nn.Linear(1024, 4), |
|
nn.Sigmoid() |
|
) |
|
|
|
|
|
self.head_class = nn.Sequential(nn.Linear(2048, 128), |
|
nn.ReLU(), |
|
nn.Dropout(p=0.3), |
|
nn.Linear(128, 1)) |
|
|
|
def forward(self, features): |
|
features = features.view(features.size()[0], -1) |
|
|
|
y_bbox = self.head_locs(features) |
|
y_class = self.head_class(features) |
|
|
|
res = (y_bbox, y_class) |
|
return res |
|
|
|
|
|
def create_model(): |
|
|
|
feature_extractor = torchvision.models.resnet50(pretrained=True) |
|
model_body = nn.Sequential(*list(feature_extractor.children())[:-1]) |
|
for param in model_body.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
|
|
model_head = SingleTractionHead() |
|
model = nn.Sequential(model_body, model_head) |
|
return model |
|
|
|
|
|
def load_weights(model, path='model.pt', device_='cpu'): |
|
checkpoint = torch.load(path, map_location=torch.device(device_)) |
|
model.load_state_dict(checkpoint) |
|
return model |
|
|