import abc import cv2 import numpy as np import torch from IPython.display import display from PIL import Image from typing import Union, Tuple, List, Dict, Optional import torch.nn.functional as nnf from PIL import Image, ImageDraw, ImageFont # def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)) -> np.ndarray: # h, w, c = image.shape # offset = int(h * .2) # img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 # font = cv2.FONT_HERSHEY_SIMPLEX # img[:h] = image # textsize = cv2.getTextSize(text, font, 1, 2)[0] # text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 # cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2) # return img def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0), font_scale: float = 1.0, thickness: int = 3) -> np.ndarray: h, w, c = image.shape # offset = int(h * .3) offset = int(h * .2) img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 font = cv2.FONT_HERSHEY_SIMPLEX img[:h] = image textsize = cv2.getTextSize(text, font, font_scale, thickness)[0] text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 cv2.putText(img, text, (text_x, text_y), font, font_scale, text_color, 2) return img def text_under_image_pil(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0), font_scale: float = 1.0) -> np.ndarray: image_pil = Image.fromarray(image) draw = ImageDraw.Draw(image_pil) font_size = int(font_scale * image.shape[0] / 20) # font = ImageFont.truetype("arial.ttf", font_size) font_path = "./Roboto-Regular.ttf" font = ImageFont.truetype(font_path, font_size) textsize = draw.textsize(text, font=font) text_x = (image.shape[1] - textsize[0]) // 2 text_y = image.shape[0] draw.text((text_x, text_y), text, font=font, fill=text_color) return np.array(image_pil) def view_images(images: Union[np.ndarray, List], num_rows: int = 1, offset_ratio: float = 0.02, display_image: bool = True) -> Image.Image: """ Displays a list of images in a grid. """ if type(images) is list: num_empty = len(images) % num_rows elif images.ndim == 4: num_empty = images.shape[0] % num_rows else: images = [images] num_empty = 0 empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty num_items = len(images) h, w, c = images[0].shape offset = int(h * offset_ratio) num_cols = num_items // num_rows image_ = np.ones((h * num_rows + offset * (num_rows - 1), w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 for i in range(num_rows): for j in range(num_cols): image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ i * num_cols + j] pil_img = Image.fromarray(image_) if display_image: display(pil_img) return pil_img def view_images_with_texts(images: Union[np.ndarray, List], texts: Union[str, List[str]], num_rows: int = 1, offset_ratio: float = 0.02, font_scale: float = 1.0, display_image: bool = True) -> Image.Image: """ Displays a list of images in a grid with texts below them. """ # Ensure texts is a list if isinstance(texts, str): texts = [texts] * len(images) # Add text under each image images_with_texts = [text_under_image(img, txt, font_scale=font_scale) for img, txt in zip(images, texts)] if type(images_with_texts) is list: num_empty = len(images_with_texts) % num_rows elif images_with_texts.ndim == 4: num_empty = images_with_texts.shape[0] % num_rows else: images_with_texts = [images_with_texts] num_empty = 0 empty_images = np.ones(images_with_texts[0].shape, dtype=np.uint8) * 255 images_with_texts = [image.astype(np.uint8) for image in images_with_texts] + [empty_images] * num_empty num_items = len(images_with_texts) h, w, c = images_with_texts[0].shape offset = int(h * offset_ratio) num_cols = num_items // num_rows image_ = np.ones((h * num_rows + offset * (num_rows - 1), w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 for i in range(num_rows): for j in range(num_cols): image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images_with_texts[ i * num_cols + j] pil_img = Image.fromarray(image_) if display_image: display(pil_img) return pil_img class AttentionControl(abc.ABC): def step_callback(self, x_t): return x_t def between_steps(self): return @property def num_uncond_att_layers(self): return 0 @abc.abstractmethod def forward (self, attn, is_cross: bool, place_in_unet: str): raise NotImplementedError def __call__(self, attn, is_cross: bool, place_in_unet: str): if self.cur_att_layer >= self.num_uncond_att_layers: h = attn.shape[0] attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) self.cur_att_layer += 1 if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: self.cur_att_layer = 0 self.cur_step += 1 self.between_steps() return attn def reset(self): self.cur_step = 0 self.cur_att_layer = 0 def __init__(self): self.cur_step = 0 self.num_att_layers = -1 self.cur_att_layer = 0 class EmptyControl(AttentionControl): def forward(self, attn, is_cross: bool, place_in_unet: str): return attn class AttentionStore(AttentionControl): @staticmethod def get_empty_store(): return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} def forward(self, attn, is_cross: bool, place_in_unet: str): key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" if attn.shape[1] <= 32 ** 2: # avoid memory overhead self.step_store[key].append(attn) return attn def between_steps(self): if len(self.attention_store) == 0: self.attention_store = self.step_store else: for key in self.attention_store: for i in range(len(self.attention_store[key])): self.attention_store[key][i] += self.step_store[key][i] self.step_store = self.get_empty_store() def get_average_attention(self): average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} return average_attention def reset(self): super(AttentionStore, self).reset() self.step_store = self.get_empty_store() self.attention_store = {} def __init__(self): super(AttentionStore, self).__init__() self.step_store = self.get_empty_store() self.attention_store = {} class LocalBlend: def __call__(self, x_t, attention_store): k = 1 maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, self.max_num_words) for item in maps] maps = torch.cat(maps, dim=1) maps = (maps * self.alpha_layers).sum(-1).mean(1) mask = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 +1), (1, 1), padding=(k, k)) mask = nnf.interpolate(mask, size=(x_t.shape[2:])) mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] mask = mask.gt(self.threshold) mask = (mask[:1] + mask[1:]).float() x_t = x_t[:1] + mask * (x_t - x_t[:1]) return x_t def __init__(self, prompts: List[str], words: [List[List[str]]], tokenizer, device, dtype=torch.float32, threshold=.3, max_num_words=77): self.max_num_words = 77 alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, self.max_num_words) for i, (prompt, words_) in enumerate(zip(prompts, words)): if type(words_) is str: words_ = [words_] for word in words_: ind = get_word_inds(prompt, word, tokenizer) alpha_layers[i, :, :, :, :, ind] = 1 self.alpha_layers = alpha_layers.to(device, dtype) self.threshold = threshold class AttentionControlEdit(AttentionStore, abc.ABC): def step_callback(self, x_t): if self.local_blend is not None: x_t = self.local_blend(x_t, self.attention_store) return x_t def replace_self_attention(self, attn_base, att_replace): if att_replace.shape[2] <= 16 ** 2: return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) else: return att_replace @abc.abstractmethod def replace_cross_attention(self, attn_base, att_replace): raise NotImplementedError def forward(self, attn, is_cross: bool, place_in_unet: str): super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) # FIXME not replace correctly if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): h = attn.shape[0] // (self.batch_size) attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) attn_base, attn_repalce = attn[0], attn[1:] if is_cross: alpha_words = self.cross_replace_alpha[self.cur_step] attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce attn[1:] = attn_repalce_new else: attn[1:] = self.replace_self_attention(attn_base, attn_repalce) attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) return attn def __init__(self, prompts, num_steps: int, cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], self_replace_steps: Union[float, Tuple[float, float]], local_blend: Optional[LocalBlend], tokenizer, device, dtype): super(AttentionControlEdit, self).__init__() # add tokenizer and device here self.tokenizer = tokenizer self.device = device self.dtype = dtype self.batch_size = len(prompts) self.cross_replace_alpha = get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, self.tokenizer).to(self.device, self.dtype) if type(self_replace_steps) is float: self_replace_steps = 0, self_replace_steps self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) self.local_blend = local_blend # 在外面定义后传进来 class AttentionReplace(AttentionControlEdit): def replace_cross_attention(self, attn_base, att_replace): return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, local_blend: Optional[LocalBlend] = None, tokenizer=None, device=None, dtype=torch.float32): super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, dtype) self.mapper = get_replacement_mapper(prompts, self.tokenizer).to(self.device, dtype=dtype) class AttentionRefine(AttentionControlEdit): def replace_cross_attention(self, attn_base, att_replace): attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) return attn_replace def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, local_blend: Optional[LocalBlend] = None, tokenizer=None, device=None, dtype=torch.float32): super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, dtype) self.mapper, alphas = get_refinement_mapper(prompts, self.tokenizer) self.mapper, alphas = self.mapper.to(self.device, self.dtype), alphas.to(self.device, self.dtype) self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) class AttentionReweight(AttentionControlEdit): def replace_cross_attention(self, attn_base, att_replace): if self.prev_controller is not None: attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] return attn_replace def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None, tokenizer=None, device=None, dtype=torch.float32): super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend, tokenizer, device, dtype) self.equalizer = equalizer.to(self.device, self.dtype) self.prev_controller = controller def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]], tokenizer): if type(word_select) is int or type(word_select) is str: word_select = (word_select,) equalizer = torch.ones(len(values), 77) values = torch.tensor(values, dtype=torch.float32) for word in word_select: inds = get_word_inds(text, word, tokenizer) equalizer[:, inds] = values return equalizer def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor]=None): if type(bounds) is float: bounds = 0, bounds start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) if word_inds is None: word_inds = torch.arange(alpha.shape[2]) alpha[: start, prompt_ind, word_inds] = 0 alpha[start: end, prompt_ind, word_inds] = 1 alpha[end:, prompt_ind, word_inds] = 0 return alpha def get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], tokenizer, max_num_words=77): if type(cross_replace_steps) is not dict: cross_replace_steps = {"default_": cross_replace_steps} if "default_" not in cross_replace_steps: cross_replace_steps["default_"] = (0., 1.) alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) for i in range(len(prompts) - 1): alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], i) for key, item in cross_replace_steps.items(): if key != "default_": inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] for i, ind in enumerate(inds): if len(ind) > 0: alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) return alpha_time_words # seg_alinger class ScoreParams: def __init__(self, gap, match, mismatch): self.gap = gap self.match = match self.mismatch = mismatch def mis_match_char(self, x, y): if x != y: return self.mismatch else: return self.match def get_matrix(size_x, size_y, gap): matrix = [] for i in range(len(size_x) + 1): sub_matrix = [] for j in range(len(size_y) + 1): sub_matrix.append(0) matrix.append(sub_matrix) for j in range(1, len(size_y) + 1): matrix[0][j] = j*gap for i in range(1, len(size_x) + 1): matrix[i][0] = i*gap return matrix def get_matrix(size_x, size_y, gap): matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) matrix[0, 1:] = (np.arange(size_y) + 1) * gap matrix[1:, 0] = (np.arange(size_x) + 1) * gap return matrix def get_traceback_matrix(size_x, size_y): matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32) matrix[0, 1:] = 1 matrix[1:, 0] = 2 matrix[0, 0] = 4 return matrix def global_align(x, y, score): matrix = get_matrix(len(x), len(y), score.gap) trace_back = get_traceback_matrix(len(x), len(y)) for i in range(1, len(x) + 1): for j in range(1, len(y) + 1): left = matrix[i, j - 1] + score.gap up = matrix[i - 1, j] + score.gap diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) matrix[i, j] = max(left, up, diag) if matrix[i, j] == left: trace_back[i, j] = 1 elif matrix[i, j] == up: trace_back[i, j] = 2 else: trace_back[i, j] = 3 return matrix, trace_back def get_aligned_sequences(x, y, trace_back): x_seq = [] y_seq = [] i = len(x) j = len(y) mapper_y_to_x = [] while i > 0 or j > 0: if trace_back[i, j] == 3: x_seq.append(x[i-1]) y_seq.append(y[j-1]) i = i-1 j = j-1 mapper_y_to_x.append((j, i)) elif trace_back[i][j] == 1: x_seq.append('-') y_seq.append(y[j-1]) j = j-1 mapper_y_to_x.append((j, -1)) elif trace_back[i][j] == 2: x_seq.append(x[i-1]) y_seq.append('-') i = i-1 elif trace_back[i][j] == 4: break mapper_y_to_x.reverse() return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) def get_mapper(x: str, y: str, tokenizer, max_len=77): x_seq = tokenizer.encode(x) y_seq = tokenizer.encode(y) score = ScoreParams(0, 1, -1) matrix, trace_back = global_align(x_seq, y_seq, score) mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] alphas = torch.ones(max_len) alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() mapper = torch.zeros(max_len, dtype=torch.int64) mapper[:mapper_base.shape[0]] = mapper_base[:, 1] mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq)) return mapper, alphas def get_refinement_mapper(prompts, tokenizer, max_len=77): x_seq = prompts[0] mappers, alphas = [], [] for i in range(1, len(prompts)): mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) mappers.append(mapper) alphas.append(alpha) return torch.stack(mappers), torch.stack(alphas) def get_word_inds(text: str, word_place: int, tokenizer): split_text = text.split(" ") if type(word_place) is str: word_place = [i for i, word in enumerate(split_text) if word_place == word] elif type(word_place) is int: word_place = [word_place] out = [] if len(word_place) > 0: words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] cur_len, ptr = 0, 0 for i in range(len(words_encode)): cur_len += len(words_encode[i]) if ptr in word_place: out.append(i + 1) if cur_len >= len(split_text[ptr]): ptr += 1 cur_len = 0 return np.array(out) def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): words_x = x.split(' ') words_y = y.split(' ') if len(words_x) != len(words_y): raise ValueError(f"attention replacement edit can only be applied on prompts with the same length" f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.") inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] mapper = np.zeros((max_len, max_len)) i = j = 0 cur_inds = 0 while i < max_len and j < max_len: if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] if len(inds_source_) == len(inds_target_): mapper[inds_source_, inds_target_] = 1 else: ratio = 1 / len(inds_target_) for i_t in inds_target_: mapper[inds_source_, i_t] = ratio cur_inds += 1 i += len(inds_source_) j += len(inds_target_) elif cur_inds < len(inds_source): mapper[i, j] = 1 i += 1 j += 1 else: mapper[j, j] = 1 i += 1 j += 1 return torch.from_numpy(mapper).float() def get_replacement_mapper(prompts, tokenizer, max_len=77): x_seq = prompts[0] mappers = [] for i in range(1, len(prompts)): mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) mappers.append(mapper) return torch.stack(mappers)