Update app.py
Browse files
app.py
CHANGED
@@ -92,8 +92,13 @@ def load_model(select_skpt):
|
|
92 |
global_pool=args.global_pool,
|
93 |
).to(device)
|
94 |
|
95 |
-
args.resume = os.path.join(CKPT_SAVE_PATH, ckpt)
|
96 |
-
args.resume
|
|
|
|
|
|
|
|
|
|
|
97 |
checkpoint = torch.load(args.resume, map_location=device)
|
98 |
model.load_state_dict(checkpoint['model'], strict=False)
|
99 |
model.eval()
|
@@ -246,15 +251,10 @@ CKPT_NAME = [
|
|
246 |
'DfD-Checkpoint_Fine-tuned_on_FF++',
|
247 |
'FAS-Checkpoint_Fine-tuned_on_MCIO',
|
248 |
]
|
249 |
-
# CKPT_PATH = {
|
250 |
-
# '✨Unified-detector_v1_Fine-tuned_on_4_classes': 'finetuned_models/Unified-detector/v1_Fine-tuned_on_4_classes/checkpoint-min_val_loss.pth',
|
251 |
-
# 'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
|
252 |
-
# 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth',
|
253 |
-
# }
|
254 |
CKPT_PATH = {
|
255 |
-
'✨Unified-detector_v1_Fine-tuned_on_4_classes': '
|
256 |
-
'DfD-Checkpoint_Fine-tuned_on_FF++': '
|
257 |
-
'FAS-Checkpoint_Fine-tuned_on_MCIO': '/
|
258 |
}
|
259 |
CKPT_CLASS = {
|
260 |
'✨Unified-detector_v1_Fine-tuned_on_4_classes': 4,
|
@@ -321,7 +321,7 @@ with gr.Blocks(css=".custom-label { font-weight: bold !important; font-size: 16p
|
|
321 |
if __name__ == "__main__":
|
322 |
args = get_args_parser()
|
323 |
args = args.parse_args()
|
324 |
-
ckpt = '
|
325 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
326 |
args.nb_classes = CKPT_CLASS[ckpt]
|
327 |
model = models_vit.__dict__[CKPT_MODEL[ckpt]](
|
@@ -329,8 +329,13 @@ if __name__ == "__main__":
|
|
329 |
drop_path_rate=args.drop_path,
|
330 |
global_pool=args.global_pool,
|
331 |
).to(device)
|
332 |
-
args.resume = os.path.join(CKPT_SAVE_PATH, ckpt)
|
333 |
-
args.resume
|
|
|
|
|
|
|
|
|
|
|
334 |
checkpoint = torch.load(args.resume, map_location=device)
|
335 |
model.load_state_dict(checkpoint['model'], strict=False)
|
336 |
model.eval()
|
|
|
92 |
global_pool=args.global_pool,
|
93 |
).to(device)
|
94 |
|
95 |
+
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
|
96 |
+
if os.path.isfile(args.resume) == False:
|
97 |
+
hf_hub_download(local_dir=CKPT_SAVE_PATH,
|
98 |
+
local_dir_use_symlinks=False,
|
99 |
+
repo_id='Wolowolo/fsfm-3c',
|
100 |
+
filename=CKPT_PATH[ckpt])
|
101 |
+
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
|
102 |
checkpoint = torch.load(args.resume, map_location=device)
|
103 |
model.load_state_dict(checkpoint['model'], strict=False)
|
104 |
model.eval()
|
|
|
251 |
'DfD-Checkpoint_Fine-tuned_on_FF++',
|
252 |
'FAS-Checkpoint_Fine-tuned_on_MCIO',
|
253 |
]
|
|
|
|
|
|
|
|
|
|
|
254 |
CKPT_PATH = {
|
255 |
+
'✨Unified-detector_v1_Fine-tuned_on_4_classes': 'finetuned_models/Unified-detector/v1_Fine-tuned_on_4_classes/checkpoint-min_train_loss.pth',
|
256 |
+
'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
|
257 |
+
'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth',
|
258 |
}
|
259 |
CKPT_CLASS = {
|
260 |
'✨Unified-detector_v1_Fine-tuned_on_4_classes': 4,
|
|
|
321 |
if __name__ == "__main__":
|
322 |
args = get_args_parser()
|
323 |
args = args.parse_args()
|
324 |
+
ckpt = '✨Unified-detector_v1_Fine-tuned_on_4_classes'
|
325 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
326 |
args.nb_classes = CKPT_CLASS[ckpt]
|
327 |
model = models_vit.__dict__[CKPT_MODEL[ckpt]](
|
|
|
329 |
drop_path_rate=args.drop_path,
|
330 |
global_pool=args.global_pool,
|
331 |
).to(device)
|
332 |
+
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
|
333 |
+
if os.path.isfile(args.resume) == False:
|
334 |
+
hf_hub_download(local_dir=CKPT_SAVE_PATH,
|
335 |
+
local_dir_use_symlinks=False,
|
336 |
+
repo_id='Wolowolo/fsfm-3c',
|
337 |
+
filename=CKPT_PATH[ckpt])
|
338 |
+
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
|
339 |
checkpoint = torch.load(args.resume, map_location=device)
|
340 |
model.load_state_dict(checkpoint['model'], strict=False)
|
341 |
model.eval()
|