yeq6x commited on
Commit
d7efa60
·
1 Parent(s): b6c9f2e
Files changed (3) hide show
  1. scripts/anime.py +1 -1
  2. scripts/data.py +1 -1
  3. scripts/model.py +1 -4
scripts/anime.py CHANGED
@@ -19,7 +19,7 @@ model = None
19
  def init_model(use_local=False):
20
  global model
21
  model_opt = "default"
22
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
  model = create_model(model_opt, use_local).to(device)
24
  model.eval()
25
 
 
19
  def init_model(use_local=False):
20
  global model
21
  model_opt = "default"
22
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # issue: nevetherless, use_gpu is False, it still uses GPU
23
  model = create_model(model_opt, use_local).to(device)
24
  model.eval()
25
 
scripts/data.py CHANGED
@@ -40,7 +40,7 @@ def get_transform(load_size=0, grayscale=False, method=bic, convert=True):
40
  transform_list.append(transforms.Grayscale(1))
41
  if load_size > 0:
42
  osize = [load_size, load_size]
43
- transform_list.append(transforms.Resize(osize, method))
44
  if convert:
45
  # transform_list += [transforms.ToTensor()]
46
  if grayscale:
 
40
  transform_list.append(transforms.Grayscale(1))
41
  if load_size > 0:
42
  osize = [load_size, load_size]
43
+ transform_list.append(transforms.Resize(osize, method, antialias=False))
44
  if convert:
45
  # transform_list += [transforms.ToTensor()]
46
  if grayscale:
scripts/model.py CHANGED
@@ -154,8 +154,7 @@ def create_model(model, use_local):
154
 
155
  import os
156
  if model == 'default':
157
- model_path = (lambda filename, subfolder: os.path.join(subfolder, filename) if use_local else download_file(filename, subfolder)) \
158
- ("netG.pth", "models/Anime2Sketch")
159
  # model_path = ((filename, subfolder) => if (use_local) os.path.join(subfolder, filename) else download_file(filename, subfolder))("netG.pth", "models/Anime2Sketch") // JavaScript
160
 
161
  ckpt = torch.load(model_path)
@@ -176,8 +175,6 @@ def create_model(model, use_local):
176
  base = base.model[3]
177
 
178
  net.load_state_dict(ckpt)
179
-
180
- os.chdir(cwd) # 元のディレクトリに戻る
181
 
182
  else:
183
  raise ValueError(f"model should be one of ['default', 'improved'], but got {model}")
 
154
 
155
  import os
156
  if model == 'default':
157
+ model_path = (lambda filename, subfolder: os.path.join(subfolder, filename) if use_local else download_file(filename, subfolder))("netG.pth", "models/Anime2Sketch")
 
158
  # model_path = ((filename, subfolder) => if (use_local) os.path.join(subfolder, filename) else download_file(filename, subfolder))("netG.pth", "models/Anime2Sketch") // JavaScript
159
 
160
  ckpt = torch.load(model_path)
 
175
  base = base.model[3]
176
 
177
  net.load_state_dict(ckpt)
 
 
178
 
179
  else:
180
  raise ValueError(f"model should be one of ['default', 'improved'], but got {model}")