방재호
init
b5ba7a5
raw
history blame contribute delete
728 Bytes
import torch
from segment_anything.modeling import ImageEncoderViT
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViTHQ(ImageEncoderViT):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
interm_embeddings=[]
for blk in self.blocks:
x = blk(x)
if blk.window_size == 0:
interm_embeddings.append(x)
x = self.neck(x.permute(0, 3, 1, 2))
return x, interm_embeddings