pi-tagger / tagger /common.py
neggles's picture
init
2b6048b
raw
history blame
3.01 kB
import json
from dataclasses import asdict, dataclass
from functools import lru_cache
from os import PathLike
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from PIL import Image
class DictJsonMixin:
def asdict(self, *args, **kwargs) -> dict[str, Any]:
return asdict(self, *args, **kwargs)
def asjson(self, *args, **kwargs):
return json.dumps(asdict(self, *args, **kwargs))
@dataclass
class LabelData(DictJsonMixin):
names: list[str]
rating: list[np.int64]
general: list[np.int64]
character: list[np.int64]
@dataclass
class ImageLabels(DictJsonMixin):
caption: str
booru: str
rating: dict[str, float]
general: dict[str, float]
character: dict[str, float]
@lru_cache(maxsize=5)
def load_labels(csv_path: PathLike = "data/selected_tags.csv") -> LabelData:
csv_path = Path(csv_path).resolve()
if not csv_path.is_file():
raise FileNotFoundError("No selected_tags.csv found")
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"])
tag_data = LabelData(
names=df["name"].tolist(),
rating=list(np.where(df["category"] == 9)[0]),
general=list(np.where(df["category"] == 0)[0]),
character=list(np.where(df["category"] == 4)[0]),
)
return tag_data
def pil_ensure_rgb(image: Image.Image) -> Image.Image:
# convert to RGB/RGBA if not already (deals with palette images etc.)
if image.mode not in ["RGB", "RGBA"]:
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB")
# convert RGBA to RGB with white background
if image.mode == "RGBA":
canvas = Image.new("RGBA", image.size, (255, 255, 255))
canvas.alpha_composite(image)
image = canvas.convert("RGB")
return image
def pil_pad_square(
image: Image.Image,
fill: tuple[int, int, int] = (255, 255, 255),
) -> Image.Image:
w, h = image.size
# get the largest dimension so we can pad to a square
px = max(image.size)
# pad to square with white background
canvas = Image.new("RGB", (px, px), fill)
canvas.paste(image, ((px - w) // 2, (px - h) // 2))
return canvas
def preprocess_image(
image: Image.Image,
size_px: int | tuple[int, int],
upscale: bool = True,
) -> Image.Image:
"""
Preprocess an image to be square and centered on a white background.
"""
if isinstance(size_px, int):
size_px = (size_px, size_px)
# ensure RGB and pad to square
image = pil_ensure_rgb(image)
image = pil_pad_square(image)
# resize to target size
if image.size[0] < size_px[0] or image.size[1] < size_px[1]:
if upscale is False:
raise ValueError("Image is smaller than target size, and upscaling is disabled")
image = image.resize(size_px, Image.LANCZOS)
if image.size[0] > size_px[0] or image.size[1] > size_px[1]:
image.thumbnail(size_px, Image.BICUBIC)
return image