Spaces:
Runtime error
Runtime error
File size: 2,102 Bytes
7e0bf18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import os
import torch
"""
This function takes in a task name and returns the direction in the embedding space that transforms class A to class B for the given task.
Parameters:
task_name (str): name of the task for which direction is to be constructed.
Returns:
torch.Tensor: A tensor representing the direction in the embedding space that transforms class A to class B.
Examples:
>>> construct_direction("cat2dog")
"""
def construct_direction(task_name):
emb_dir = f"assets/embeddings_sd_1.4"
if task_name=="cat2dog":
embs_a = torch.load(os.path.join(emb_dir, f"cat.pt"))
embs_b = torch.load(os.path.join(emb_dir, f"dog.pt"))
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
elif task_name=="dog2cat":
embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
embs_b = torch.load(os.path.join(emb_dir, f"cat.pt"))
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
elif task_name=="horse2zebra":
embs_a = torch.load(os.path.join(emb_dir, f"horse.pt"))
embs_b = torch.load(os.path.join(emb_dir, f"zebra.pt"))
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
elif task_name=="zebra2horse":
embs_a = torch.load(os.path.join(emb_dir, f"zebra.pt"))
embs_b = torch.load(os.path.join(emb_dir, f"horse.pt"))
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
elif task_name=="horse2llama":
embs_a = torch.load(os.path.join(emb_dir, f"horse.pt"))
embs_b = torch.load(os.path.join(emb_dir, f"llama.pt"))
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
elif task_name=="dog2capy":
embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
embs_b = torch.load(os.path.join(emb_dir, f"capy.pt"))
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
elif task_name=='dogglasses':
embs_a = torch.load(os.path.join(emb_dir, f"dog.pt"))
embs_b = torch.load(os.path.join(emb_dir, f"dogs_with_glasses.pt"))
return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0)
else:
raise NotImplementedError
|