File size: 1,253 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
import torch
from ..utils.callback_utils import get_callback_fn
from ..utils.latent_utils import add_latent_noise
from ..utils.sampling_utils import get_model_fn, get_sample_args
def run_sampler(model, latents, positive, negative, sigmas, cfg, sampler_fn, add_noise=False, seed=0):
# seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
generator = torch.Generator(device=model.device)
generator.manual_seed(seed)
# prepare latents
latent_shape = latents.shape
if add_noise:
z = add_latent_noise(model, latent_shape, sigmas, latents, generator)
else:
z = latents.clone()
# prepare model and args
positive, negative = get_sample_args(model, positive, negative)
model_fn = get_model_fn(model)
# sampling
callback_fn = get_callback_fn(model, len(sigmas)-1)
extra_args = {
"positive": positive,
"negative": negative,
"cfg": cfg
}
z = sampler_fn(model_fn, z, sigmas, callback=callback_fn, extra_args=extra_args)
# cleanup
model.dit.to(model.offload_device)
return z
|