File size: 1,709 Bytes
9665c2c
 
 
 
 
 
 
 
b0cf684
9665c2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) Meta Platforms, Inc. and affiliates.

import json
import shutil
from pathlib import Path

import cv2
import numpy as np
import requests
import torch
from tqdm.auto import tqdm

from .. import logger

DATA_URL = "https://cvg-data.inf.ethz.ch/OrienterNet_CVPR2023"


def read_image(path, grayscale=False):
    if grayscale:
        mode = cv2.IMREAD_GRAYSCALE
    else:
        mode = cv2.IMREAD_COLOR
    image = cv2.imread(str(path), mode)
    if image is None:
        raise ValueError(f"Cannot read image {path}.")
    if not grayscale and len(image.shape) == 3:
        image = np.ascontiguousarray(image[:, :, ::-1])  # BGR to RGB
    return image


def write_torch_image(path, image):
    image_cv2 = np.round(image.clip(0, 1) * 255).astype(int)[..., ::-1]
    cv2.imwrite(str(path), image_cv2)


class JSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (np.ndarray, torch.Tensor)):
            return obj.tolist()
        elif isinstance(obj, np.generic):
            return obj.item()
        return json.JSONEncoder.default(self, obj)


def write_json(path, data):
    with open(path, "w") as f:
        json.dump(data, f, cls=JSONEncoder)


def download_file(url, path):
    path = Path(path)
    if path.is_dir():
        path = path / Path(url).name
    path.parent.mkdir(exist_ok=True, parents=True)
    logger.info("Downloading %s to %s.", url, path)
    with requests.get(url, stream=True) as r:
        total_length = int(r.headers.get("Content-Length"))
        with tqdm.wrapattr(r.raw, "read", total=total_length, desc="") as raw:
            with open(path, "wb") as output:
                shutil.copyfileobj(raw, output)
    return path