import base64 import functools import io import logging import random import beartype import einops.layers.torch import numpy as np import requests from jaxtyping import Integer, UInt8, jaxtyped from PIL import Image from torch import Tensor from torchvision.transforms import v2 logger = logging.getLogger("data.py") R2_URL = "https://pub-129e98faed1048af94c4d4119ea47be7.r2.dev" @beartype.beartype @functools.lru_cache(maxsize=512) def get_img(i: int) -> Image.Image: fpath = f"/images/ADE_val_{i + 1:08}.jpg" url = R2_URL + fpath logger.info("Getting image from '%s'.", url) return Image.open(requests.get(url, stream=True).raw) @beartype.beartype @functools.lru_cache(maxsize=512) def get_seg(i: int) -> Image.Image: fpath = f"/annotations/ADE_val_{i + 1:08}.png" url = R2_URL + fpath logger.info("Getting annotations from '%s'.", url) return Image.open(requests.get(url, stream=True).raw) @jaxtyped(typechecker=beartype.beartype) def make_colors() -> UInt8[np.ndarray, "n 3"]: values = (0, 51, 102, 153, 204, 255) colors = [] for r in values: for g in values: for b in values: colors.append((r, g, b)) # Fixed seed random.Random(42).shuffle(colors) colors = np.array(colors, dtype=np.uint8) # Fixed colors. Must be synced with Segmentation.elm. colors[2] = np.array([201, 249, 255], dtype=np.uint8) colors[4] = np.array([151, 204, 4], dtype=np.uint8) colors[13] = np.array([104, 139, 88], dtype=np.uint8) colors[16] = np.array([54, 48, 32], dtype=np.uint8) colors[21] = np.array([120, 202, 210], dtype=np.uint8) # water colors[26] = np.array([45, 125, 210], dtype=np.uint8) colors[29] = np.array([116, 142, 84], dtype=np.uint8) colors[46] = np.array([238, 185, 2], dtype=np.uint8) colors[52] = np.array([88, 91, 86], dtype=np.uint8) colors[60] = np.array([72, 99, 156], dtype=np.uint8) # river colors[72] = np.array([76, 46, 5], dtype=np.uint8) colors[94] = np.array([12, 15, 10], dtype=np.uint8) return colors colors = make_colors() resize_transform = v2.Compose([ v2.Resize((512, 512), interpolation=v2.InterpolationMode.NEAREST), v2.CenterCrop((448, 448)), ]) @beartype.beartype def to_sized(img_raw: Image.Image) -> Image.Image: return resize_transform(img_raw) u8_transform = v2.Compose([ v2.ToImage(), einops.layers.torch.Rearrange("() width height -> width height"), ]) @beartype.beartype def to_u8(seg_raw: Image.Image) -> UInt8[Tensor, "width height"]: return u8_transform(seg_raw) @jaxtyped(typechecker=beartype.beartype) def u8_to_img(map: UInt8[Tensor, "width height"]) -> Image.Image: map = map.cpu().numpy() width, height = map.shape colored = np.zeros((width, height, 3), dtype=np.uint8) for i, color in enumerate(colors): colored[map == i + 1, :] = color return Image.fromarray(colored) @jaxtyped(typechecker=beartype.beartype) def to_classes(map: Integer[Tensor, "width height"]) -> list[int]: # Integer is any signed or unsigned int: https://docs.kidger.site/jaxtyping/api/array/#dtype return list(set(map.view(-1).tolist())) @beartype.beartype def img_to_base64(img: Image.Image) -> str: buf = io.BytesIO() img.save(buf, format="webp") b64 = base64.b64encode(buf.getvalue()) s64 = b64.decode("utf8") return "data:image/webp;base64," + s64