#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/kaituoxu/Conv-TasNet/blob/master/src/utils.py """ import math import torch def overlap_and_add(signal: torch.Tensor, frame_step: int): """ Reconstructs a signal from a framed representation. Adds potentially overlapping frames of a signal with shape `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. The resulting tensor has shape `[..., output_size]` where output_size = (frames - 1) * frame_step + frame_length Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py :param signal: Tensor, shape: [..., frames, frame_length]. All dimensions may be unknown, and rank must be at least 2. :param frame_step: int, overlap offsets. Must be less than or equal to frame_length. :return: Tensor, shape: [..., output_size]. containing the overlap-added frames of signal's inner-most two dimensions. output_size = (frames - 1) * frame_step + frame_length """ outer_dimensions = signal.size()[:-2] frames, frame_length = signal.size()[-2:] subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor subframe_step = frame_step // subframe_length subframes_per_frame = frame_length // subframe_length output_size = frame_step * (frames - 1) + frame_length output_subframes = output_size // subframe_length subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step) frame = frame.clone().detach() frame = frame.to(signal.device) frame = frame.long() frame = frame.contiguous().view(-1) result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) result.index_add_(-2, frame, subframe_signal) result = result.view(*outer_dimensions, -1) return result if __name__ == "__main__": pass