File size: 2,397 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch

import comfy.model_management as mm


def get_model_fn(model):
    #  sample, sample_null, cfg_scale
    def model_fn(z, sigma, positive, negative, cfg):
        model.dit.to(model.device)
        if hasattr(model.dit, "cublas_half_matmul") and model.dit.cublas_half_matmul:
            autocast_dtype = torch.float16
        else:
            autocast_dtype = torch.bfloat16
        
        with torch.autocast(mm.get_autocast_device(model.device), dtype=autocast_dtype):
            if cfg > 1.0:
                out_cond = model.dit(z, sigma, **positive)
                out_uncond = model.dit(z, sigma, **negative)
            else:
                out_cond = model.dit(z, sigma, **positive)
                return out_cond

        return out_uncond + cfg * (out_cond - out_uncond)
    
    return model_fn


def get_sample_args(model, cond_embeds, uncond_embeds):
    cond_args = {
        "y_mask": [cond_embeds["attention_mask"].to(model.device)],
        "y_feat": [cond_embeds["embeds"].to(model.device)]
    }

    uncond_args = {
        "y_mask": [uncond_embeds["attention_mask"].to(model.device)],
        "y_feat": [uncond_embeds["embeds"].to(model.device)]
    }
    return cond_args, uncond_args


def prepare_conds(positive, negative):
    #For compatibility with Comfy CLIPTextEncode
    if not isinstance(positive, dict):
        positive = {
            "embeds": positive[0][0],
            "attention_mask": positive[0][1]["attention_mask"].bool(),
            }
    if not isinstance(negative, dict):
        negative = {
            "embeds": negative[0][0],
            "attention_mask": negative[0][1]["attention_mask"].bool(),
            }
    return positive, negative


def generate_eta_values(steps, start_time, end_time, eta, eta_trend):
    end_time = min(end_time, steps)
    eta_values = [0] * steps
    
    if eta_trend == 'constant':
        for i in range(start_time, end_time):
            eta_values[i] = eta
    elif eta_trend == 'linear_increase':
        for i in range(start_time, end_time):
            progress = (i - start_time) / (end_time - start_time - 1)
            eta_values[i] = eta * progress
    elif eta_trend == 'linear_decrease':
        for i in range(start_time, end_time):
            progress = 1 - (i - start_time) / (end_time - start_time - 1)
            eta_values[i] = eta * progress
    
    return eta_values