import torch.nn as nn class CRFCombinedModel(nn.Module): def __init__(self, base_model, crf): super(CRFCombinedModel, self).__init__() self.base_model = base_model self.crf = crf def forward(self, x, x_text=None, spatial_spacings=None): logits,reg_loss = self.base_model(x, x_text) shape_img = logits.shape if len(shape_img) == 3: logits = logits.reshape(shape_img[0], 1, shape_img[1], shape_img[2]) output = self.crf(logits, spatial_spacings=spatial_spacings) print(output.shape) return output , reg_loss