radames's picture
first
ee9318a
raw
history blame
877 Bytes
from typing import List
import torch
from transformers import SamModel, SamProcessor
from PIL import Image
from io import BytesIO
import numpy as np
class PreTrainedPipeline():
def __init__(self, path=""):
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
self.processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
self.model = SamModel.from_pretrained(
"facebook/sam-vit-base").to(self.device)
self.model.eval()
self.model = self.model.to(self.device)
def __call__(self, inputs: "Image.Image") -> BytesIO:
raw_image = inputs.convert("RGB")
inputs = self.processor(raw_image, return_tensors="pt").to(self.device)
feature_vector = self.model.get_image_embeddings(
inputs["pixel_values"])
return feature_vector.tolist()