File size: 4,516 Bytes
d4be371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn.functional as F
import collections

def compute_similarity(logit_scale, audio_embeddings, text_embeddings):
    r"""Compute similarity between text and audio embeddings"""
    audio_embeddings = audio_embeddings/torch.norm(audio_embeddings, dim=-1, keepdim=True)
    text_embeddings = text_embeddings/torch.norm(text_embeddings, dim=-1, keepdim=True)

    similarity = logit_scale*text_embeddings @ audio_embeddings.T
    return similarity.T

def compute_logit(logit_scale, audio_embeddings, text_embeddings):
    logits_per_audio = logit_scale * audio_embeddings @ text_embeddings.T
    logits_per_text = logit_scale * text_embeddings @ audio_embeddings.T
    return logits_per_audio, logits_per_text

def laion_compute_similarity(logit_scale, audio_embeddings, text_embeddings):
    r"""Compute similarity between text and audio embeddings"""
    audio_embeddings = F.normalize(audio_embeddings, dim=-1)
    text_embeddings = F.normalize(text_embeddings, dim=-1)

    similarity = logit_scale*audio_embeddings @ text_embeddings.T
    return similarity

def freeze_branch_parameters(named_parameters, branch_name, freeze_flag):
    branch_parameters = [
        p
        for n, p in named_parameters
        if branch_name in n
    ]
    if freeze_flag:
        print(f"Freezing {branch_name.capitalize()} parameters.")
        for param in branch_parameters:
            param.requires_grad = False

def format_emotion(emotion):
    if emotion == 'no_agreement':
        return 'there is no clear emotion.'
    else:
        return f'this person is feeling {emotion}.'


def preprocess_text(text_queries, tokenizer):
    r"""Load list of class labels and return tokenized text"""
    token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
    tokenized_texts = []
    for ttext in text_queries:
        tok = tokenizer.encode_plus(
            text=ttext, add_special_tokens=True, max_length=77, padding='max_length', return_tensors="pt")
        for key in token_keys:
            tok[key] = tok[key].reshape(-1).cuda()
        tokenized_texts.append(tok)
    return default_collate(tokenized_texts)

def default_collate(batch):
        r"""Puts each data field into a tensor with outer dimension batch size"""
        elem = batch[0]
        elem_type = type(elem)
        if isinstance(elem, torch.Tensor):
            out = None
            if torch.utils.data.get_worker_info() is not None:
                # If we're in a background process, concatenate directly into a
                # shared memory tensor to avoid an extra copy
                numel = sum([x.numel() for x in batch])
                storage = elem.storage()._new_shared(numel)
                out = elem.new(storage)
            return torch.stack(batch, 0, out=out)
        elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
                and elem_type.__name__ != 'string_':
            if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
                # array of string classes and object
                if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                    raise TypeError(
                        default_collate_err_msg_format.format(elem.dtype))

                return default_collate([torch.as_tensor(b) for b in batch])
            elif elem.shape == ():  # scalars
                return torch.as_tensor(batch)
        elif isinstance(elem, float):
            return torch.tensor(batch, dtype=torch.float64)
        elif isinstance(elem, int):
            return torch.tensor(batch)
        elif isinstance(elem, str):
            return batch
        elif isinstance(elem, collections.abc.Mapping):
            return {key: default_collate([d[key] for d in batch]) for key in elem}
        elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
            return elem_type(*(default_collate(samples) for samples in zip(*batch)))
        elif isinstance(elem, collections.abc.Sequence):
            # check to make sure that the elements in batch have consistent size
            it = iter(batch)
            elem_size = len(next(it))
            if not all(len(elem) == elem_size for elem in it):
                raise RuntimeError(
                    'each element in list of batch should be of equal size')
            transposed = zip(*batch)
            return [default_collate(samples) for samples in transposed]

        raise TypeError(default_collate_err_msg_format.format(elem_type))