Spaces:
Runtime error
Runtime error
from manipulate import Manipulator | |
import tensorflow as tf | |
import numpy as np | |
import torch | |
import clip | |
from MapTS import GetBoundary,GetDt | |
class StyleCLIP(): | |
def __init__(self,dataset_name='ffhq'): | |
print('load clip') | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model, preprocess = clip.load("ViT-B/32", device=device) | |
self.LoadData(dataset_name) | |
def LoadData(self, dataset_name): | |
tf.keras.backend.clear_session() | |
M=Manipulator(dataset_name=dataset_name) | |
np.set_printoptions(suppress=True) | |
fs3=np.load('./npy/'+dataset_name+'/fs3.npy') | |
self.M=M | |
self.fs3=fs3 | |
w_plus=np.load('./data/'+dataset_name+'/w_plus.npy') | |
self.M.dlatents=M.W2S(w_plus) | |
if dataset_name=='ffhq': | |
self.c_threshold=20 | |
else: | |
self.c_threshold=100 | |
self.SetInitP() | |
def SetInitP(self): | |
self.M.alpha=[3] | |
self.M.num_images=1 | |
self.target='' | |
self.neutral='' | |
self.GetDt2() | |
img_index=0 | |
self.M.dlatent_tmp=[tmp[img_index:(img_index+1)] for tmp in self.M.dlatents] | |
def GetDt2(self): | |
classnames=[self.target,self.neutral] | |
dt=GetDt(classnames,self.model) | |
self.dt=dt | |
num_cs=[] | |
betas=np.arange(0.1,0.3,0.01) | |
for i in range(len(betas)): | |
boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=betas[i]) | |
print(betas[i]) | |
num_cs.append(num_c) | |
num_cs=np.array(num_cs) | |
select=num_cs>self.c_threshold | |
if sum(select)==0: | |
self.beta=0.1 | |
else: | |
self.beta=betas[select][-1] | |
def GetCode(self): | |
boundary_tmp2,num_c=GetBoundary(self.fs3,self.dt,self.M,threshold=self.beta) | |
codes=self.M.MSCode(self.M.dlatent_tmp,boundary_tmp2) | |
return codes | |
def GetImg(self): | |
codes=self.GetCode() | |
out=self.M.GenerateImg(codes) | |
img=out[0,0] | |
return img | |
#%% | |
if __name__ == "__main__": | |
style_clip=StyleCLIP() | |
self=style_clip | |