Josue Aaron Soriano Rivero commited on
Commit
2114012
1 Parent(s): 2f41cb5

Problema de use_auth_token

Browse files
Files changed (1) hide show
  1. utils.py +39 -1
utils.py CHANGED
@@ -1,9 +1,47 @@
 
1
  import numpy as np
2
  import torch
3
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
 
 
 
 
 
 
 
 
 
 
4
 
5
  def carga_modelo(nombre_modelo="ceyda/butterfly_cropped_uniq1K_512", model_version=None):
6
- gan = LightweightGAN.from_pretrained(nombre_modelo, version=model_version)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  gan.eval()
8
  return gan
9
 
 
1
+ import json
2
  import numpy as np
3
  import torch
4
  from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ CONFIG_NAME = "config.json"
8
+ revision = None
9
+ cache_dir = None
10
+ force_download = False
11
+ proxies = None
12
+ resume_download = False
13
+ local_files_only = False
14
+ token = None
15
 
16
  def carga_modelo(nombre_modelo="ceyda/butterfly_cropped_uniq1K_512", model_version=None):
17
+ # Load the config
18
+ config_file = hf_hub_download(
19
+ repo_id=str(nombre_modelo),
20
+ filename=CONFIG_NAME,
21
+ revision=revision,
22
+ cache_dir=cache_dir,
23
+ force_download=force_download,
24
+ proxies=proxies,
25
+ resume_download=resume_download,
26
+ token=token,
27
+ local_files_only=local_files_only,
28
+ )
29
+ with open(config_file, "r", encoding="utf-8") as f:
30
+ config = json.load(f)
31
+
32
+ gan = LightweightGAN(latent_dim=256, image_size=512)
33
+ gan = gan._from_pretrained(
34
+ model_id=str(nombre_modelo),
35
+ revision=revision,
36
+ cache_dir=cache_dir,
37
+ force_download=force_download,
38
+ proxies=proxies,
39
+ resume_download=resume_download,
40
+ local_files_only=local_files_only,
41
+ token=token,
42
+ use_auth_token=False,
43
+ config=config, # usually in **model_kwargs
44
+ )
45
  gan.eval()
46
  return gan
47