zideliu commited on
Commit
27f50b2
·
1 Parent(s): 79c0c01

Update clip

Browse files
Files changed (1) hide show
  1. app.py +45 -35
app.py CHANGED
@@ -15,33 +15,6 @@ from libs.muse import MUSE
15
  import utils
16
  import numpy as np
17
  from PIL import Image
18
- print("cuda available:",torch.cuda.is_available())
19
- print("cuda device count:",torch.cuda.device_count())
20
- print("cuda device name:",torch.cuda.get_device_name(0))
21
- # print(os.system("nvidia-smi"))
22
- print(os.system("nvcc --version"))
23
-
24
- empty_context = np.load("assets/contexts/empty_context.npy")
25
-
26
- print("downloading cc3m-285000.ckpt")
27
- os.makedirs("assets/ckpts/cc3m-285000.ckpt",exist_ok=True)
28
-
29
- wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/lr_scheduler.pth","assets/ckpts/cc3m-285000.ckpt/lr_scheduler.pth")
30
- wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/optimizer.pth","assets/ckpts/cc3m-285000.ckpt/optimizer.pth")
31
- wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/nnet.pth","assets/ckpts/cc3m-285000.ckpt/nnet.pth")
32
- wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/nnet_ema.pth","assets/ckpts/cc3m-285000.ckpt/nnet_ema.pth")
33
- wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/step.pth","assets/ckpts/cc3m-285000.ckpt/step.pth")
34
- wget.download("https://huggingface.co/zideliu/vqgan/resolve/main/vqgan_jax_strongaug.ckpt","assets/vqgan_jax_strongaug.ckpt")
35
- os.system("ls assets/ckpts/cc3m-285000.ckpt")
36
- def set_seed(seed: int):
37
- random.seed(seed)
38
- np.random.seed(seed)
39
- torch.manual_seed(seed)
40
- torch.cuda.manual_seed_all(seed)
41
-
42
- def d(**kwargs):
43
- """Helper of creating a config dict."""
44
- return ml_collections.ConfigDict(initial_dictionary=kwargs)
45
 
46
  def get_config():
47
  config = ml_collections.ConfigDict()
@@ -98,6 +71,50 @@ def get_config():
98
  )
99
  return config
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  def cfg_nnet(x, context, scale=None,lambdaA=None,lambdaB=None):
102
  _cond = nnet_ema(x, context=context)
103
  _cond_w_adapter = nnet_ema(x,context=context,use_adapter=True)
@@ -113,14 +130,7 @@ def unprocess(x):
113
  x.clamp_(0., 1.)
114
  return x
115
 
116
- config = get_config()
117
- device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
118
- print(device)
119
- # Load open_clip and vq model
120
- prompt_model,_,_ = open_clip.create_model_and_transforms('ViT-bigG-14', 'laion2b_s39b_b160k',device='cuda')
121
- prompt_model = prompt_model.to(device)
122
- prompt_model.eval()
123
- tokenizer = open_clip.get_tokenizer('ViT-bigG-14')
124
 
125
  vq_model = taming.models.vqgan.get_model('vq-f16-jax.yaml')
126
  vq_model.eval()
 
15
  import utils
16
  import numpy as np
17
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def get_config():
20
  config = ml_collections.ConfigDict()
 
71
  )
72
  return config
73
 
74
+ print("cuda available:",torch.cuda.is_available())
75
+ print("cuda device count:",torch.cuda.device_count())
76
+ print("cuda device name:",torch.cuda.get_device_name(0))
77
+ # print(os.system("nvidia-smi"))
78
+ print(os.system("nvcc --version"))
79
+
80
+ empty_context = np.load("assets/contexts/empty_context.npy")
81
+
82
+ config = get_config()
83
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
84
+ print(device)
85
+ # Load open_clip and vq model
86
+ print("GPU memory:",torch.cuda.memory_allocated(0)/1024/1024/1024,"GB")
87
+ prompt_model,_,_ = open_clip.create_model_and_transforms('ViT-bigG-14', 'laion2b_s39b_b160k',device=device)
88
+ print("GPU memory:",torch.cuda.memory_allocated(0)/1024/1024/1024,"GB")
89
+
90
+ prompt_model = prompt_model.to(device)
91
+ prompt_model.eval()
92
+ tokenizer = open_clip.get_tokenizer('ViT-bigG-14')
93
+
94
+
95
+
96
+
97
+ print("downloading cc3m-285000.ckpt")
98
+ os.makedirs("assets/ckpts/cc3m-285000.ckpt",exist_ok=True)
99
+
100
+ wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/lr_scheduler.pth","assets/ckpts/cc3m-285000.ckpt/lr_scheduler.pth")
101
+ wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/optimizer.pth","assets/ckpts/cc3m-285000.ckpt/optimizer.pth")
102
+ wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/nnet.pth","assets/ckpts/cc3m-285000.ckpt/nnet.pth")
103
+ wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/nnet_ema.pth","assets/ckpts/cc3m-285000.ckpt/nnet_ema.pth")
104
+ wget.download("https://huggingface.co/nzl-thu/MUSE/resolve/main/assets/ckpts/cc3m-285000.ckpt/step.pth","assets/ckpts/cc3m-285000.ckpt/step.pth")
105
+ wget.download("https://huggingface.co/zideliu/vqgan/resolve/main/vqgan_jax_strongaug.ckpt","assets/vqgan_jax_strongaug.ckpt")
106
+ os.system("ls assets/ckpts/cc3m-285000.ckpt")
107
+ def set_seed(seed: int):
108
+ random.seed(seed)
109
+ np.random.seed(seed)
110
+ torch.manual_seed(seed)
111
+ torch.cuda.manual_seed_all(seed)
112
+
113
+ def d(**kwargs):
114
+ """Helper of creating a config dict."""
115
+ return ml_collections.ConfigDict(initial_dictionary=kwargs)
116
+
117
+
118
  def cfg_nnet(x, context, scale=None,lambdaA=None,lambdaB=None):
119
  _cond = nnet_ema(x, context=context)
120
  _cond_w_adapter = nnet_ema(x,context=context,use_adapter=True)
 
130
  x.clamp_(0., 1.)
131
  return x
132
 
133
+
 
 
 
 
 
 
 
134
 
135
  vq_model = taming.models.vqgan.get_model('vq-f16-jax.yaml')
136
  vq_model.eval()