# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Union, Tuple, List, Callable, Dict import torch from diffusers import StableDiffusionPipeline import torch.nn.functional as nnf import numpy as np import abc import ptp_scripts.ptp_utils as ptp_utils import matplotlib.pyplot as plt from PIL import Image LOW_RESOURCE = False class AttentionControl(abc.ABC): def step_callback(self, x_t): return x_t def between_steps(self): return @property def num_uncond_att_layers(self): return self.num_att_layers if LOW_RESOURCE else 0 @abc.abstractmethod def forward (self, attn, is_cross: bool, place_in_unet: str): raise NotImplementedError def __call__(self, attn, is_cross: bool, place_in_unet: str): if self.cur_att_layer >= self.num_uncond_att_layers: if LOW_RESOURCE: attn = self.forward(attn, is_cross, place_in_unet) else: h = attn.shape[0] attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) self.cur_att_layer += 1 if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: self.cur_att_layer = 0 self.cur_step += 1 self.between_steps() return attn def reset(self): self.cur_step = 0 self.cur_att_layer = 0 def __init__(self): self.cur_step = 0 self.num_att_layers = -1 self.cur_att_layer = 0 class AttentionStore(AttentionControl): @staticmethod def get_empty_store(): return {"down_cross": [], "mid_cross": [], "up_cross": [], "down_self": [], "mid_self": [], "up_self": []} def forward(self, attn, is_cross: bool, place_in_unet: str): key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" # if attn.shape[1] <= 16 ** 2: # and attn.shape[1] > 16 ** 2: # avoid memory overhead self.step_store[key].append(attn) return attn def between_steps(self): for key in self.step_store: self.attention_store[key] = self.step_store[key] self.step_store = self.get_empty_store() def get_average_attention(self): average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} return average_attention def reset(self): super(AttentionStore, self).reset() self.step_store = self.get_empty_store() self.attention_store = self.get_empty_store() def __init__(self): super(AttentionStore, self).__init__() self.step_store = self.get_empty_store() self.attention_store = self.get_empty_store()