File size: 2,102 Bytes
7e0bf18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import torch


"""
This function takes in a task name and returns the direction in the embedding space that transforms class A to class B for the given task.

Parameters:
task_name (str): name of the task for which direction is to be constructed.

Returns:
torch.Tensor: A tensor representing the direction in the embedding space that transforms class A to class B.

Examples:
>>> construct_direction("cat2dog")
"""
def construct_direction(task_name):
    emb_dir = f"assets/embeddings_sd_1.4"
    if task_name=="cat2dog":    
        embs_a = torch.load(os.path.join(emb_dir, f"cat.pt"))
        embs_b = torch.load(os.path.join(emb_dir, f"dog.pt"))
        return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
    elif task_name=="dog2cat":    
        embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
        embs_b = torch.load(os.path.join(emb_dir, f"cat.pt"))
        return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
    elif task_name=="horse2zebra":
        embs_a = torch.load(os.path.join(emb_dir, f"horse.pt"))
        embs_b = torch.load(os.path.join(emb_dir, f"zebra.pt"))
        return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
    elif task_name=="zebra2horse":
        embs_a = torch.load(os.path.join(emb_dir, f"zebra.pt"))
        embs_b = torch.load(os.path.join(emb_dir, f"horse.pt"))
        return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
    elif task_name=="horse2llama":
        embs_a = torch.load(os.path.join(emb_dir, f"horse.pt"))
        embs_b = torch.load(os.path.join(emb_dir, f"llama.pt"))
        return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
    elif task_name=="dog2capy":
        embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
        embs_b = torch.load(os.path.join(emb_dir, f"capy.pt"))
        return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
    elif task_name=='dogglasses':
        embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
        embs_b = torch.load(os.path.join(emb_dir, f"dogs_with_glasses.pt"))
        return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
    else:
        raise NotImplementedError