Spaces:
Runtime error
Runtime error
# Copyright 2022 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Optional, Union, Tuple, List, Callable, Dict | |
import torch | |
from diffusers import StableDiffusionPipeline | |
import torch.nn.functional as nnf | |
import numpy as np | |
import abc | |
import ptp_scripts.ptp_utils as ptp_utils | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
LOW_RESOURCE = False | |
class AttentionControl(abc.ABC): | |
def step_callback(self, x_t): | |
return x_t | |
def between_steps(self): | |
return | |
def num_uncond_att_layers(self): | |
return self.num_att_layers if LOW_RESOURCE else 0 | |
def forward (self, attn, is_cross: bool, place_in_unet: str): | |
raise NotImplementedError | |
def __call__(self, attn, is_cross: bool, place_in_unet: str): | |
if self.cur_att_layer >= self.num_uncond_att_layers: | |
if LOW_RESOURCE: | |
attn = self.forward(attn, is_cross, place_in_unet) | |
else: | |
h = attn.shape[0] | |
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) | |
self.cur_att_layer += 1 | |
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: | |
self.cur_att_layer = 0 | |
self.cur_step += 1 | |
self.between_steps() | |
return attn | |
def reset(self): | |
self.cur_step = 0 | |
self.cur_att_layer = 0 | |
def __init__(self): | |
self.cur_step = 0 | |
self.num_att_layers = -1 | |
self.cur_att_layer = 0 | |
class AttentionStore(AttentionControl): | |
def get_empty_store(): | |
return {"down_cross": [], "mid_cross": [], "up_cross": [], | |
"down_self": [], "mid_self": [], "up_self": []} | |
def forward(self, attn, is_cross: bool, place_in_unet: str): | |
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" | |
# if attn.shape[1] <= 16 ** 2: # and attn.shape[1] > 16 ** 2: # avoid memory overhead | |
self.step_store[key].append(attn) | |
return attn | |
def between_steps(self): | |
for key in self.step_store: | |
self.attention_store[key] = self.step_store[key] | |
self.step_store = self.get_empty_store() | |
def get_average_attention(self): | |
average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store} | |
return average_attention | |
def reset(self): | |
super(AttentionStore, self).reset() | |
self.step_store = self.get_empty_store() | |
self.attention_store = self.get_empty_store() | |
def __init__(self): | |
super(AttentionStore, self).__init__() | |
self.step_store = self.get_empty_store() | |
self.attention_store = self.get_empty_store() |