#!/usr/bin/python3 # -*- coding: utf-8 -*- import torch def as_complex(x: torch.Tensor): if torch.is_complex(x): return x if x.shape[-1] != 2: raise ValueError(f"Last dimension need to be of length 2 (re + im), but got {x.shape}") if x.stride(-1) != 1: x = x.contiguous() return torch.view_as_complex(x) if __name__ == '__main__': pass