File size: 1,694 Bytes
c025a3d
 
ca154e9
 
 
c025a3d
ca154e9
 
 
 
c025a3d
ca154e9
c025a3d
 
ca154e9
 
c025a3d
 
 
ca154e9
 
 
 
 
 
 
 
 
 
c025a3d
 
 
ca154e9
 
c025a3d
 
 
 
ca154e9
 
 
 
 
 
 
 
 
c025a3d
 
 
ca154e9
 
 
 
 
 
 
 
 
c025a3d
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import numpy as np
from typing import List, Dict, Optional
from PIL.Image import Image as PILImage
from IP_Adapter import IPAdapterXL

def compute_dataset_embeds_svd(
    all_embeds: np.ndarray,
    rank: int
) -> np.ndarray:
    # Perform SVD on the combined matrix
    _, _, v = np.linalg.svd(all_embeds, full_matrices=False)

    # Select the top `rank` singular vectors to construct the projection matrix
    v = v[:rank]
    projection_matrix = v.T @ v

    return projection_matrix

def get_projected_embedding(
    embed: np.ndarray,
    projection_matrix: np.ndarray
) -> np.ndarray:
    return embed @ projection_matrix

def get_embedding_composition(
    embed: np.ndarray,
    projections_data: List[Dict[str, np.ndarray]]
) -> np.ndarray:
    combined_embeds = embed.copy()

    for proj_data in projections_data:
        combined_embeds -= get_projected_embedding(embed, proj_data["projection_matrix"])
        combined_embeds += get_projected_embedding(proj_data["embed"], proj_data["projection_matrix"])

    return combined_embeds


def get_modified_images_embeds_composition(
    embed: np.ndarray,
    projections_data: List[Dict[str, np.ndarray]],
    ip_model: IPAdapterXL,
    prompt: Optional[str] = None,
    scale: float = 1.0,
    num_samples: int = 3,
    seed: int = 420
) -> List[PILImage]:
    final_embeds = get_embedding_composition(embed, projections_data)
    clip_embeds = torch.from_numpy(final_embeds)

    images: List[PILImage] = ip_model.generate(
    clip_image_embeds=clip_embeds,
    prompt=prompt,
    num_samples=num_samples,
    num_inference_steps=50,
    seed=seed,
    guidance_scale=7.5,
    scale=scale
    )
    return images