Spaces:
Runtime error
Runtime error
#!/usr/bin/python | |
#****************************************************************# | |
# ScriptName: fft_pytorch.py | |
# Author: Anonymous_123 | |
# Create Date: 2022-08-15 11:33 | |
# Modify Author: Anonymous_123 | |
# Modify Date: 2022-08-18 17:46 | |
# Function: | |
#***************************************************************# | |
import torch | |
import torch.nn as nn | |
import torch.fft as fft | |
import cv2 | |
import numpy as np | |
import torchvision.transforms as transforms | |
from PIL import Image | |
def lowpass(input, limit): | |
pass1 = torch.abs(fft.rfftfreq(input.shape[-1])) < limit | |
pass2 = torch.abs(fft.fftfreq(input.shape[-2])) < limit | |
kernel = torch.outer(pass2, pass1) | |
fft_input = fft.rfft2(input) | |
return fft.irfft2(fft_input*kernel, s=input.shape[-2:]) | |
class HighFrequencyLoss(nn.Module): | |
def __init__(self, size=(224,224)): | |
super(HighFrequencyLoss, self).__init__() | |
''' | |
self.h,self.w = size | |
self.lpf = torch.zeros((self.h,1)) | |
R = (self.h+self.w)//8 | |
for x in range(self.w): | |
for y in range(self.h): | |
if ((x-(self.w-1)/2)**2 + (y-(self.h-1)/2)**2) < (R**2): | |
self.lpf[y,x] = 1 | |
self.hpf = 1-self.lpf | |
''' | |
def forward(self, x): | |
f = fft.fftn(x, dim=(2,3)) | |
loss = torch.abs(f).mean() | |
# f = torch.roll(f,(self.h//2,self.w//2),dims=(2,3)) | |
# f_l = torch.mean(f * self.lpf) | |
# f_h = torch.mean(f * self.hpf) | |
return loss | |
if __name__ == '__main__': | |
import pdb | |
pdb.set_trace() | |
HF = HighFrequencyLoss() | |
transform = transforms.Compose([transforms.ToTensor()]) | |
# img = cv2.imread('test_imgs/ILSVRC2012_val_00001935.JPEG') | |
img = cv2.imread('../tmp.jpg') | |
H,W,C = img.shape | |
imgs = [] | |
for i in range(10): | |
img_ = img[:, 224*i:224*(i+1), :] | |
print(img_.shape) | |
img_tensor = transform(Image.fromarray(img_[:,:,::-1])).unsqueeze(0) | |
loss = HF(img_tensor).item() | |
cv2.putText(img_, str(loss)[:6], (5,50), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2) | |
imgs.append(img_) | |
cv2.imwrite('tmp.jpg', cv2.hconcat(imgs)) | |