jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import torch
import comfy.model_management as mm
def get_model_fn(model):
# sample, sample_null, cfg_scale
def model_fn(z, sigma, positive, negative, cfg):
model.dit.to(model.device)
if hasattr(model.dit, "cublas_half_matmul") and model.dit.cublas_half_matmul:
autocast_dtype = torch.float16
else:
autocast_dtype = torch.bfloat16
with torch.autocast(mm.get_autocast_device(model.device), dtype=autocast_dtype):
if cfg > 1.0:
out_cond = model.dit(z, sigma, **positive)
out_uncond = model.dit(z, sigma, **negative)
else:
out_cond = model.dit(z, sigma, **positive)
return out_cond
return out_uncond + cfg * (out_cond - out_uncond)
return model_fn
def get_sample_args(model, cond_embeds, uncond_embeds):
cond_args = {
"y_mask": [cond_embeds["attention_mask"].to(model.device)],
"y_feat": [cond_embeds["embeds"].to(model.device)]
}
uncond_args = {
"y_mask": [uncond_embeds["attention_mask"].to(model.device)],
"y_feat": [uncond_embeds["embeds"].to(model.device)]
}
return cond_args, uncond_args
def prepare_conds(positive, negative):
#For compatibility with Comfy CLIPTextEncode
if not isinstance(positive, dict):
positive = {
"embeds": positive[0][0],
"attention_mask": positive[0][1]["attention_mask"].bool(),
}
if not isinstance(negative, dict):
negative = {
"embeds": negative[0][0],
"attention_mask": negative[0][1]["attention_mask"].bool(),
}
return positive, negative
def generate_eta_values(steps, start_time, end_time, eta, eta_trend):
end_time = min(end_time, steps)
eta_values = [0] * steps
if eta_trend == 'constant':
for i in range(start_time, end_time):
eta_values[i] = eta
elif eta_trend == 'linear_increase':
for i in range(start_time, end_time):
progress = (i - start_time) / (end_time - start_time - 1)
eta_values[i] = eta * progress
elif eta_trend == 'linear_decrease':
for i in range(start_time, end_time):
progress = 1 - (i - start_time) / (end_time - start_time - 1)
eta_values[i] = eta * progress
return eta_values