advcloud
commited on
Commit
·
9d224d7
1
Parent(s):
6af8543
first commit
Browse files
model.py
CHANGED
@@ -37,8 +37,6 @@ sys.path.insert(0, mapper_dir.as_posix())
|
|
37 |
from mapper.datasets.latents_dataset_inference import LatentsDatasetInference
|
38 |
from mapper.hairclip_mapper import HairCLIPMapper
|
39 |
|
40 |
-
HF_TOKEN = os.environ['HF_TOKEN']
|
41 |
-
|
42 |
|
43 |
class Model:
|
44 |
def __init__(self, device: Union[torch.device, str]):
|
@@ -51,15 +49,15 @@ class Model:
|
|
51 |
@staticmethod
|
52 |
def _create_dlib_landmark_model():
|
53 |
path = huggingface_hub.hf_hub_download(
|
54 |
-
'
|
55 |
-
'
|
56 |
-
|
57 |
return dlib.shape_predictor(path)
|
58 |
|
59 |
def _load_e4e(self) -> nn.Module:
|
60 |
-
ckpt_path = huggingface_hub.hf_hub_download('
|
61 |
-
'
|
62 |
-
|
63 |
ckpt = torch.load(ckpt_path, map_location='cpu')
|
64 |
opts = ckpt['opts']
|
65 |
opts['device'] = self.device.type
|
@@ -71,9 +69,9 @@ class Model:
|
|
71 |
return model
|
72 |
|
73 |
def _load_hairclip(self) -> nn.Module:
|
74 |
-
ckpt_path = huggingface_hub.hf_hub_download('
|
75 |
-
'hairclip.pt'
|
76 |
-
|
77 |
ckpt = torch.load(ckpt_path, map_location='cpu')
|
78 |
opts = ckpt['opts']
|
79 |
opts['device'] = self.device.type
|
|
|
37 |
from mapper.datasets.latents_dataset_inference import LatentsDatasetInference
|
38 |
from mapper.hairclip_mapper import HairCLIPMapper
|
39 |
|
|
|
|
|
40 |
|
41 |
class Model:
|
42 |
def __init__(self, device: Union[torch.device, str]):
|
|
|
49 |
@staticmethod
|
50 |
def _create_dlib_landmark_model():
|
51 |
path = huggingface_hub.hf_hub_download(
|
52 |
+
'aijack/jojogan',
|
53 |
+
'face_landmarks.dat'
|
54 |
+
)
|
55 |
return dlib.shape_predictor(path)
|
56 |
|
57 |
def _load_e4e(self) -> nn.Module:
|
58 |
+
ckpt_path = huggingface_hub.hf_hub_download('aijack/e4e',
|
59 |
+
'e4e.pt'
|
60 |
+
)
|
61 |
ckpt = torch.load(ckpt_path, map_location='cpu')
|
62 |
opts = ckpt['opts']
|
63 |
opts['device'] = self.device.type
|
|
|
69 |
return model
|
70 |
|
71 |
def _load_hairclip(self) -> nn.Module:
|
72 |
+
ckpt_path = huggingface_hub.hf_hub_download('aijack/hair',
|
73 |
+
'hairclip.pt'
|
74 |
+
)
|
75 |
ckpt = torch.load(ckpt_path, map_location='cpu')
|
76 |
opts = ckpt['opts']
|
77 |
opts['device'] = self.device.type
|