Spaces:
Running
on
A10G
Running
on
A10G
File size: 5,471 Bytes
1f8beea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
from typing import Literal, Optional, Union, List
import yaml
from pathlib import Path
from pydantic import BaseModel, root_validator
import torch
import copy
ACTION_TYPES = Literal[
"erase",
"enhance",
]
# XL は二種類必要なので
class PromptEmbedsXL:
text_embeds: torch.FloatTensor
pooled_embeds: torch.FloatTensor
def __init__(self, *args) -> None:
self.text_embeds = args[0]
self.pooled_embeds = args[1]
# SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL
PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL]
class PromptEmbedsCache: # 使いまわしたいので
prompts: dict[str, PROMPT_EMBEDDING] = {}
def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None:
self.prompts[__name] = __value
def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]:
if __name in self.prompts:
return self.prompts[__name]
else:
return None
class PromptSettings(BaseModel): # yaml のやつ
target: str
positive: str = None # if None, target will be used
unconditional: str = "" # default is ""
neutral: str = None # if None, unconditional will be used
action: ACTION_TYPES = "erase" # default is "erase"
guidance_scale: float = 1.0 # default is 1.0
resolution: int = 512 # default is 512
dynamic_resolution: bool = False # default is False
batch_size: int = 1 # default is 1
dynamic_crops: bool = False # default is False. only used when model is XL
@root_validator(pre=True)
def fill_prompts(cls, values):
keys = values.keys()
if "target" not in keys:
raise ValueError("target must be specified")
if "positive" not in keys:
values["positive"] = values["target"]
if "unconditional" not in keys:
values["unconditional"] = ""
if "neutral" not in keys:
values["neutral"] = values["unconditional"]
return values
class PromptEmbedsPair:
target: PROMPT_EMBEDDING # not want to generate the concept
positive: PROMPT_EMBEDDING # generate the concept
unconditional: PROMPT_EMBEDDING # uncondition (default should be empty)
neutral: PROMPT_EMBEDDING # base condition (default should be empty)
guidance_scale: float
resolution: int
dynamic_resolution: bool
batch_size: int
dynamic_crops: bool
loss_fn: torch.nn.Module
action: ACTION_TYPES
def __init__(
self,
loss_fn: torch.nn.Module,
target: PROMPT_EMBEDDING,
positive: PROMPT_EMBEDDING,
unconditional: PROMPT_EMBEDDING,
neutral: PROMPT_EMBEDDING,
settings: PromptSettings,
) -> None:
self.loss_fn = loss_fn
self.target = target
self.positive = positive
self.unconditional = unconditional
self.neutral = neutral
self.guidance_scale = settings.guidance_scale
self.resolution = settings.resolution
self.dynamic_resolution = settings.dynamic_resolution
self.batch_size = settings.batch_size
self.dynamic_crops = settings.dynamic_crops
self.action = settings.action
def _erase(
self,
target_latents: torch.FloatTensor, # "van gogh"
positive_latents: torch.FloatTensor, # "van gogh"
unconditional_latents: torch.FloatTensor, # ""
neutral_latents: torch.FloatTensor, # ""
) -> torch.FloatTensor:
"""Target latents are going not to have the positive concept."""
return self.loss_fn(
target_latents,
neutral_latents
- self.guidance_scale * (positive_latents - unconditional_latents)
)
def _enhance(
self,
target_latents: torch.FloatTensor, # "van gogh"
positive_latents: torch.FloatTensor, # "van gogh"
unconditional_latents: torch.FloatTensor, # ""
neutral_latents: torch.FloatTensor, # ""
):
"""Target latents are going to have the positive concept."""
return self.loss_fn(
target_latents,
neutral_latents
+ self.guidance_scale * (positive_latents - unconditional_latents)
)
def loss(
self,
**kwargs,
):
if self.action == "erase":
return self._erase(**kwargs)
elif self.action == "enhance":
return self._enhance(**kwargs)
else:
raise ValueError("action must be erase or enhance")
def load_prompts_from_yaml(path, attributes = []):
with open(path, "r") as f:
prompts = yaml.safe_load(f)
print(prompts)
if len(prompts) == 0:
raise ValueError("prompts file is empty")
if len(attributes)!=0:
newprompts = []
for i in range(len(prompts)):
for att in attributes:
copy_ = copy.deepcopy(prompts[i])
copy_['target'] = att + ' ' + copy_['target']
copy_['positive'] = att + ' ' + copy_['positive']
copy_['neutral'] = att + ' ' + copy_['neutral']
copy_['unconditional'] = att + ' ' + copy_['unconditional']
newprompts.append(copy_)
else:
newprompts = copy.deepcopy(prompts)
print(newprompts)
print(len(prompts), len(newprompts))
prompt_settings = [PromptSettings(**prompt) for prompt in newprompts]
return prompt_settings
|