|
from os.path import join
|
|
import PIL
|
|
import numpy as np
|
|
import pandas as pd
|
|
import reverse_geocoder
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class GeoDataset(Dataset):
|
|
def __init__(self, image_folder, annotation_file, transformation, tag="image_id"):
|
|
self.image_folder = image_folder
|
|
gt = pd.read_csv(annotation_file, dtype={tag: str})
|
|
files = set([f.replace(".jpg", "") for f in os.listdir(image_folder)])
|
|
gt = gt[gt[tag].isin(files)]
|
|
self.processor = transformation
|
|
self.gt = [
|
|
(g[1][tag], g[1]["latitude"], g[1]["longitude"]) for g in gt.iterrows()
|
|
]
|
|
self.tag = tag
|
|
|
|
def fid(self, i):
|
|
return self.gt[i][0]
|
|
|
|
def latlon(self, i):
|
|
return self.gt[i][1]
|
|
|
|
def __len__(self):
|
|
return len(self.gt)
|
|
|
|
def __getitem__(self, idx):
|
|
fp = join(self.image_folder, self.gt[idx][0] + ".jpg")
|
|
return self.processor(self, idx, fp)
|
|
|
|
|
|
def load_plonk(path):
|
|
import hydra
|
|
from hydra import initialize, compose
|
|
from models.module import Geolocalizer
|
|
from omegaconf import OmegaConf, open_dict
|
|
from os.path import join
|
|
from hydra.utils import instantiate
|
|
|
|
|
|
|
|
with initialize(version_base=None, config_path="osv5m__best_model"):
|
|
cfg = compose(config_name="config", overrides=[])
|
|
|
|
checkpoint = torch.load(join(path, "last.ckpt"))
|
|
del checkpoint["state_dict"][
|
|
"model.backbone.clip.vision_model.embeddings.position_ids"
|
|
]
|
|
torch.save(checkpoint, join(path, "last2.ckpt"))
|
|
|
|
with open_dict(cfg):
|
|
cfg.checkpoint = join(path, "last2.ckpt")
|
|
|
|
cfg.num_classes = 11399
|
|
cfg.model.network.mid.instance.final_dim = cfg.num_classes * 3
|
|
cfg.model.network.head.final_dim = cfg.num_classes * 3
|
|
cfg.model.network.head.instance.quadtree_path = join(path, "quadtree_10_1000.csv")
|
|
|
|
cfg.dataset.train_dataset.path = ""
|
|
cfg.dataset.val_dataset.path = ""
|
|
cfg.dataset.test_dataset.path = ""
|
|
cfg.logger.save_dir = ""
|
|
cfg.data_dir = ""
|
|
cfg.root_dir = ""
|
|
cfg.mode = "test"
|
|
cfg.model.network.backbone.instance.path = (
|
|
"laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K"
|
|
)
|
|
transform = instantiate(cfg.dataset.test_transform)
|
|
model = Geolocalizer.load_from_checkpoint(join(path, "last2.ckpt"), cfg=cfg.model)
|
|
os.remove(join(path, "last2.ckpt"))
|
|
|
|
@torch.no_grad()
|
|
def inference(model, x):
|
|
return x[0], model.model.backbone({"img": x[1].to(model.device)})[:, 0, :].cpu()
|
|
|
|
def collate_fn(batch):
|
|
return [b[0] for b in batch], torch.stack([b[1] for b in batch], dim=0)
|
|
|
|
def operate(self, idx, fp):
|
|
proc = self.processor(PIL.Image.open(fp))
|
|
return self.gt[idx][0], proc
|
|
|
|
return model, operate, inference, collate_fn
|
|
|
|
|
|
def load_clip(which):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import CLIPProcessor, CLIPModel
|
|
|
|
@torch.no_grad()
|
|
def inference(model, img):
|
|
image_ids = img.data.pop("image_id")
|
|
image_input = img.to(model.device)
|
|
image_input["pixel_values"] = image_input["pixel_values"].squeeze(1)
|
|
features = model.get_image_features(**image_input)
|
|
features /= features.norm(dim=-1, keepdim=True)
|
|
return image_ids, features.cpu()
|
|
|
|
processor = CLIPProcessor.from_pretrained(which)
|
|
|
|
def operate(self, idx, fp):
|
|
pil = PIL.Image.open(fp)
|
|
proc = processor(images=pil, return_tensors="pt")
|
|
proc["image_id"] = self.gt[idx][0]
|
|
return proc
|
|
|
|
return CLIPModel.from_pretrained(which), operate, inference, None
|
|
|
|
|
|
def load_dino(which):
|
|
|
|
|
|
from transformers import AutoImageProcessor, AutoModel
|
|
|
|
@torch.no_grad()
|
|
def inference(model, img):
|
|
image_ids = img.data.pop("image_id")
|
|
image_input = img.to(model.device)
|
|
image_input["pixel_values"] = image_input["pixel_values"].squeeze(1)
|
|
features = model(**image_input).last_hidden_state[:, 0]
|
|
features /= features.norm(dim=-1, keepdim=True)
|
|
return image_ids, features.cpu()
|
|
|
|
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
|
|
|
|
def operate(self, idx, fp):
|
|
pil = PIL.Image.open(fp)
|
|
proc = processor(images=pil, return_tensors="pt")
|
|
proc["image_id"] = self.gt[idx][0]
|
|
return proc
|
|
|
|
return AutoModel.from_pretrained("facebook/dinov2-large"), operate, inference, None
|
|
|
|
|
|
def get_backbone(name):
|
|
if os.path.isdir(name):
|
|
return load_plonk(name)
|
|
elif "clip" in name.lower():
|
|
return load_clip(name)
|
|
elif "dino" in name.lower():
|
|
return load_dino(name)
|
|
|