Amrrs's picture
Duplicate from DragGan/DragGan-Inversion
04d341d
raw
history blame contribute delete
2.33 kB
from manipulate import Manipulator
import tensorflow as tf
import numpy as np
import torch
import clip
from MapTS import GetBoundary,GetDt
class StyleCLIP():
def __init__(self,dataset_name='ffhq'):
print('load clip')
device = "cuda" if torch.cuda.is_available() else "cpu"
self.model, preprocess = clip.load("ViT-B/32", device=device)
self.LoadData(dataset_name)
def LoadData(self, dataset_name):
tf.keras.backend.clear_session()
M=Manipulator(dataset_name=dataset_name)
np.set_printoptions(suppress=True)
fs3=np.load('./npy/'+dataset_name+'/fs3.npy')
self.M=M
self.fs3=fs3
w_plus=np.load('./data/'+dataset_name+'/w_plus.npy')
self.M.dlatents=M.W2S(w_plus)
if dataset_name=='ffhq':
self.c_threshold=20
else:
self.c_threshold=100
self.SetInitP()
def SetInitP(self):
self.M.alpha=[3]
self.M.num_images=1
self.target=''
self.neutral=''
self.GetDt2()
img_index=0
self.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.M.dlatents]
def GetDt2(self):
classnames=[self.target,self.neutral]
dt=GetDt(classnames,self.model)
self.dt=dt
num_cs=[]
betas=np.arange(0.1,0.3,0.01)
for i in range(len(betas)):
boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=betas[i])
print(betas[i])
num_cs.append(num_c)
num_cs=np.array(num_cs)
select=num_cs>self.c_threshold
if sum(select)==0:
self.beta=0.1
else:
self.beta=betas[select][-1]
def GetCode(self):
boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=self.beta)
codes=self.M.MSCode(self.M.dlatent_tmp,boundary_tmp2)
return codes
def GetImg(self):
codes=self.GetCode()
out=self.M.GenerateImg(codes)
img=out[0,0]
return img
#%%
if __name__ == "__main__":
style_clip=StyleCLIP()
self=style_clip