Realcat commited on
Commit
20ee7b7
·
1 Parent(s): b11e1d7

update: download liftfeat model

Browse files
Files changed (1) hide show
  1. imcui/hloc/extractors/liftfeat.py +10 -3
imcui/hloc/extractors/liftfeat.py CHANGED
@@ -4,26 +4,33 @@ from pathlib import Path
4
  import torch
5
  import random
6
  from ..utils.base_model import BaseModel
7
- from .. import logger
8
 
9
  fire_path = Path(__file__).parent / "../../third_party/LiftFeat"
10
  sys.path.append(str(fire_path))
11
 
12
- from models.liftfeat_wrapper import LiftFeat, MODEL_PATH
13
 
14
 
15
  class Liftfeat(BaseModel):
16
  default_conf = {
17
  "keypoint_threshold": 0.05,
18
  "max_keypoints": 5000,
 
19
  }
20
 
21
  required_inputs = ["image"]
22
 
23
  def _init(self, conf):
24
  logger.info("Loading LiftFeat model...")
 
 
 
 
 
 
25
  self.net = LiftFeat(
26
- weight=MODEL_PATH,
27
  detect_threshold=self.conf["keypoint_threshold"],
28
  top_k=self.conf["max_keypoints"],
29
  )
 
4
  import torch
5
  import random
6
  from ..utils.base_model import BaseModel
7
+ from .. import logger, MODEL_REPO_ID
8
 
9
  fire_path = Path(__file__).parent / "../../third_party/LiftFeat"
10
  sys.path.append(str(fire_path))
11
 
12
+ from models.liftfeat_wrapper import LiftFeat
13
 
14
 
15
  class Liftfeat(BaseModel):
16
  default_conf = {
17
  "keypoint_threshold": 0.05,
18
  "max_keypoints": 5000,
19
+ "model_name": "LiftFeat.pth",
20
  }
21
 
22
  required_inputs = ["image"]
23
 
24
  def _init(self, conf):
25
  logger.info("Loading LiftFeat model...")
26
+ model_path = self._download_model(
27
+ repo_id=MODEL_REPO_ID,
28
+ filename="{}/{}".format(
29
+ Path(__file__).stem, self.conf["model_name"]
30
+ ),
31
+ )
32
  self.net = LiftFeat(
33
+ weight=model_path,
34
  detect_threshold=self.conf["keypoint_threshold"],
35
  top_k=self.conf["max_keypoints"],
36
  )