File size: 649 Bytes
ca2145e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torch.nn as nn
import sys

sys.path.append('third_party/moge')
from .moge.moge.model.moge_model import MoGeModel

class MoGe(nn.Module):
    
    def __init__(self, cache_dir):
        super().__init__()
        self.model = MoGeModel.from_pretrained(
            'Ruicheng/moge-vitl', cache_dir=cache_dir).eval()


    @torch.no_grad()
    def forward_image(self, image: torch.Tensor, **kwargs):
        # image: b, 3, h, w 0,1
        output = self.model.infer(image, resolution_level=9, apply_mask=False, **kwargs)
        points = output['points'] # b,h,w,3
        masks = output['mask'] # b,h,w
        return points, masks