File size: 1,711 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch


from ..utils.latent_guide import LatentGuide


class AddLatentGuideNode:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {"model": ("MODEL",),
                             "latent": ("LATENT",),
                             "image_latent": ("LATENT",),
                             "index": ("INT", {"default": 0, "min": -1, "max": 9999, "step": 1}),
                             "insert": ("BOOLEAN", { "default": False }),
                }}

    RETURN_TYPES = ("MODEL", "LATENT")

    CATEGORY = "ltxtricks"
    FUNCTION = "generate"

    def generate(self, model, latent, image_latent, index, insert):
        image_latent = image_latent['samples']
        latent = latent['samples'].clone()
        
        # Convert negative index to positive
        if insert:
            # Handle insertion
            if index == 0:
                # Insert at beginning
                latent = torch.cat([image_latent[:,:,0:1], latent], dim=2)
            elif index >= latent.shape[2] or index < 0:
                # Append to end
                latent = torch.cat([latent, image_latent[:,:,0:1]], dim=2)
            else:
                # Insert in middle
                latent = torch.cat([
                    latent[:,:,:index],
                    image_latent[:,:,0:1],
                    latent[:,:,index:]
                ], dim=2)
        else:
            # Original replacement behavior
            latent[:,:,index] = image_latent[:,:,0]
        
        model = model.clone()
        guiding_latent = LatentGuide(image_latent, index)
        model.set_model_patch(guiding_latent, 'guiding_latents')
        
        return (model, {"samples": latent}, )