File size: 3,585 Bytes
32b542e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import torch
def get_moe_embedding(moe_type):
if moe_type == 'attribute':
Task_attribute = {
# task input -- TASK_TYPE & data_type
'image_classification': {
"input":
torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float),
},
'video_classification': {
"input":
torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float),
},
'text_mlm': {
"input":
torch.tensor([[0, 1, 0, 1, 0, 1, 0, 0]], dtype=torch.float),
},
'image_caption': {
"input":
torch.tensor(
[[1, 1, 0, 1, 1, 0, 0, 0], [1, 1, 0, 1, 0, 1, 0, 1]],
dtype=torch.float)
},
'video_caption': {
"input":
torch.tensor(
[[1, 1, 0, 1, 1, 0, 0, 0], [1, 1, 0, 1, 0, 1, 0, 1]],
dtype=torch.float)
},
'image_retrieval': {
'input':
torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float),
'target':
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
},
'video_retrieval': {
'input':
torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float),
'target':
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
},
'text_classification': {
"input":
torch.tensor([[0, 1, 0, 1, 0, 1, 0, 0]], dtype=torch.float),
},
# SHARED_TARGETS
"ImageNet1k":
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"ImageNet22k":
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"MomentsInTime":
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"Kinetics700":
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"Kinetics400":
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"Vocab_Word":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"CoLA-target":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"MNLI-target":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"MRPC-target":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"QNLI-target":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"QQP-target":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"RTE-target":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"SST-2-target":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
}
return Task_attribute
else:
raise NotImplementedError(f'please check MOE_TYPE {moe_type}')
def get_embed_with_task_type(moe_type: str, task_type: str, data_type: str):
if moe_type is None:
return None
embedding_dict = get_moe_embedding(moe_type)
return embedding_dict[task_type][data_type]
def get_embed_with_shared_tagert_name(moe_type: str, set_name: str,):
if moe_type is None:
return None
embedding_dict = get_moe_embedding(moe_type)
return embedding_dict[set_name]
|