zino36 commited on
Commit
773c11a
·
verified ·
1 Parent(s): bce5a1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -5
app.py CHANGED
@@ -33,6 +33,13 @@
33
  # --------------------------------------------------------
34
  # gradio demo executable
35
  # --------------------------------------------------------
 
 
 
 
 
 
 
36
  import os
37
  import torch
38
  import tempfile
@@ -46,17 +53,24 @@ from mast3r.utils.misc import hash_md5
46
  import matplotlib.pyplot as pl
47
  pl.ion()
48
 
49
- torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
 
 
 
 
50
 
51
  if __name__ == '__main__':
52
  parser = get_args_parser()
53
  args = parser.parse_args()
54
 
55
- # Set default value for `args.weights` if not provided
 
 
 
 
 
56
  if args.weights is None:
57
- # Set a default model_name if weights are not provided
58
- args.model_name = 'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric' # Default model_name
59
- args.weights = "naver/" + args.model_name # Construct default weights path
60
 
61
  if args.server_name is not None:
62
  server_name = args.server_name
@@ -66,12 +80,14 @@ if __name__ == '__main__':
66
  # Use the provided or default weights_path
67
  weights_path = args.weights
68
 
 
69
  model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
70
  chkpt_tag = hash_md5(weights_path)
71
 
72
  def get_context(tmp_dir):
73
  return tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') if tmp_dir is None \
74
  else nullcontext(tmp_dir)
 
75
  with get_context(args.tmp_dir) as tmpdirname:
76
  cache_path = os.path.join(tmpdirname, chkpt_tag)
77
  os.makedirs(cache_path, exist_ok=True)
 
33
  # --------------------------------------------------------
34
  # gradio demo executable
35
  # --------------------------------------------------------
36
+ #!/usr/bin/env python3
37
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
38
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
39
+ #
40
+ # --------------------------------------------------------
41
+ # gradio demo executable
42
+ # --------------------------------------------------------
43
  import os
44
  import torch
45
  import tempfile
 
53
  import matplotlib.pyplot as pl
54
  pl.ion()
55
 
56
+ torch.backends.cuda.matmul.allow_tf32 = True # for GPU >= Ampere and PyTorch >= 1.12
57
+
58
+ def get_default_weights_path(model_name):
59
+ # Construct default weights path based on model_name
60
+ return f"naver/{model_name}"
61
 
62
  if __name__ == '__main__':
63
  parser = get_args_parser()
64
  args = parser.parse_args()
65
 
66
+ # Ensure at least one of weights or model_name is provided
67
+ if args.weights is None and args.model_name is None:
68
+ # Provide a default model_name if both weights and model_name are not provided
69
+ args.model_name = 'MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
70
+
71
+ # If weights are not provided but model_name is, construct weights_path
72
  if args.weights is None:
73
+ args.weights = get_default_weights_path(args.model_name)
 
 
74
 
75
  if args.server_name is not None:
76
  server_name = args.server_name
 
80
  # Use the provided or default weights_path
81
  weights_path = args.weights
82
 
83
+ # Load the model with the weights_path
84
  model = AsymmetricMASt3R.from_pretrained(weights_path).to(args.device)
85
  chkpt_tag = hash_md5(weights_path)
86
 
87
  def get_context(tmp_dir):
88
  return tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') if tmp_dir is None \
89
  else nullcontext(tmp_dir)
90
+
91
  with get_context(args.tmp_dir) as tmpdirname:
92
  cache_path = os.path.join(tmpdirname, chkpt_tag)
93
  os.makedirs(cache_path, exist_ok=True)