import torch | |
from segment_anything.modeling import ImageEncoderViT | |
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa | |
class ImageEncoderViTHQ(ImageEncoderViT): | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.patch_embed(x) | |
if self.pos_embed is not None: | |
x = x + self.pos_embed | |
interm_embeddings=[] | |
for blk in self.blocks: | |
x = blk(x) | |
if blk.window_size == 0: | |
interm_embeddings.append(x) | |
x = self.neck(x.permute(0, 3, 1, 2)) | |
return x, interm_embeddings |