File size: 2,486 Bytes
6452bf1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from typing import List
import torch
import torch.nn as nn
from torch import Tensor
from transformers import Wav2Vec2Processor, Wav2Vec2Model
SAMPLE_RATE = 16000
class UpstreamExpert(nn.Module):
def __init__(self, ckpt: str = None, model_config: str = None, **kwargs):
super().__init__()
self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
self.model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")
def get_downsample_rates(self, key: str) -> int:
return 320
def forward(self, wavs: List[Tensor]):
wavs_silence = []
#Total 7 settings
#original
wavs_silence = wavs
#front, 5
for wav in wavs:
temp_wav = torch.zeros(len(wav)//5).to(wav.device)
wavs_silence.append(torch.cat((temp_wav, wav)))
#front, 10
for wav in wavs:
temp_wav = torch.zeros(len(wav)//10).to(wav.device)
wavs_silence.append(torch.cat((temp_wav, wav)))
#front, 20
for wav in wavs:
temp_wav = torch.zeros(len(wav)//20).to(wav.device)
wavs_silence.append(torch.cat((temp_wav, wav)))
#end, 5
for wav in wavs:
temp_wav = torch.zeros(len(wav)//5).to(wav.device)
wavs_silence.append(torch.cat((wav, temp_wav)))
#end, 10
for wav in wavs:
temp_wav = torch.zeros(len(wav)//10).to(wav.device)
wavs_silence.append(torch.cat((wav, temp_wav)))
#end, 20
for wav in wavs:
temp_wav = torch.zeros(len(wav)//20).to(wav.device)
wavs_silence.append(torch.cat((wav, temp_wav)))
wavs = wavs_silence
device = wavs[0].device
processor_outputs = self.processor(
[wav.cpu().numpy() for wav in wavs],
return_tensors="pt",
sampling_rate=SAMPLE_RATE,
padding="longest",
)
attention_mask = processor_outputs.get("attention_mask", None)
if isinstance(attention_mask, torch.Tensor):
attention_mask = attention_mask.to(device)
model_outputs = self.model(
processor_outputs.input_values.to(device),
attention_mask=attention_mask,
output_hidden_states=True,
)
return {
"last_hidden_state": model_outputs.last_hidden_state,
"hidden_states": model_outputs.hidden_states,
}
|