Spaces:
Paused
Paused
from dataclasses import dataclass, field | |
from typing import Optional, Dict, Any | |
import torch | |
class ModelConfig: | |
# Image dimensions | |
width: int = 512 | |
height: int = 512 | |
latents_width: int = 64 # width // 8 | |
latents_height: int = 64 # height // 8 | |
# Model architecture parameters | |
n_embd: int = 1280 | |
n_head: int = 8 | |
d_context: int = 768 | |
# UNet parameters | |
n_time: int = 1280 | |
n_channels: int = 4 | |
n_residual_blocks: int = 2 | |
# Attention parameters | |
attention_heads: int = 8 | |
attention_dim: int = 1280 | |
class DiffusionConfig: | |
# Sampling parameters | |
n_inference_steps: int = 50 | |
guidance_scale: float = 7.5 | |
strength: float = 0.8 | |
# Sampler configuration | |
sampler_name: str = "ddpm" | |
beta_start: float = 0.00085 | |
beta_end: float = 0.0120 | |
beta_schedule: str = "linear" | |
# Conditioning parameters | |
do_cfg: bool = True | |
cfg_scale: float = 7.5 | |
class DeviceConfig: | |
device: Optional[str] = None | |
idle_device: Optional[str] = None | |
def __post_init__(self): | |
if self.device is None: | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
if self.idle_device is None: | |
self.idle_device = "cpu" | |
class Config: | |
model: ModelConfig = field(default_factory=ModelConfig) | |
diffusion: DiffusionConfig = field(default_factory=DiffusionConfig) | |
device: DeviceConfig = field(default_factory=DeviceConfig) | |
# Additional settings | |
seed: Optional[int] = None | |
tokenizer: Optional[Any] = None | |
models: Dict[str, Any] = field(default_factory=dict) | |
def __post_init__(self): | |
# Update latent dimensions based on image dimensions | |
self.model.latents_width = self.model.width // 8 | |
self.model.latents_height = self.model.height // 8 | |
# Default configuration instance | |
default_config = Config() |