echen01
working demo
2fec875
raw
history blame
2.33 kB
from manipulate import Manipulator
import tensorflow as tf
import numpy as np
import torch
import clip
from MapTS import GetBoundary,GetDt
class StyleCLIP():
def __init__(self,dataset_name='ffhq'):
print('load clip')
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, preprocess = clip.load("ViT-B/32", device=device)
self.LoadData(dataset_name)
def LoadData(self, dataset_name):
tf.keras.backend.clear_session()
M=Manipulator(dataset_name=dataset_name)
np.set_printoptions(suppress=True)
fs3=np.load('./npy/'+dataset_name+'/fs3.npy')
self.M=M
self.fs3=fs3
w_plus=np.load('./data/'+dataset_name+'/w_plus.npy')
self.M.dlatents=M.W2S(w_plus)
if dataset_name=='ffhq':
self.c_threshold=20
else:
self.c_threshold=100
self.SetInitP()
def SetInitP(self):
self.M.alpha=[3]
self.M.num_images=1
self.target=''
self.neutral=''
self.GetDt2()
img_index=0
self.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.M.dlatents]
def GetDt2(self):
classnames=[self.target,self.neutral]
dt=GetDt(classnames,self.model)
self.dt=dt
num_cs=[]
betas=np.arange(0.1,0.3,0.01)
for i in range(len(betas)):
boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=betas[i])
print(betas[i])
num_cs.append(num_c)
num_cs=np.array(num_cs)
select=num_cs>self.c_threshold
if sum(select)==0:
self.beta=0.1
else:
self.beta=betas[select][-1]
def GetCode(self):
boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=self.beta)
codes=self.M.MSCode(self.M.dlatent_tmp,boundary_tmp2)
return codes
def GetImg(self):
codes=self.GetCode()
out=self.M.GenerateImg(codes)
img=out[0,0]
return img
#%%
if __name__ == "__main__":
style_clip=StyleCLIP()
self=style_clip