HoneyTian's picture
first commit
bd94e77
raw
history blame contribute delete
384 Bytes
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import torch
def as_complex(x: torch.Tensor):
if torch.is_complex(x):
return x
if x.shape[-1] != 2:
raise ValueError(f"Last dimension need to be of length 2 (re + im), but got {x.shape}")
if x.stride(-1) != 1:
x = x.contiguous()
return torch.view_as_complex(x)
if __name__ == '__main__':
pass