Spaces:
Running
Running
import torch | |
from torch.utils.data import DataLoader | |
from starvector.metrics.base_metric import BaseMetric | |
from tqdm import tqdm | |
from transformers import AutoModel, AutoImageProcessor | |
from PIL import Image | |
import torch.nn as nn | |
class DINOScoreCalculator(BaseMetric): | |
def __init__(self, config=None, device='cuda'): | |
super().__init__() | |
self.class_name = self.__class__.__name__ | |
self.config = config | |
self.model, self.processor = self.get_DINOv2_model("base") | |
self.model = self.model.to(device) | |
self.device = device | |
self.metric = self.calculate_DINOv2_similarity_score | |
def get_DINOv2_model(self, model_size): | |
if model_size == "small": | |
model_size = "facebook/dinov2-small" | |
elif model_size == "base": | |
model_size = "facebook/dinov2-base" | |
elif model_size == "large": | |
model_size = "facebook/dinov2-large" | |
else: | |
raise ValueError(f"model_size should be either 'small', 'base' or 'large', got {model_size}") | |
return AutoModel.from_pretrained(model_size), AutoImageProcessor.from_pretrained(model_size) | |
def process_input(self, image, processor): | |
if isinstance(image, str): | |
image = Image.open(image) | |
if isinstance(image, Image.Image): | |
with torch.no_grad(): | |
inputs = processor(images=image, return_tensors="pt").to(self.device) | |
outputs = self.model(**inputs) | |
features = outputs.last_hidden_state.mean(dim=1) | |
elif isinstance(image, torch.Tensor): | |
features = image.unsqueeze(0) if image.dim() == 1 else image | |
else: | |
raise ValueError("Input must be a file path, PIL Image, or tensor of features") | |
return features | |
def calculate_DINOv2_similarity_score(self, **kwargs): | |
image1 = kwargs.get('gt_im') | |
image2 = kwargs.get('gen_im') | |
features1 = self.process_input(image1, self.processor) | |
features2 = self.process_input(image2, self.processor) | |
cos = nn.CosineSimilarity(dim=1) | |
sim = cos(features1, features2).item() | |
sim = (sim + 1) / 2 | |
return sim | |