Spaces:
Runtime error
Runtime error
File size: 2,623 Bytes
2fec875 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import numpy as np
import torch
import clip
from PIL import Image
import copy
from manipulate import Manipulator
import argparse
def GetImgF(out,model,preprocess):
imgs=out
imgs1=imgs.reshape([-1]+list(imgs.shape[2:]))
tmp=[]
for i in range(len(imgs1)):
img=Image.fromarray(imgs1[i])
image = preprocess(img).unsqueeze(0).to(device)
tmp.append(image)
image=torch.cat(tmp)
with torch.no_grad():
image_features = model.encode_image(image)
image_features1=image_features.cpu().numpy()
image_features1=image_features1.reshape(list(imgs.shape[:2])+[512])
return image_features1
def GetFs(fs):
tmp=np.linalg.norm(fs,axis=-1)
fs1=fs/tmp[:,:,:,None]
fs2=fs1[:,:,1,:]-fs1[:,:,0,:] # 5*sigma - (-5)* sigma
fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None]
fs3=fs3.mean(axis=1)
fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None]
return fs3
#%%
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--dataset_name',type=str,default='cat',
help='name of dataset, for example, ffhq')
args = parser.parse_args()
dataset_name=args.dataset_name
#%%
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
#%%
M=Manipulator(dataset_name=dataset_name)
np.set_printoptions(suppress=True)
print(M.dataset_name)
#%%
img_sindex=0
num_images=100
dlatents_o=[]
tmp=img_sindex*num_images
for i in range(len(M.dlatents)):
tmp1=M.dlatents[i][tmp:(tmp+num_images)]
dlatents_o.append(tmp1)
#%%
all_f=[]
M.alpha=[-5,5] #ffhq 5
M.step=2
M.num_images=num_images
select=np.array(M.mindexs)<=16 #below or equal to 128 resolution
mindexs2=np.array(M.mindexs)[select]
for lindex in mindexs2: #ignore ToRGB layers
print(lindex)
num_c=M.dlatents[lindex].shape[1]
for cindex in range(num_c):
M.dlatents=copy.copy(dlatents_o)
M.dlatents[lindex][:,cindex]=M.code_mean[lindex][cindex]
M.manipulate_layers=[lindex]
codes,out=M.EditOneC(cindex)
image_features1=GetImgF(out,model,preprocess)
all_f.append(image_features1)
all_f=np.array(all_f)
fs3=GetFs(all_f)
#%%
file_path='./npy/'+M.dataset_name+'/'
np.save(file_path+'fs3',fs3)
|