import torch from segment_anything.modeling import ImageEncoderViT # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa class ImageEncoderViTHQ(ImageEncoderViT): def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) if self.pos_embed is not None: x = x + self.pos_embed interm_embeddings=[] for blk in self.blocks: x = blk(x) if blk.window_size == 0: interm_embeddings.append(x) x = self.neck(x.permute(0, 3, 1, 2)) return x, interm_embeddings