import torch.nn as nn | |
class CRFCombinedModel(nn.Module): | |
def __init__(self, base_model, crf): | |
super(CRFCombinedModel, self).__init__() | |
self.base_model = base_model | |
self.crf = crf | |
def forward(self, x, x_text=None, spatial_spacings=None): | |
logits,reg_loss = self.base_model(x, x_text) | |
shape_img = logits.shape | |
if len(shape_img) == 3: | |
logits = logits.reshape(shape_img[0], 1, shape_img[1], shape_img[2]) | |
output = self.crf(logits, spatial_spacings=spatial_spacings) | |
print(output.shape) | |
return output , reg_loss | |