Spaces:
Runtime error
Runtime error
File size: 1,186 Bytes
68cd8f8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from functools import lru_cache
from typing import Union
import clip
from PIL import Image
class VisionLanguageModel:
def __init__(self, model_name: str = "ViT-B/32", device: str = "cuda"):
self._load_model(model_name, device)
self.device = device
@lru_cache(maxsize=1)
def _load_model(self, model_name, device: str = "cpu"):
self.model, self.processor = clip.load(model_name, device=device)
def get_embedding(self, input: Union[str, Image.Image]):
if isinstance(input, str):
tokens = clip.tokenize(input).to(self.device)
vector = self.model.encode_text(tokens)
vector /= vector.norm(dim=-1, keepdim=True)
vector = vector.cpu().detach().numpy().astype("float32")
return vector
elif isinstance(input, Image.Image):
image_input = self.preprocess(input).unsqueeze(0).to(self.device)
vector = self.model.encode_image(image_input)
vector /= vector.norm(dim=-1, keepdim=True)
vector = vector.cpu().detach().numpy().astype("float32")
return vector
else:
raise Exception("Invalid input type")
|