File size: 4,880 Bytes
e129232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from collections import OrderedDict
from typing import List, Union, Dict

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence

import fairseq

# class Model(nn.Module):
#     def __init__(self):
#         super().__init__()
#         # The model needs to be a nn.Module for finetuning, not required for representation extraction
#         self.model1 = nn.Linear(1, HIDDEN_DIM)
#         self.model2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)

#     def forward(self, wavs, upstream_feature_selection="hidden_states"):
#         # You can do task-specified pre- / post-processing based on upstream_feature_selection
#         hidden = self.model1(wavs)
#         # hidden: (batch_size, max_len, hidden_dim)

#         feature = self.model2(hidden)
#         # feature: (batch_size, max_len, hidden_dim)

#         return [hidden, feature]

class UpstreamExpert(nn.Module):
    def __init__(
        self,
        ckpt: str = "https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt",
        upstream_feature_selection: str = "hidden_states",
        **kwargs):
        """
        Args:
            ckpt:
                The checkpoint path for loading your pretrained weights.
                Should be fixed as model.pt for SUPERB Challenge.
            upstream_feature_selection:
                The value could be 
                'hidden_states', 'PR', 'SID', 'ER', 'ASR', 'QbE', 'ASV', 'SD', 'ST', 'SE', 'SS', 'secret', or others(new tasks).
                You can use it to control which task-specified pre- / post-processing to do.
        """
        super().__init__()
        self.name = "[Example UpstreamExpert]"
        self.upstream_feature_selection = upstream_feature_selection

        # # You can use ckpt to load your pretrained weights
        # ckpt = torch.load(ckpt, map_location="cpu")
        # self.model = Model()
        # self.model.load_state_dict(ckpt)

        assert version.parse(fairseq.__version__) > version.parse(
            "0.10.2"
        ), "Please install the fairseq master branch."

        model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
            [ckpt]
        )
        self.model = model[0]
        self.task = task







    def get_downsample_rates(self, key: str) -> int:
        """
        Since we do not do any downsampling in this example upstream
        All keys' corresponding representations have downsample rate of 1
        Eg. 10ms stride representation has the downsample rate 160 (input wavs are all in 16kHz)
        """
        return 320

    def forward(self, wavs: List[Tensor]) -> Dict[str, List[Tensor]]:
        """
        When the returning Dict contains the List with more than one Tensor,
        those Tensors should be in the same shape to train a weighted-sum on them.
        """
        wavs_silence = []


        #Total 7 settings

        #original
        wavs_silence = wavs


        #front, 5
        for wav in wavs:
            temp_wav = torch.zeros(len(wav)//5).to(wav.device)
            wavs_silence.append(torch.cat((temp_wav, wav)))

        #front, 10
        for wav in wavs:
            temp_wav = torch.zeros(len(wav)//10).to(wav.device)
            wavs_silence.append(torch.cat((temp_wav, wav)))

        #front, 20
        for wav in wavs:
            temp_wav = torch.zeros(len(wav)//20).to(wav.device)
            wavs_silence.append(torch.cat((temp_wav, wav)))

        #end, 5
        for wav in wavs:
            temp_wav = torch.zeros(len(wav)//5).to(wav.device)
            wavs_silence.append(torch.cat((wav, temp_wav)))

        #end, 10
        for wav in wavs:
            temp_wav = torch.zeros(len(wav)//10).to(wav.device)
            wavs_silence.append(torch.cat((wav, temp_wav)))

        #end, 20
        for wav in wavs:
            temp_wav = torch.zeros(len(wav)//20).to(wav.device)
            wavs_silence.append(torch.cat((wav, temp_wav)))


        wavs = wavs_silence

        device = wavs[0].device
        wav_lengths = torch.LongTensor([len(wav) for wav in wavs]).to(device)
        wav_padding_mask = ~torch.lt(
            torch.arange(max(wav_lengths)).unsqueeze(0).to(device),
            wav_lengths.unsqueeze(1),
        )
        padded_wav = pad_sequence(wavs, batch_first=True)

        features, feat_padding_mask = self.model.extract_features(
            padded_wav,
            padding_mask=wav_padding_mask,
            mask=None,
        )


        # Deprecated! Do not do any task-specified postprocess below
        # You can use the init arg "upstream_feature_selection" to control which task-specified pre- / post-processing to do.
        # The "hidden_states" key will be used as default in many cases
        # Others keys in this example are presented for SUPERB Challenge
        return {
            "hidden_states": features,
        }