File size: 4,091 Bytes
01e514a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import math

import torch
import torchvision.transforms as T
import numpy as np
from scepter.modules.annotator.registry import ANNOTATORS
from scepter.modules.utils.config import Config
from PIL import Image


def edit_preprocess(processor, device, edit_image, edit_mask):
    if edit_image is None or processor is None:
        return edit_image
    processor = Config(cfg_dict=processor, load=False)
    processor = ANNOTATORS.build(processor).to(device)
    new_edit_image = processor(np.asarray(edit_image))
    processor = processor.to("cpu")
    del processor
    new_edit_image = Image.fromarray(new_edit_image)
    return Image.composite(new_edit_image, edit_image, edit_mask)

class ACEPlusImageProcessor():
    def __init__(self, max_aspect_ratio=4, d=16, max_seq_len=1024):
        self.max_aspect_ratio = max_aspect_ratio
        self.d = d
        self.max_seq_len = max_seq_len
        self.transforms = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def image_check(self, image):
        if image is None:
            return image
        # preprocess
        W, H = image.size
        if H / W > self.max_aspect_ratio:
            image = T.CenterCrop([int(self.max_aspect_ratio * W), W])(image)
        elif W / H > self.max_aspect_ratio:
            image = T.CenterCrop([H, int(self.max_aspect_ratio * H)])(image)
        return self.transforms(image)


    def preprocess(self,
                   reference_image=None,
                   edit_image=None,
                   edit_mask=None,
                   height=1024,
                   width=1024,
                   repainting_scale = 1.0):
        reference_image = self.image_check(reference_image)
        edit_image = self.image_check(edit_image)
        # for reference generation
        if edit_image is None:
            edit_image = torch.zeros([3, height, width])
            edit_mask = torch.ones([1, height, width])
        else:
            edit_mask = np.asarray(edit_mask)
            edit_mask = np.where(edit_mask > 128, 1, 0)
            edit_mask = edit_mask.astype(
                np.float32) if np.any(edit_mask) else np.ones_like(edit_mask).astype(
                np.float32)
            edit_mask = torch.tensor(edit_mask).unsqueeze(0)

        edit_image = edit_image * (1 - edit_mask * repainting_scale)


        out_h, out_w = edit_image.shape[-2:]

        assert edit_mask is not None
        if reference_image is not None:
        # align height with edit_image
            _, H, W = reference_image.shape
            _, eH, eW = edit_image.shape
            scale = eH / H
            tH, tW = eH, int(W * scale)
            reference_image = T.Resize((tH, tW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(reference_image)
            edit_image = torch.cat([reference_image, edit_image], dim=-1)
            edit_mask = torch.cat([torch.zeros([1, reference_image.shape[1], reference_image.shape[2]]), edit_mask], dim=-1)
            slice_w = reference_image.shape[-1]
        else:
            slice_w = 0

        H, W = edit_image.shape[-2:]
        scale = min(1.0, math.sqrt(self.max_seq_len * 2 / ((H / self.d) * (W / self.d))))
        rH = int(H * scale) // self.d * self.d  # ensure divisible by self.d
        rW = int(W * scale) // self.d * self.d
        slice_w = int(slice_w * scale) // self.d * self.d

        edit_image = T.Resize((rH, rW), interpolation=T.InterpolationMode.BILINEAR, antialias=True)(edit_image)
        edit_mask = T.Resize((rH, rW), interpolation=T.InterpolationMode.NEAREST_EXACT, antialias=True)(edit_mask)

        return edit_image, edit_mask, out_h, out_w, slice_w


    def postprocess(self, image, slice_w, out_w, out_h):
        w, h = image.size
        if slice_w > 0:
            output_image = image.crop((slice_w + 20, 0, w, h))
            output_image = output_image.resize((out_w, out_h))
        else:
            output_image = image
        return output_image