|
import base64 |
|
import functools |
|
import io |
|
import logging |
|
import random |
|
|
|
import beartype |
|
import einops.layers.torch |
|
import numpy as np |
|
import requests |
|
from jaxtyping import Integer, UInt8, jaxtyped |
|
from PIL import Image |
|
from torch import Tensor |
|
from torchvision.transforms import v2 |
|
|
|
logger = logging.getLogger("data.py") |
|
|
|
R2_URL = "https://pub-129e98faed1048af94c4d4119ea47be7.r2.dev" |
|
|
|
|
|
@beartype.beartype |
|
@functools.lru_cache(maxsize=512) |
|
def get_img(i: int) -> Image.Image: |
|
fpath = f"/images/ADE_val_{i + 1:08}.jpg" |
|
url = R2_URL + fpath |
|
logger.info("Getting image from '%s'.", url) |
|
return Image.open(requests.get(url, stream=True).raw) |
|
|
|
|
|
@beartype.beartype |
|
@functools.lru_cache(maxsize=512) |
|
def get_seg(i: int) -> Image.Image: |
|
fpath = f"/annotations/ADE_val_{i + 1:08}.png" |
|
url = R2_URL + fpath |
|
logger.info("Getting annotations from '%s'.", url) |
|
return Image.open(requests.get(url, stream=True).raw) |
|
|
|
|
|
@jaxtyped(typechecker=beartype.beartype) |
|
def make_colors() -> UInt8[np.ndarray, "n 3"]: |
|
values = (0, 51, 102, 153, 204, 255) |
|
colors = [] |
|
for r in values: |
|
for g in values: |
|
for b in values: |
|
colors.append((r, g, b)) |
|
|
|
random.Random(42).shuffle(colors) |
|
colors = np.array(colors, dtype=np.uint8) |
|
|
|
|
|
colors[2] = np.array([201, 249, 255], dtype=np.uint8) |
|
colors[4] = np.array([151, 204, 4], dtype=np.uint8) |
|
colors[13] = np.array([104, 139, 88], dtype=np.uint8) |
|
colors[16] = np.array([54, 48, 32], dtype=np.uint8) |
|
colors[21] = np.array([120, 202, 210], dtype=np.uint8) |
|
colors[26] = np.array([45, 125, 210], dtype=np.uint8) |
|
colors[29] = np.array([116, 142, 84], dtype=np.uint8) |
|
colors[46] = np.array([238, 185, 2], dtype=np.uint8) |
|
colors[52] = np.array([88, 91, 86], dtype=np.uint8) |
|
colors[60] = np.array([72, 99, 156], dtype=np.uint8) |
|
colors[72] = np.array([76, 46, 5], dtype=np.uint8) |
|
colors[94] = np.array([12, 15, 10], dtype=np.uint8) |
|
|
|
return colors |
|
|
|
|
|
colors = make_colors() |
|
|
|
resize_transform = v2.Compose([ |
|
v2.Resize((512, 512), interpolation=v2.InterpolationMode.NEAREST), |
|
v2.CenterCrop((448, 448)), |
|
]) |
|
|
|
|
|
@beartype.beartype |
|
def to_sized(img_raw: Image.Image) -> Image.Image: |
|
return resize_transform(img_raw) |
|
|
|
|
|
u8_transform = v2.Compose([ |
|
v2.ToImage(), |
|
einops.layers.torch.Rearrange("() width height -> width height"), |
|
]) |
|
|
|
|
|
@beartype.beartype |
|
def to_u8(seg_raw: Image.Image) -> UInt8[Tensor, "width height"]: |
|
return u8_transform(seg_raw) |
|
|
|
|
|
@jaxtyped(typechecker=beartype.beartype) |
|
def u8_to_img(map: UInt8[Tensor, "width height"]) -> Image.Image: |
|
map = map.cpu().numpy() |
|
width, height = map.shape |
|
colored = np.zeros((width, height, 3), dtype=np.uint8) |
|
for i, color in enumerate(colors): |
|
colored[map == i + 1, :] = color |
|
|
|
return Image.fromarray(colored) |
|
|
|
|
|
@jaxtyped(typechecker=beartype.beartype) |
|
def to_classes(map: Integer[Tensor, "width height"]) -> list[int]: |
|
|
|
return list(set(map.view(-1).tolist())) |
|
|
|
|
|
@beartype.beartype |
|
def img_to_base64(img: Image.Image) -> str: |
|
buf = io.BytesIO() |
|
img.save(buf, format="webp") |
|
b64 = base64.b64encode(buf.getvalue()) |
|
s64 = b64.decode("utf8") |
|
return "data:image/webp;base64," + s64 |
|
|