from typing import List | |
import torch | |
from transformers import SamModel, SamProcessor | |
from PIL import Image | |
import numpy as np | |
class PreTrainedPipeline(): | |
def __init__(self, path=""): | |
self.device = torch.device( | |
"cuda" if torch.cuda.is_available() else "cpu") | |
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
self.model = SamModel.from_pretrained( | |
"facebook/sam-vit-base").to(self.device) | |
self.model.eval() | |
self.model = self.model.to(self.device) | |
def __call__(self, inputs: "Image.Image") -> List[float]: | |
raw_image = inputs.convert("RGB") | |
inputs = self.processor(raw_image, return_tensors="pt").to(self.device) | |
feature_vector = self.model.get_image_embeddings( | |
inputs["pixel_values"]) | |
return feature_vector.tolist() | |