SALT-SAM / AllinonSAM /combined_model.py
pythn's picture
Upload with huggingface_hub
4a1f918 verified
import torch.nn as nn
class CRFCombinedModel(nn.Module):
def __init__(self, base_model, crf):
super(CRFCombinedModel, self).__init__()
self.base_model = base_model
self.crf = crf
def forward(self, x, x_text=None, spatial_spacings=None):
logits,reg_loss = self.base_model(x, x_text)
shape_img = logits.shape
if len(shape_img) == 3:
logits = logits.reshape(shape_img[0], 1, shape_img[1], shape_img[2])
output = self.crf(logits, spatial_spacings=spatial_spacings)
print(output.shape)
return output , reg_loss