|
import logging |
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
log = logging.getLogger(__name__) |
|
|
|
import torch |
|
|
|
from ..utils.callback_utils import get_callback_fn |
|
from ..utils.latent_utils import add_latent_noise |
|
from ..utils.sampling_utils import get_model_fn, get_sample_args |
|
|
|
|
|
def run_sampler(model, latents, positive, negative, sigmas, cfg, sampler_fn, add_noise=False, seed=0): |
|
|
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
generator = torch.Generator(device=model.device) |
|
generator.manual_seed(seed) |
|
|
|
|
|
latent_shape = latents.shape |
|
|
|
if add_noise: |
|
z = add_latent_noise(model, latent_shape, sigmas, latents, generator) |
|
else: |
|
z = latents.clone() |
|
|
|
|
|
positive, negative = get_sample_args(model, positive, negative) |
|
model_fn = get_model_fn(model) |
|
|
|
|
|
callback_fn = get_callback_fn(model, len(sigmas)-1) |
|
extra_args = { |
|
"positive": positive, |
|
"negative": negative, |
|
"cfg": cfg |
|
} |
|
z = sampler_fn(model_fn, z, sigmas, callback=callback_fn, extra_args=extra_args) |
|
|
|
|
|
model.dit.to(model.offload_device) |
|
|
|
return z |
|
|