Spaces:
Paused
Paused
File size: 1,940 Bytes
ef6c3c2 74045df ef6c3c2 74045df ef6c3c2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from dataclasses import dataclass, field
from typing import Optional, Dict, Any
import torch
@dataclass
class ModelConfig:
# Image dimensions
width: int = 512
height: int = 512
latents_width: int = 64 # width // 8
latents_height: int = 64 # height // 8
# Model architecture parameters
n_embd: int = 1280
n_head: int = 8
d_context: int = 768
# UNet parameters
n_time: int = 1280
n_channels: int = 4
n_residual_blocks: int = 2
# Attention parameters
attention_heads: int = 8
attention_dim: int = 1280
@dataclass
class DiffusionConfig:
# Sampling parameters
n_inference_steps: int = 50
guidance_scale: float = 7.5
strength: float = 0.8
# Sampler configuration
sampler_name: str = "ddpm"
beta_start: float = 0.00085
beta_end: float = 0.0120
beta_schedule: str = "linear"
# Conditioning parameters
do_cfg: bool = True
cfg_scale: float = 7.5
@dataclass
class DeviceConfig:
device: Optional[str] = None
idle_device: Optional[str] = None
def __post_init__(self):
if self.device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if self.idle_device is None:
self.idle_device = "cpu"
@dataclass
class Config:
model: ModelConfig = field(default_factory=ModelConfig)
diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
device: DeviceConfig = field(default_factory=DeviceConfig)
# Additional settings
seed: Optional[int] = None
tokenizer: Optional[Any] = None
models: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
# Update latent dimensions based on image dimensions
self.model.latents_width = self.model.width // 8
self.model.latents_height = self.model.height // 8
# Default configuration instance
default_config = Config() |