Spaces:
Running
Running
File size: 2,018 Bytes
e86d760 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/kaituoxu/Conv-TasNet/blob/master/src/utils.py
"""
import math
import torch
def overlap_and_add(signal: torch.Tensor, frame_step: int):
"""
Reconstructs a signal from a framed representation.
Adds potentially overlapping frames of a signal with shape
`[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
The resulting tensor has shape `[..., output_size]` where
output_size = (frames - 1) * frame_step + frame_length
Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
:param signal: Tensor, shape: [..., frames, frame_length]. All dimensions may be unknown, and rank must be at least 2.
:param frame_step: int, overlap offsets. Must be less than or equal to frame_length.
:return: Tensor, shape: [..., output_size].
containing the overlap-added frames of signal's inner-most two dimensions.
output_size = (frames - 1) * frame_step + frame_length
"""
outer_dimensions = signal.size()[:-2]
frames, frame_length = signal.size()[-2:]
subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
subframe_step = frame_step // subframe_length
subframes_per_frame = frame_length // subframe_length
output_size = frame_step * (frames - 1) + frame_length
output_subframes = output_size // subframe_length
subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
frame = frame.clone().detach()
frame = frame.to(signal.device)
frame = frame.long()
frame = frame.contiguous().view(-1)
result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
result.index_add_(-2, frame, subframe_signal)
result = result.view(*outer_dimensions, -1)
return result
if __name__ == "__main__":
pass
|