File size: 602 Bytes
4a1f918
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18

import torch.nn as nn

class CRFCombinedModel(nn.Module):
    def __init__(self, base_model, crf):
        super(CRFCombinedModel, self).__init__()
        self.base_model = base_model
        self.crf = crf

    def forward(self, x, x_text=None, spatial_spacings=None):
        logits,reg_loss = self.base_model(x, x_text) 
        shape_img = logits.shape
        if len(shape_img) == 3:
            logits = logits.reshape(shape_img[0], 1, shape_img[1], shape_img[2])
        output = self.crf(logits, spatial_spacings=spatial_spacings)
        print(output.shape)
        return output , reg_loss