import os import torch """ This function takes in a task name and returns the direction in the embedding space that transforms class A to class B for the given task. Parameters: task_name (str): name of the task for which direction is to be constructed. Returns: torch.Tensor: A tensor representing the direction in the embedding space that transforms class A to class B. Examples: >>> construct_direction("cat2dog") """ def construct_direction(task_name): emb_dir = f"assets/embeddings_sd_1.4" if task_name=="cat2dog": embs_a = torch.load(os.path.join(emb_dir, f"cat.pt")) embs_b = torch.load(os.path.join(emb_dir, f"dog.pt")) return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0) elif task_name=="dog2cat": embs_a = torch.load(os.path.join(emb_dir, f"dog.pt")) embs_b = torch.load(os.path.join(emb_dir, f"cat.pt")) return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0) elif task_name=="horse2zebra": embs_a = torch.load(os.path.join(emb_dir, f"horse.pt")) embs_b = torch.load(os.path.join(emb_dir, f"zebra.pt")) return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0) elif task_name=="zebra2horse": embs_a = torch.load(os.path.join(emb_dir, f"zebra.pt")) embs_b = torch.load(os.path.join(emb_dir, f"horse.pt")) return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0) elif task_name=="horse2llama": embs_a = torch.load(os.path.join(emb_dir, f"horse.pt")) embs_b = torch.load(os.path.join(emb_dir, f"llama.pt")) return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0) elif task_name=="dog2capy": embs_a = torch.load(os.path.join(emb_dir, f"dog.pt")) embs_b = torch.load(os.path.join(emb_dir, f"capy.pt")) return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0) elif task_name=='dogglasses': embs_a = torch.load(os.path.join(emb_dir, f"dog.pt")) embs_b = torch.load(os.path.join(emb_dir, f"dogs_with_glasses.pt")) return (embs_b.mean(0)-embs_a.mean(0)).unsqueeze(0) else: raise NotImplementedError