wzhouxiff's picture
init
2890711
raw
history blame
11.8 kB
import pathlib
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Optional
from omegaconf import OmegaConf
from sgm.inference.helpers import (Img2ImgDiscretizationWrapper, do_img2img,
do_sample)
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
LinearMultistepSampler)
from sgm.util import load_model_from_config
class ModelArchitecture(str, Enum):
SD_2_1 = "stable-diffusion-v2-1"
SD_2_1_768 = "stable-diffusion-v2-1-768"
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
class Sampler(str, Enum):
EULER_EDM = "EulerEDMSampler"
HEUN_EDM = "HeunEDMSampler"
EULER_ANCESTRAL = "EulerAncestralSampler"
DPMPP2S_ANCESTRAL = "DPMPP2SAncestralSampler"
DPMPP2M = "DPMPP2MSampler"
LINEAR_MULTISTEP = "LinearMultistepSampler"
class Discretization(str, Enum):
LEGACY_DDPM = "LegacyDDPMDiscretization"
EDM = "EDMDiscretization"
class Guider(str, Enum):
VANILLA = "VanillaCFG"
IDENTITY = "IdentityGuider"
class Thresholder(str, Enum):
NONE = "None"
@dataclass
class SamplingParams:
width: int = 1024
height: int = 1024
steps: int = 50
sampler: Sampler = Sampler.DPMPP2M
discretization: Discretization = Discretization.LEGACY_DDPM
guider: Guider = Guider.VANILLA
thresholder: Thresholder = Thresholder.NONE
scale: float = 6.0
aesthetic_score: float = 5.0
negative_aesthetic_score: float = 5.0
img2img_strength: float = 1.0
orig_width: int = 1024
orig_height: int = 1024
crop_coords_top: int = 0
crop_coords_left: int = 0
sigma_min: float = 0.0292
sigma_max: float = 14.6146
rho: float = 3.0
s_churn: float = 0.0
s_tmin: float = 0.0
s_tmax: float = 999.0
s_noise: float = 1.0
eta: float = 1.0
order: int = 4
@dataclass
class SamplingSpec:
width: int
height: int
channels: int
factor: int
is_legacy: bool
config: str
ckpt: str
is_guided: bool
model_specs = {
ModelArchitecture.SD_2_1: SamplingSpec(
height=512,
width=512,
channels=4,
factor=8,
is_legacy=True,
config="sd_2_1.yaml",
ckpt="v2-1_512-ema-pruned.safetensors",
is_guided=True,
),
ModelArchitecture.SD_2_1_768: SamplingSpec(
height=768,
width=768,
channels=4,
factor=8,
is_legacy=True,
config="sd_2_1_768.yaml",
ckpt="v2-1_768-ema-pruned.safetensors",
is_guided=True,
),
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
height=1024,
width=1024,
channels=4,
factor=8,
is_legacy=False,
config="sd_xl_base.yaml",
ckpt="sd_xl_base_0.9.safetensors",
is_guided=True,
),
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
height=1024,
width=1024,
channels=4,
factor=8,
is_legacy=True,
config="sd_xl_refiner.yaml",
ckpt="sd_xl_refiner_0.9.safetensors",
is_guided=True,
),
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
height=1024,
width=1024,
channels=4,
factor=8,
is_legacy=False,
config="sd_xl_base.yaml",
ckpt="sd_xl_base_1.0.safetensors",
is_guided=True,
),
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
height=1024,
width=1024,
channels=4,
factor=8,
is_legacy=True,
config="sd_xl_refiner.yaml",
ckpt="sd_xl_refiner_1.0.safetensors",
is_guided=True,
),
}
class SamplingPipeline:
def __init__(
self,
model_id: ModelArchitecture,
model_path="checkpoints",
config_path="configs/inference",
device="cuda",
use_fp16=True,
) -> None:
if model_id not in model_specs:
raise ValueError(f"Model {model_id} not supported")
self.model_id = model_id
self.specs = model_specs[self.model_id]
self.config = str(pathlib.Path(config_path, self.specs.config))
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
self.device = device
self.model = self._load_model(device=device, use_fp16=use_fp16)
def _load_model(self, device="cuda", use_fp16=True):
config = OmegaConf.load(self.config)
model = load_model_from_config(config, self.ckpt)
if model is None:
raise ValueError(f"Model {self.model_id} could not be loaded")
model.to(device)
if use_fp16:
model.conditioner.half()
model.model.half()
return model
def text_to_image(
self,
params: SamplingParams,
prompt: str,
negative_prompt: str = "",
samples: int = 1,
return_latents: bool = False,
):
sampler = get_sampler_config(params)
value_dict = asdict(params)
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
value_dict["target_width"] = params.width
value_dict["target_height"] = params.height
return do_sample(
self.model,
sampler,
value_dict,
samples,
params.height,
params.width,
self.specs.channels,
self.specs.factor,
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=None,
)
def image_to_image(
self,
params: SamplingParams,
image,
prompt: str,
negative_prompt: str = "",
samples: int = 1,
return_latents: bool = False,
):
sampler = get_sampler_config(params)
if params.img2img_strength < 1.0:
sampler.discretization = Img2ImgDiscretizationWrapper(
sampler.discretization,
strength=params.img2img_strength,
)
height, width = image.shape[2], image.shape[3]
value_dict = asdict(params)
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
value_dict["target_width"] = width
value_dict["target_height"] = height
return do_img2img(
image,
self.model,
sampler,
value_dict,
samples,
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=None,
)
def refiner(
self,
params: SamplingParams,
image,
prompt: str,
negative_prompt: Optional[str] = None,
samples: int = 1,
return_latents: bool = False,
):
sampler = get_sampler_config(params)
value_dict = {
"orig_width": image.shape[3] * 8,
"orig_height": image.shape[2] * 8,
"target_width": image.shape[3] * 8,
"target_height": image.shape[2] * 8,
"prompt": prompt,
"negative_prompt": negative_prompt,
"crop_coords_top": 0,
"crop_coords_left": 0,
"aesthetic_score": 6.0,
"negative_aesthetic_score": 2.5,
}
return do_img2img(
image,
self.model,
sampler,
value_dict,
samples,
skip_encode=True,
return_latents=return_latents,
filter=None,
)
def get_guider_config(params: SamplingParams):
if params.guider == Guider.IDENTITY:
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif params.guider == Guider.VANILLA:
scale = params.scale
thresholder = params.thresholder
if thresholder == Thresholder.NONE:
dyn_thresh_config = {
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
}
else:
raise NotImplementedError
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
}
else:
raise NotImplementedError
return guider_config
def get_discretization_config(params: SamplingParams):
if params.discretization == Discretization.LEGACY_DDPM:
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
elif params.discretization == Discretization.EDM:
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
"params": {
"sigma_min": params.sigma_min,
"sigma_max": params.sigma_max,
"rho": params.rho,
},
}
else:
raise ValueError(f"unknown discretization {params.discretization}")
return discretization_config
def get_sampler_config(params: SamplingParams):
discretization_config = get_discretization_config(params)
guider_config = get_guider_config(params)
sampler = None
if params.sampler == Sampler.EULER_EDM:
return EulerEDMSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=params.s_churn,
s_tmin=params.s_tmin,
s_tmax=params.s_tmax,
s_noise=params.s_noise,
verbose=True,
)
if params.sampler == Sampler.HEUN_EDM:
return HeunEDMSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=params.s_churn,
s_tmin=params.s_tmin,
s_tmax=params.s_tmax,
s_noise=params.s_noise,
verbose=True,
)
if params.sampler == Sampler.EULER_ANCESTRAL:
return EulerAncestralSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=params.eta,
s_noise=params.s_noise,
verbose=True,
)
if params.sampler == Sampler.DPMPP2S_ANCESTRAL:
return DPMPP2SAncestralSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=params.eta,
s_noise=params.s_noise,
verbose=True,
)
if params.sampler == Sampler.DPMPP2M:
return DPMPP2MSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
verbose=True,
)
if params.sampler == Sampler.LINEAR_MULTISTEP:
return LinearMultistepSampler(
num_steps=params.steps,
discretization_config=discretization_config,
guider_config=guider_config,
order=params.order,
verbose=True,
)
raise ValueError(f"unknown sampler {params.sampler}!")