Wolowolo commited on
Commit
dc741a8
·
verified ·
1 Parent(s): 4d10ed1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -13
app.py CHANGED
@@ -92,8 +92,13 @@ def load_model(select_skpt):
92
  global_pool=args.global_pool,
93
  ).to(device)
94
 
95
- args.resume = os.path.join(CKPT_SAVE_PATH, ckpt)
96
- args.resume = CKPT_PATH[ckpt]
 
 
 
 
 
97
  checkpoint = torch.load(args.resume, map_location=device)
98
  model.load_state_dict(checkpoint['model'], strict=False)
99
  model.eval()
@@ -246,15 +251,10 @@ CKPT_NAME = [
246
  'DfD-Checkpoint_Fine-tuned_on_FF++',
247
  'FAS-Checkpoint_Fine-tuned_on_MCIO',
248
  ]
249
- # CKPT_PATH = {
250
- # '✨Unified-detector_v1_Fine-tuned_on_4_classes': 'finetuned_models/Unified-detector/v1_Fine-tuned_on_4_classes/checkpoint-min_val_loss.pth',
251
- # 'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
252
- # 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth',
253
- # }
254
  CKPT_PATH = {
255
- '✨Unified-detector_v1_Fine-tuned_on_4_classes': './checkpoints/checkpoint-min_train_loss.pth',
256
- 'DfD-Checkpoint_Fine-tuned_on_FF++': '/mnt/localDisk2/wgj/FSFM/released/FSFM-main/fsfm-3c/finuetune/cross_dataset_DfD/checkpoint/finetuned_models/ft_on_FF++_c23_32frames/pt_from_VF2_ViT-B_epoch600/checkpoint-min_val_loss.pth',
257
- 'FAS-Checkpoint_Fine-tuned_on_MCIO': '/mnt/localDisk2/wgj/FSFM/FSFM-3C/codespace/fsfm-3c/finuetune/cross_dataset_DfD/finetuned_models/FAS_MCIO/checkpoint-199.pth',
258
  }
259
  CKPT_CLASS = {
260
  '✨Unified-detector_v1_Fine-tuned_on_4_classes': 4,
@@ -321,7 +321,7 @@ with gr.Blocks(css=".custom-label { font-weight: bold !important; font-size: 16p
321
  if __name__ == "__main__":
322
  args = get_args_parser()
323
  args = args.parse_args()
324
- ckpt = 'DfD-Checkpoint_Fine-tuned_on_FF++'
325
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
326
  args.nb_classes = CKPT_CLASS[ckpt]
327
  model = models_vit.__dict__[CKPT_MODEL[ckpt]](
@@ -329,8 +329,13 @@ if __name__ == "__main__":
329
  drop_path_rate=args.drop_path,
330
  global_pool=args.global_pool,
331
  ).to(device)
332
- args.resume = os.path.join(CKPT_SAVE_PATH, ckpt)
333
- args.resume = CKPT_PATH[ckpt]
 
 
 
 
 
334
  checkpoint = torch.load(args.resume, map_location=device)
335
  model.load_state_dict(checkpoint['model'], strict=False)
336
  model.eval()
 
92
  global_pool=args.global_pool,
93
  ).to(device)
94
 
95
+ args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
96
+ if os.path.isfile(args.resume) == False:
97
+ hf_hub_download(local_dir=CKPT_SAVE_PATH,
98
+ local_dir_use_symlinks=False,
99
+ repo_id='Wolowolo/fsfm-3c',
100
+ filename=CKPT_PATH[ckpt])
101
+ args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
102
  checkpoint = torch.load(args.resume, map_location=device)
103
  model.load_state_dict(checkpoint['model'], strict=False)
104
  model.eval()
 
251
  'DfD-Checkpoint_Fine-tuned_on_FF++',
252
  'FAS-Checkpoint_Fine-tuned_on_MCIO',
253
  ]
 
 
 
 
 
254
  CKPT_PATH = {
255
+ '✨Unified-detector_v1_Fine-tuned_on_4_classes': 'finetuned_models/Unified-detector/v1_Fine-tuned_on_4_classes/checkpoint-min_train_loss.pth',
256
+ 'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
257
+ 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth',
258
  }
259
  CKPT_CLASS = {
260
  '✨Unified-detector_v1_Fine-tuned_on_4_classes': 4,
 
321
  if __name__ == "__main__":
322
  args = get_args_parser()
323
  args = args.parse_args()
324
+ ckpt = '✨Unified-detector_v1_Fine-tuned_on_4_classes'
325
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
326
  args.nb_classes = CKPT_CLASS[ckpt]
327
  model = models_vit.__dict__[CKPT_MODEL[ckpt]](
 
329
  drop_path_rate=args.drop_path,
330
  global_pool=args.global_pool,
331
  ).to(device)
332
+ args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
333
+ if os.path.isfile(args.resume) == False:
334
+ hf_hub_download(local_dir=CKPT_SAVE_PATH,
335
+ local_dir_use_symlinks=False,
336
+ repo_id='Wolowolo/fsfm-3c',
337
+ filename=CKPT_PATH[ckpt])
338
+ args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
339
  checkpoint = torch.load(args.resume, map_location=device)
340
  model.load_state_dict(checkpoint['model'], strict=False)
341
  model.eval()