File size: 6,138 Bytes
239ee43 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import json
from pydantic import BaseModel, validator
from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
from enum import Enum
from imagen_pytorch.imagen_pytorch import Imagen, Unet, Unet3D, NullUnet
from imagen_pytorch.trainer import ImagenTrainer
from imagen_pytorch.elucidated_imagen import ElucidatedImagen
from imagen_pytorch.t5 import DEFAULT_T5_NAME, get_encoded_dim
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def ListOrTuple(inner_type):
return Union[List[inner_type], Tuple[inner_type]]
def SingleOrList(inner_type):
return Union[inner_type, ListOrTuple(inner_type)]
# noise schedule
class NoiseSchedule(Enum):
cosine = 'cosine'
linear = 'linear'
class AllowExtraBaseModel(BaseModel):
class Config:
extra = "allow"
use_enum_values = True
# imagen pydantic classes
class NullUnetConfig(BaseModel):
is_null: bool
def create(self):
return NullUnet()
class UnetConfig(AllowExtraBaseModel):
dim: int
dim_mults: ListOrTuple(int)
text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
cond_dim: int = None
channels: int = 3
attn_dim_head: int = 32
attn_heads: int = 16
def create(self):
return Unet(**self.dict())
class Unet3DConfig(AllowExtraBaseModel):
dim: int
dim_mults: ListOrTuple(int)
text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
cond_dim: int = None
channels: int = 3
attn_dim_head: int = 32
attn_heads: int = 16
def create(self):
return Unet3D(**self.dict())
class ImagenConfig(AllowExtraBaseModel):
unets: ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig])
image_sizes: ListOrTuple(int)
video: bool = False
timesteps: SingleOrList(int) = 1000
noise_schedules: SingleOrList(NoiseSchedule) = 'cosine'
text_encoder_name: str = DEFAULT_T5_NAME
channels: int = 3
loss_type: str = 'l2'
cond_drop_prob: float = 0.5
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
unets = values.get('unets')
if len(image_sizes) != len(unets):
raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}')
return image_sizes
def create(self):
decoder_kwargs = self.dict()
unets_kwargs = decoder_kwargs.pop('unets')
is_video = decoder_kwargs.pop('video', False)
unets = []
for unet, unet_kwargs in zip(self.unets, unets_kwargs):
if isinstance(unet, NullUnetConfig):
unet_klass = NullUnet
elif is_video:
unet_klass = Unet3D
else:
unet_klass = Unet
unets.append(unet_klass(**unet_kwargs))
imagen = Imagen(unets, **decoder_kwargs)
imagen._config = self.dict().copy()
return imagen
class ElucidatedImagenConfig(AllowExtraBaseModel):
unets: ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig])
image_sizes: ListOrTuple(int)
video: bool = False
text_encoder_name: str = DEFAULT_T5_NAME
channels: int = 3
cond_drop_prob: float = 0.5
num_sample_steps: SingleOrList(int) = 32
sigma_min: SingleOrList(float) = 0.002
sigma_max: SingleOrList(int) = 80
sigma_data: SingleOrList(float) = 0.5
rho: SingleOrList(int) = 7
P_mean: SingleOrList(float) = -1.2
P_std: SingleOrList(float) = 1.2
S_churn: SingleOrList(int) = 80
S_tmin: SingleOrList(float) = 0.05
S_tmax: SingleOrList(int) = 50
S_noise: SingleOrList(float) = 1.003
@validator('image_sizes')
def check_image_sizes(cls, image_sizes, values):
unets = values.get('unets')
if len(image_sizes) != len(unets):
raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}')
return image_sizes
def create(self):
decoder_kwargs = self.dict()
unets_kwargs = decoder_kwargs.pop('unets')
is_video = decoder_kwargs.pop('video', False)
unet_klass = Unet3D if is_video else Unet
unets = []
for unet, unet_kwargs in zip(self.unets, unets_kwargs):
if isinstance(unet, NullUnetConfig):
unet_klass = NullUnet
elif is_video:
unet_klass = Unet3D
else:
unet_klass = Unet
unets.append(unet_klass(**unet_kwargs))
imagen = ElucidatedImagen(unets, **decoder_kwargs)
imagen._config = self.dict().copy()
return imagen
class ImagenTrainerConfig(AllowExtraBaseModel):
imagen: dict
elucidated: bool = False
video: bool = False
use_ema: bool = True
lr: SingleOrList(float) = 1e-4
eps: SingleOrList(float) = 1e-8
beta1: float = 0.9
beta2: float = 0.99
max_grad_norm: Optional[float] = None
group_wd_params: bool = True
warmup_steps: SingleOrList(Optional[int]) = None
cosine_decay_max_steps: SingleOrList(Optional[int]) = None
def create(self):
trainer_kwargs = self.dict()
imagen_config = trainer_kwargs.pop('imagen')
elucidated = trainer_kwargs.pop('elucidated')
imagen_config_klass = ElucidatedImagenConfig if elucidated else ImagenConfig
imagen = imagen_config_klass(**{**imagen_config, 'video': video}).create()
return ImagenTrainer(imagen, **trainer_kwargs)
|