|
import os
|
|
import sys
|
|
import PIL
|
|
import json
|
|
import torch
|
|
import numpy as np
|
|
import pandas as pd
|
|
import operator
|
|
|
|
from PIL import Image
|
|
from itertools import cycle
|
|
from tqdm.auto import tqdm, trange
|
|
from os.path import join
|
|
from PIL import Image
|
|
|
|
from tqdm import tqdm
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torch.nn import functional as F
|
|
|
|
from backbone import get_backbone
|
|
from utils import haversine, get_filenames, get_match_values, compute_print_accuracy
|
|
|
|
|
|
def compute_features(path, data_dir, csv_file, tag, args):
|
|
data = GeoDataset(data_dir, csv_file, tag=tag)
|
|
if not os.path.isdir(test_features_dir) or len(
|
|
os.listdir(test_features_dir)
|
|
) != len(data):
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model, transform, inference, collate_fn = get_backbone(args.name)
|
|
dataloader = DataLoader(
|
|
data,
|
|
batch_size=args.batch_size,
|
|
shuffle=False,
|
|
num_workers=8,
|
|
collate_fn=collate_fn,
|
|
)
|
|
model = model.to(device)
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
for i, x in enumerate(tqdm(dataloader)):
|
|
image_ids, features = inference(model, x)
|
|
|
|
for j, image_id in zip(range(features.shape[0]), image_ids):
|
|
np.save(join(path, f"{image_id}.npy"), features[j].unsqueeze(0).numpy())
|
|
|
|
|
|
def get_results(args, train_test):
|
|
import joblib
|
|
|
|
if not os.path.isfile(join(args.features_parent, ".cache", "1-nn.pkl")):
|
|
import faiss, glob, bisect
|
|
|
|
|
|
indexes = [
|
|
get_filenames(idx) for idx in tqdm(range(1, 6), desc="Loading indexes...")
|
|
]
|
|
|
|
train_gt = pd.read_csv(
|
|
join(args.data_parent, args.annotation_file), dtype={"image_id": str}
|
|
)[["image_id", "latitude", "longitude"]]
|
|
test_gt = pd.read_csv(test_path_csv, dtype={"id": str})[
|
|
["id", "latitude", "longitude"]
|
|
]
|
|
|
|
|
|
train_gt = {
|
|
g[1]["image_id"]: np.array([g[1]["latitude"], g[1]["longitude"]])
|
|
for g in tqdm(
|
|
train_gt.iterrows(), total=len(train_gt), desc="Loading train_gt"
|
|
)
|
|
}
|
|
test_gt = {
|
|
g[1]["id"]: np.array([g[1]["latitude"], g[1]["longitude"]])
|
|
for g in tqdm(
|
|
test_gt.iterrows(), total=len(test_gt), desc="Loading test_gt"
|
|
)
|
|
}
|
|
|
|
train_test = []
|
|
os.makedirs(join(args.features_parent, ".cache"), exist_ok=True)
|
|
for f in tqdm(os.listdir(test_features_dir)):
|
|
query_vector = np.load(join(test_features_dir, f))
|
|
|
|
neighbors = []
|
|
for index, ids in indexes:
|
|
distances, indices = index.search(query_vector, 1)
|
|
distances, indices = np.squeeze(distances), np.squeeze(indices)
|
|
bisect.insort(
|
|
neighbors, (ids[indices], distances), key=operator.itemgetter(1)
|
|
)
|
|
|
|
neighbors = list(reversed(neighbors))
|
|
train_gps = train_gt[neighbors[0][0].replace(".npy", "")][None, :]
|
|
test_gps = test_gt[f.replace(".npy", "")][None, :]
|
|
train_test.append((train_gps, test_gps))
|
|
joblib.dump(train_test, join(args.features_parent, ".cache", "1-nn.pkl"))
|
|
else:
|
|
train_test = joblib.load(join(args.features_parent, ".cache", "1-nn.pkl"))
|
|
|
|
return train_test
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--id", type=int, default=1)
|
|
parser.add_argument("--batch_size", type=int, default=512)
|
|
parser.add_argument(
|
|
"--annotation_file", type=str, required=False, default="train.csv"
|
|
)
|
|
parser.add_argument("--name", type=str, default="openai/clip-vit-base-patch32")
|
|
parser.add_argument("--features_parent", type=str, default="faiss/")
|
|
parser.add_argument("--data_parent", type=str, default="data/")
|
|
parser.add_argument("--test", action="store_true")
|
|
|
|
args = parser.parse_args()
|
|
args.features_parent = join(args.features_parent, args.name)
|
|
if args.test:
|
|
csv_file = join(args.data_parent, "test.csv")
|
|
data_dir = join(args.data_parent, "test")
|
|
path = join(args.features_parent, "features-test")
|
|
model = get_backbone(args.name)
|
|
compute_features(path, data_dir, csv_file, tag="id", args=args)
|
|
train_test = get_results(args, train_test)
|
|
|
|
from collections import Counter
|
|
|
|
N, pos = Counter(), Counter()
|
|
for train_gps, test_gps in tqdm(train_test, desc="Computing accuracy..."):
|
|
get_match_values(train_gps, test_gps, N, pos)
|
|
|
|
for train_gps, test_gps in tqdm(train_test, desc="Computing haversine..."):
|
|
haversine(train_gps, test_gps, N, pos)
|
|
|
|
compute_print_accuracy(N, pos)
|
|
else:
|
|
csv_file = join(args.data_parent, args.annotation_file)
|
|
path = join(args.features_parent, f"features-{args.id}")
|
|
data_dir = join(args.data_parent, f"images-{args.id}", "train")
|
|
compute_features(path, data_dir, csv_file, tag="image_id", args=args)
|
|
|