XVerse / eval /tools /dino.py
alexnasa's picture
Upload 113 files
4479f79 verified
raw
history blame
3.47 kB
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import ViTImageProcessor, ViTModel
from torch.nn import functional as F
from PIL import Image
import requests
from torchvision import transforms
import torch, os
class DINOScore:
def __init__(self, device, use_center_crop=True):
# https://github.com/facebookresearch/dino/issues/72#issuecomment-932874140
# https://github.com/facebookresearch/dino/blob/main/eval_linear.py
# https://gist.github.com/woctezuma/a30ee1de2e5efc1a3beff8e108795374
# according to this, we should use center crop with class token
self.device = torch.device(device)
self.use_center_crop = use_center_crop
if use_center_crop:
self.T = transforms.Compose([
transforms.Resize(256, interpolation=3),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
else:
self.T = transforms.Compose([
transforms.Resize(224, interpolation=3),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
self.model = ViTModel.from_pretrained(os.getenv("DINO_MODEL_PATH", "facebook/dino-vits16")).to(self.device)
def __call__(self, image_x, image_y, similarity_type="class"):
inputs = torch.stack([self.T(x) for x in [image_x, image_y]]) # (2, 3, 224, 224). Batchsize = 2
outputs = self.model(inputs.to(self.device))
last_hidden_states = outputs.last_hidden_state
assert similarity_type in ["class", "avg"]
if similarity_type == "class":
return self.cls_similarity(last_hidden_states[0], last_hidden_states[1])
return self.avg_similairty(last_hidden_states[0], last_hidden_states[1])
def avg_similairty(self, x, y):
return F.cosine_similarity(x.mean(dim=0), y.mean(dim=0), dim=0).item() * 100
def cls_similarity(self, x, y):
return F.cosine_similarity(x[0], y[0], dim=0).item() * 100
if __name__ == "__main__":
# urls = [
# 'https://github.com/google/dreambooth/blob/main/dataset/rc_car/03.jpg?raw=true', # reference from Fig 11
# 'https://github.com/google/dreambooth/blob/main/dataset/rc_car/02.jpg?raw=true'# Real Sample from Fig 11
# ]
# images = [Image.open(requests.get(url, stream=True).raw) for url in urls]
urls = [
"assets/idipbench_base/object/3_pinkbackpack.png",
"tmp/backpack_0.png",
]
images = [Image.open(url).convert("RGB") for url in urls]
dino_score_model = DINOScore("cuda", use_center_crop=True)
print(dino_score_model(images[0], images[1], "class"))
print(dino_score_model(images[0], images[1], "avg"))