Rodrigo_Cobo
commited on
Commit
•
272c5b4
1
Parent(s):
83c98ec
add the option to work in CPU
Browse files- .gitignore +2 -1
- WiggleGAN.py +6 -2
.gitignore
CHANGED
@@ -5,4 +5,5 @@ Lib/*
|
|
5 |
logs/*
|
6 |
WiggleGAN_mod.py
|
7 |
WiggleGAN_noCycle.py
|
8 |
-
pyvenv.cfg
|
|
|
|
5 |
logs/*
|
6 |
WiggleGAN_mod.py
|
7 |
WiggleGAN_noCycle.py
|
8 |
+
pyvenv.cfg
|
9 |
+
py
|
WiggleGAN.py
CHANGED
@@ -783,9 +783,13 @@ class WiggleGAN(object):
|
|
783 |
def load(self):
|
784 |
save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
|
785 |
|
786 |
-
|
|
|
|
|
|
|
|
|
787 |
if not self.wiggle:
|
788 |
-
self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_D.pkl')))
|
789 |
|
790 |
def wiggleEf(self):
|
791 |
seed, epoch = self.seed_load.split('_')
|
|
|
783 |
def load(self):
|
784 |
save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
|
785 |
|
786 |
+
map_loc=None
|
787 |
+
if not torch.cuda.is_available():
|
788 |
+
map_loc='cpu'
|
789 |
+
|
790 |
+
self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_G.pkl'), map_location=map_loc))
|
791 |
if not self.wiggle:
|
792 |
+
self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_D.pkl'), map_location=map_loc))
|
793 |
|
794 |
def wiggleEf(self):
|
795 |
seed, epoch = self.seed_load.split('_')
|