|
import torch |
|
|
|
def get_moe_embedding(moe_type): |
|
|
|
if moe_type == 'attribute': |
|
Task_attribute = { |
|
|
|
'image_classification': { |
|
"input": |
|
torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float), |
|
}, |
|
'video_classification': { |
|
"input": |
|
torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float), |
|
}, |
|
'text_mlm': { |
|
"input": |
|
torch.tensor([[0, 1, 0, 1, 0, 1, 0, 0]], dtype=torch.float), |
|
}, |
|
'image_caption': { |
|
"input": |
|
torch.tensor( |
|
[[1, 1, 0, 1, 1, 0, 0, 0], [1, 1, 0, 1, 0, 1, 0, 1]], |
|
dtype=torch.float) |
|
}, |
|
'video_caption': { |
|
"input": |
|
torch.tensor( |
|
[[1, 1, 0, 1, 1, 0, 0, 0], [1, 1, 0, 1, 0, 1, 0, 1]], |
|
dtype=torch.float) |
|
}, |
|
'image_retrieval': { |
|
'input': |
|
torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float), |
|
'target': |
|
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
}, |
|
'video_retrieval': { |
|
'input': |
|
torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float), |
|
'target': |
|
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
}, |
|
'text_classification': { |
|
"input": |
|
torch.tensor([[0, 1, 0, 1, 0, 1, 0, 0]], dtype=torch.float), |
|
}, |
|
|
|
|
|
|
|
"ImageNet1k": |
|
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
"ImageNet22k": |
|
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
"MomentsInTime": |
|
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
"Kinetics700": |
|
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
"Kinetics400": |
|
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
"Vocab_Word": |
|
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
"CoLA-target": |
|
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
"MNLI-target": |
|
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
"MRPC-target": |
|
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
"QNLI-target": |
|
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
"QQP-target": |
|
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
"RTE-target": |
|
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
"SST-2-target": |
|
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float), |
|
} |
|
return Task_attribute |
|
else: |
|
raise NotImplementedError(f'please check MOE_TYPE {moe_type}') |
|
|
|
|
|
|
|
def get_embed_with_task_type(moe_type: str, task_type: str, data_type: str): |
|
if moe_type is None: |
|
return None |
|
embedding_dict = get_moe_embedding(moe_type) |
|
return embedding_dict[task_type][data_type] |
|
|
|
|
|
def get_embed_with_shared_tagert_name(moe_type: str, set_name: str,): |
|
if moe_type is None: |
|
return None |
|
embedding_dict = get_moe_embedding(moe_type) |
|
return embedding_dict[set_name] |
|
|