File size: 850 Bytes
dac7840 6f6eb5f dac7840 6f6eb5f dac7840 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from typing import List
import torch
from transformers import SamModel, SamProcessor
from PIL import Image
import numpy as np
MODEL_ID = "facebook/sam-vit-huge"
class PreTrainedPipeline():
def __init__(self, path=""):
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.processor = SamProcessor.from_pretrained(MODEL_ID)
self.model = SamModel.from_pretrained(MODEL_ID).to(self.device)
self.model.eval()
self.model = self.model.to(self.device)
def __call__(self, inputs: "Image.Image") -> List[float]:
raw_image = inputs.convert("RGB")
inputs = self.processor(raw_image, return_tensors="pt").to(self.device)
feature_vector = self.model.get_image_embeddings(
inputs["pixel_values"])
return feature_vector.tolist()
|