|
from typing import List |
|
import torch |
|
from transformers import SamModel, SamProcessor |
|
from PIL import Image |
|
import numpy as np |
|
|
|
MODEL_ID = "facebook/sam-vit-huge" |
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path=""): |
|
|
|
self.device = torch.device( |
|
"cuda" if torch.cuda.is_available() else "cpu") |
|
self.processor = SamProcessor.from_pretrained(MODEL_ID) |
|
self.model = SamModel.from_pretrained(MODEL_ID).to(self.device) |
|
self.model.eval() |
|
self.model = self.model.to(self.device) |
|
|
|
def __call__(self, inputs: "Image.Image") -> List[float]: |
|
raw_image = inputs.convert("RGB") |
|
inputs = self.processor(raw_image, return_tensors="pt").to(self.device) |
|
feature_vector = self.model.get_image_embeddings( |
|
inputs["pixel_values"]) |
|
|
|
return feature_vector.tolist() |
|
|