unit_test / uniperceiver /datasets /moe_embeddings.py
herrius's picture
Upload 259 files
32b542e
import torch
def get_moe_embedding(moe_type):
if moe_type == 'attribute':
Task_attribute = {
# task input -- TASK_TYPE & data_type
'image_classification': {
"input":
torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float),
},
'video_classification': {
"input":
torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float),
},
'text_mlm': {
"input":
torch.tensor([[0, 1, 0, 1, 0, 1, 0, 0]], dtype=torch.float),
},
'image_caption': {
"input":
torch.tensor(
[[1, 1, 0, 1, 1, 0, 0, 0], [1, 1, 0, 1, 0, 1, 0, 1]],
dtype=torch.float)
},
'video_caption': {
"input":
torch.tensor(
[[1, 1, 0, 1, 1, 0, 0, 0], [1, 1, 0, 1, 0, 1, 0, 1]],
dtype=torch.float)
},
'image_retrieval': {
'input':
torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float),
'target':
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
},
'video_retrieval': {
'input':
torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float),
'target':
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
},
'text_classification': {
"input":
torch.tensor([[0, 1, 0, 1, 0, 1, 0, 0]], dtype=torch.float),
},
# SHARED_TARGETS
"ImageNet1k":
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"ImageNet22k":
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"MomentsInTime":
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"Kinetics700":
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"Kinetics400":
torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"Vocab_Word":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"CoLA-target":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"MNLI-target":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"MRPC-target":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"QNLI-target":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"QQP-target":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"RTE-target":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
"SST-2-target":
torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
}
return Task_attribute
else:
raise NotImplementedError(f'please check MOE_TYPE {moe_type}')
def get_embed_with_task_type(moe_type: str, task_type: str, data_type: str):
if moe_type is None:
return None
embedding_dict = get_moe_embedding(moe_type)
return embedding_dict[task_type][data_type]
def get_embed_with_shared_tagert_name(moe_type: str, set_name: str,):
if moe_type is None:
return None
embedding_dict = get_moe_embedding(moe_type)
return embedding_dict[set_name]