|
import argparse |
|
import sys |
|
|
|
from PIL import Image |
|
from typing import List, Optional, Tuple |
|
|
|
|
|
Pos = Tuple[int, int] |
|
Dim = Tuple[int, int] |
|
|
|
|
|
class Box: |
|
def __init__(self, min: Pos, max: Pos) -> None: |
|
self._min = min |
|
self._max = max |
|
|
|
|
|
def min(self) -> Tuple[int, int]: |
|
return self._min |
|
|
|
|
|
def max(self) -> Tuple[int, int]: |
|
return self._max |
|
|
|
def width(self) -> int: |
|
return self._max[0] - self._min[0] + 1 |
|
|
|
def height(self) -> int: |
|
return self._max[1] - self._min[1] + 1 |
|
|
|
def dimensions(self) -> Tuple[int, int]: |
|
return (self.width(), self.height()) |
|
|
|
|
|
def as_tuple(self) -> Tuple[int, int, int, int]: |
|
return (self._min[0], self._min[1], self._max[0], self._max[1]) |
|
|
|
|
|
class DownBox(Box): |
|
def __init__(self, min: Pos, max: Pos, down_pos: Pos) -> None: |
|
super().__init__(min, max) |
|
self._down_pos = down_pos |
|
|
|
def down_pos(self) -> Tuple[int, int]: |
|
return self._down_pos |
|
|
|
|
|
class ExtractedBoxes: |
|
def __init__(self, boxes: List[DownBox]) -> None: |
|
self._boxes = boxes |
|
|
|
def boxes(self) -> List[DownBox]: |
|
return self._boxes |
|
|
|
def down_dimensions(self) -> Dim: |
|
if len(self._boxes) == 0: |
|
return (0, 0) |
|
back = self._boxes[-1] |
|
down = back.down_pos() |
|
return (down[0] + 1, down[1] + 1) |
|
|
|
|
|
def average_box_dimensions(boxes: List[DownBox]) -> Dim: |
|
assert len(boxes) > 0 |
|
if len(boxes) == 1: |
|
return boxes[0].dimensions() |
|
if len(boxes) <= 16: |
|
|
|
width = 0 |
|
height = 0 |
|
for box in boxes: |
|
width += box.width() |
|
height += box.height() |
|
return (width // len(boxes), height // len(boxes)) |
|
|
|
widths = [box.width() for box in boxes] |
|
heights = [box.height() for box in boxes] |
|
widths.sort() |
|
heights.sort() |
|
return (widths[len(widths) // 2], heights[len(heights) // 2]) |
|
|
|
|
|
def get_trimmed(boxes: List[DownBox]) -> Tuple[Box, Box]: |
|
avg = average_box_dimensions(boxes) |
|
|
|
outlier_dist = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_outlier(box: DownBox) -> bool: |
|
dim = box.dimensions() |
|
if abs(dim[0] - avg[0]) > outlier_dist: |
|
return True |
|
if abs(dim[1] - avg[1]) > outlier_dist: |
|
return True |
|
return False |
|
|
|
assert len(boxes) > 0 |
|
front = boxes[0] |
|
back = boxes[-1] |
|
|
|
min_out = (0, 0) |
|
max_out = back.max() |
|
min_down = (0, 0) |
|
max_down = back.down_pos() |
|
if is_outlier(front): |
|
for i in range(1, len(boxes)): |
|
if not is_outlier(boxes[i]): |
|
min_out = boxes[i].min() |
|
min_down = boxes[i].down_pos() |
|
break |
|
if is_outlier(back): |
|
for i in range(len(boxes) - 2, -1, -1): |
|
if not is_outlier(boxes[i]): |
|
max_out = boxes[i].max() |
|
max_down = boxes[i].down_pos() |
|
break |
|
box_out = Box(min_out, max_out) |
|
box_down = Box(min_down, max_down) |
|
return (box_out, box_down) |
|
|
|
|
|
def calc_face_box(control_image: Image.Image, min_pos: Pos) -> Box: |
|
min_pixel = control_image.getpixel(min_pos) |
|
width, height = control_image.size |
|
x = 0 |
|
while min_pos[0] + x < width: |
|
if control_image.getpixel((min_pos[0] + x, min_pos[1])) != min_pixel: |
|
break |
|
x += 1 |
|
y = 0 |
|
while min_pos[1] + y < height: |
|
if control_image.getpixel((min_pos[0], min_pos[1] + y)) != min_pixel: |
|
break |
|
y += 1 |
|
x -= 1 |
|
y -= 1 |
|
assert x > 0 |
|
assert y > 0 |
|
return Box(min_pos, (x + min_pos[0], y + min_pos[1])) |
|
|
|
|
|
def extract_boxes(control_image: Image.Image) -> ExtractedBoxes: |
|
width, height = control_image.size |
|
assert width > 0 |
|
assert height > 0 |
|
|
|
boxes: List[DownBox] = [] |
|
x = 0 |
|
y = 0 |
|
down_x = 0 |
|
down_y = 0 |
|
|
|
while y < height: |
|
while x < width: |
|
min_pos = (x, y) |
|
box = calc_face_box(control_image, min_pos) |
|
boxes.append(DownBox(box.min(), box.max(), (down_x, down_y))) |
|
x += box.width() |
|
down_x += 1 |
|
assert x == width |
|
box = boxes[-1] |
|
x = 0 |
|
y += box.height() |
|
down_x = 0 |
|
down_y += 1 |
|
assert y == height |
|
|
|
return ExtractedBoxes(boxes) |
|
|
|
|
|
def downsample_one(input_image: Image.Image, box: Box, sample_radius: Optional[int], downsampler: Image.Resampling) -> Tuple[int, int, int]: |
|
region = input_image.crop(box.as_tuple()) |
|
|
|
box_width = box.width() |
|
box_height = box.height() |
|
box_center_x = box.min()[0] + box_width // 2 |
|
box_center_y = box.min()[1] + box_height // 2 |
|
|
|
if sample_radius is not None: |
|
radius_x = min(sample_radius, box_width // 2) |
|
radius_y = min(sample_radius, box_height // 2) |
|
else: |
|
radius_x = box_width // 2 |
|
radius_y = box_height // 2 |
|
|
|
cropped_region = region.crop(( |
|
max(0, box_center_x - radius_x - box.min()[0]), |
|
max(0, box_center_y - radius_y - box.min()[1]), |
|
min(box_width, box_center_x + radius_x - box.min()[0]), |
|
min(box_height, box_center_y + radius_y - box.min()[1]) |
|
)) |
|
assert cropped_region.size[0] >= radius_x and cropped_region.size[1] >= radius_y |
|
sampled = cropped_region.resize((1, 1), downsampler) |
|
|
|
rgb_value = sampled.getpixel((0, 0)) |
|
assert isinstance(rgb_value, tuple) and len(rgb_value) == 3 |
|
return rgb_value |
|
|
|
|
|
class ImageRef: |
|
def __init__(self, ref: Image.Image) -> None: |
|
self.ref = ref |
|
|
|
|
|
def downsample_all(*, input_image: Image.Image, output_image: Optional[ImageRef], down_image: Optional[ImageRef], boxes: List[DownBox], sample_radius: Optional[int], downsampler: Image.Resampling, trim_cropped_edges: bool) -> None: |
|
assert output_image or down_image |
|
for box in boxes: |
|
rgb_value = downsample_one(input_image, box, sample_radius, downsampler) |
|
solid_color_image = Image.new("RGB", box.dimensions(), rgb_value) |
|
if output_image: |
|
output_image.ref.paste(solid_color_image, box.min()) |
|
if down_image: |
|
down_image.ref.paste(solid_color_image, box.down_pos()) |
|
if trim_cropped_edges: |
|
o, d = get_trimmed(boxes) |
|
if output_image: |
|
output_image.ref = output_image.ref.crop(o.as_tuple()) |
|
if down_image: |
|
down_image.ref = down_image.ref.crop(d.as_tuple()) |
|
|
|
|
|
def str2bool(value) -> bool: |
|
if isinstance(value, bool): |
|
return value |
|
if value.lower() in ("yes", "true", "t", "y", "1"): |
|
return True |
|
elif value.lower() in ("no", "false", "f", "n", "0"): |
|
return False |
|
else: |
|
raise argparse.ArgumentTypeError("Boolean value expected.") |
|
|
|
|
|
def main(cli_args: List[str]) -> None: |
|
parser = argparse.ArgumentParser(description="Downsample and rescale image.") |
|
parser.add_argument("--control", required=True, help="Path to control image.") |
|
parser.add_argument("--input", required=True, help="Path to input image.") |
|
parser.add_argument("--output-up", help="Path to save the output image, upscaled to the original size.") |
|
parser.add_argument("--output-down", help="Path to save the output image, kept at the downsampled size.") |
|
parser.add_argument("--sample-radius", type=int, default=None, help="Radius for sampling (Manhattan distance).") |
|
parser.add_argument("--downsampler", choices=["box", "bilinear", "bicubic", "hamming", "lanczos"], default="box", help="Downsampler to use.") |
|
parser.add_argument("--trim-cropped-edges", type=str2bool, default=False, help="Drop mapped checker grid elements that are cropped in the control image.") |
|
|
|
args = parser.parse_args(cli_args) |
|
|
|
control_image = Image.open(args.control).convert("1") |
|
input_image = Image.open(args.input) |
|
if control_image.size != input_image.size: |
|
raise ValueError("Control image and input image must have the same dimensions.") |
|
downsampler = Image.Resampling[args.downsampler.upper()] |
|
output_image: Optional[ImageRef] = None |
|
down_image: Optional[ImageRef] = None |
|
if not args.output_up and not args.output_down: |
|
raise ValueError("At least one of --output-up and --output-down must be specified.") |
|
if args.output_up: |
|
output_image = ImageRef(Image.new("RGB", input_image.size)) |
|
extracted_boxes = extract_boxes(control_image) |
|
if args.output_down: |
|
down_image = ImageRef(Image.new("RGB", extracted_boxes.down_dimensions())) |
|
|
|
boxes = extracted_boxes.boxes() |
|
|
|
print(args.trim_cropped_edges) |
|
|
|
downsample_all(input_image=input_image, output_image=output_image, down_image=down_image, boxes=boxes, sample_radius=args.sample_radius, downsampler=downsampler, trim_cropped_edges=args.trim_cropped_edges) |
|
if output_image: |
|
output_image.ref.save(args.output_up) |
|
if down_image: |
|
down_image.ref.save(args.output_down) |
|
|
|
|
|
if __name__ == "__main__": |
|
main(sys.argv[1:]) |
|
|