Files changed (1) hide show
  1. app.py +415 -345
app.py CHANGED
@@ -1,346 +1,416 @@
1
- # -*- coding: utf-8 -*-
2
- # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
- # --------------------------------------------------------
4
- # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
- # You can find the license in the LICENSE file in the root directory of this source tree.
6
- # --------------------------------------------------------
7
-
8
- import sys
9
- import os
10
- os.system(f'pip install dlib')
11
- import dlib
12
- import argparse
13
- import numpy as np
14
- from PIL import Image
15
- import cv2
16
- import torch
17
- from huggingface_hub import hf_hub_download
18
- import gradio as gr
19
-
20
- import models_vit
21
- from util.datasets import build_dataset
22
- from engine_finetune import test_two_class, test_multi_class
23
-
24
-
25
- def get_args_parser():
26
- parser = argparse.ArgumentParser('FSFM3C fine-tuning&Testing for image classification', add_help=False)
27
- parser.add_argument('--batch_size', default=64, type=int, help='Batch size per GPU')
28
- parser.add_argument('--epochs', default=50, type=int)
29
- parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations')
30
- parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', help='Name of model to train')
31
- parser.add_argument('--input_size', default=224, type=int, help='images input size')
32
- parser.add_argument('--normalize_from_IMN', action='store_true', help='cal mean and std from imagenet')
33
- parser.set_defaults(normalize_from_IMN=True)
34
- parser.add_argument('--apply_simple_augment', action='store_true', help='apply simple data augment')
35
- parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path rate')
36
- parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', help='Clip gradient norm')
37
- parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay')
38
- parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate')
39
- parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', help='base learning rate')
40
- parser.add_argument('--layer_decay', type=float, default=0.75, help='layer-wise lr decay')
41
- parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound')
42
- parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR')
43
- parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', help='Color jitter factor')
44
- parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy')
45
- parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing')
46
- parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob')
47
- parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode')
48
- parser.add_argument('--recount', type=int, default=1, help='Random erase count')
49
- parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first augmentation split')
50
- parser.add_argument('--mixup', type=float, default=0, help='mixup alpha')
51
- parser.add_argument('--cutmix', type=float, default=0, help='cutmix alpha')
52
- parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio')
53
- parser.add_argument('--mixup_prob', type=float, default=1.0, help='Probability of performing mixup or cutmix')
54
- parser.add_argument('--mixup_switch_prob', type=float, default=0.5, help='Probability of switching to cutmix')
55
- parser.add_argument('--mixup_mode', type=str, default='batch', help='How to apply mixup/cutmix params')
56
- parser.add_argument('--finetune', default='', help='finetune from checkpoint')
57
- parser.add_argument('--global_pool', action='store_true')
58
- parser.set_defaults(global_pool=True)
59
- parser.add_argument('--cls_token', action='store_false', dest='global_pool', help='Use class token for classification')
60
- parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, help='dataset path')
61
- parser.add_argument('--nb_classes', default=1000, type=int, help='number of the classification types')
62
- parser.add_argument('--output_dir', default='', help='path where to save')
63
- parser.add_argument('--log_dir', default='', help='path where to tensorboard log')
64
- parser.add_argument('--device', default='cuda', help='device to use for training / testing')
65
- parser.add_argument('--seed', default=0, type=int)
66
- parser.add_argument('--resume', default='', help='resume from checkpoint')
67
- parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
68
- parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
69
- parser.set_defaults(eval=True)
70
- parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation')
71
- parser.add_argument('--num_workers', default=10, type=int)
72
- parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader')
73
- parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
74
- parser.set_defaults(pin_mem=True)
75
- parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
76
- parser.add_argument('--local_rank', default=-1, type=int)
77
- parser.add_argument('--dist_on_itp', action='store_true')
78
- parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
79
- return parser
80
-
81
-
82
- def load_model(select_skpt):
83
- global ckpt, device, model, checkpoint
84
- if select_skpt not in CKPT_NAME:
85
- return gr.update(), "Select a correct model"
86
- ckpt = select_skpt
87
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88
- args.nb_classes = CKPT_CLASS[ckpt]
89
- model = models_vit.__dict__[CKPT_MODEL[ckpt]](
90
- num_classes=args.nb_classes,
91
- drop_path_rate=args.drop_path,
92
- global_pool=args.global_pool,
93
- ).to(device)
94
-
95
- args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
96
- if os.path.isfile(args.resume) == False:
97
- hf_hub_download(local_dir=CKPT_SAVE_PATH,
98
- local_dir_use_symlinks=False,
99
- repo_id='Wolowolo/fsfm-3c',
100
- filename=CKPT_PATH[ckpt])
101
- args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
102
- checkpoint = torch.load(args.resume, map_location=device)
103
- model.load_state_dict(checkpoint['model'], strict=False)
104
- model.eval()
105
- return gr.update(), f"[Loaded Model Successfully:] {args.resume}] "
106
-
107
-
108
- def get_boundingbox(face, width, height, minsize=None):
109
- x1, y1, x2, y2 = face.left(), face.top(), face.right(), face.bottom()
110
- size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
111
- if minsize and size_bb < minsize:
112
- size_bb = minsize
113
- center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
114
- x1, y1 = max(int(center_x - size_bb // 2), 0), max(int(center_y - size_bb // 2), 0)
115
- size_bb = min(width - x1, size_bb)
116
- size_bb = min(height - y1, size_bb)
117
- return x1, y1, size_bb
118
-
119
-
120
- def extract_face(frame):
121
- face_detector = dlib.get_frontal_face_detector()
122
- image = np.array(frame.convert('RGB'))
123
- faces = face_detector(image, 1)
124
- if faces:
125
- face = faces[0]
126
- x, y, size = get_boundingbox(face, image.shape[1], image.shape[0])
127
- cropped_face = image[y:y + size, x:x + size]
128
- return Image.fromarray(cropped_face)
129
- return None
130
-
131
-
132
- def get_frame_index_uniform_sample(total_frame_num, extract_frame_num):
133
- return np.linspace(0, total_frame_num - 1, num=extract_frame_num, dtype=int).tolist()
134
-
135
-
136
- def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None):
137
- video_capture = cv2.VideoCapture(src_video)
138
- total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
139
- frame_indices = get_frame_index_uniform_sample(total_frames, num_frames) if num_frames else range(total_frames)
140
- for frame_index in frame_indices:
141
- video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
142
- ret, frame = video_capture.read()
143
- if not ret:
144
- continue
145
- image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
146
- img = extract_face(image)
147
- if img:
148
- img = img.resize((224, 224), Image.BICUBIC)
149
- save_img_name = f"frame_{frame_index}.png"
150
- img.save(os.path.join(dst_path, '0', save_img_name))
151
- video_capture.release()
152
- return frame_indices
153
-
154
-
155
- def FSFM3C_image_detection(image):
156
- frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
157
- os.makedirs(frame_path, exist_ok=True)
158
- os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
159
- img = extract_face(image)
160
- if img is None:
161
- return 'No face detected, please upload a clear face!'
162
- img = img.resize((224, 224), Image.BICUBIC)
163
- img.save(os.path.join(frame_path, '0', "frame_0.png"))
164
- args.data_path = frame_path
165
- args.batch_size = 1
166
- dataset_val = build_dataset(is_train=False, args=args)
167
- sampler_val = torch.utils.data.SequentialSampler(dataset_val)
168
- data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
169
-
170
- if CKPT_CLASS[ckpt] > 2:
171
- frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
172
- class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
173
- avg_video_pred = np.mean(video_pred_list, axis=0)
174
- max_prob_index = np.argmax(avg_video_pred)
175
- max_prob_class = class_names[max_prob_index]
176
- probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
177
- image_results = f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]"
178
- return image_results
179
-
180
- if CKPT_CLASS[ckpt] == 2:
181
- frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
182
- if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
183
- prob = sum(video_pred_list) / len(video_pred_list)
184
- label = "Deepfake" if prob <= 0.5 else "Real"
185
- prob = prob if label == "Real" else 1 - prob
186
- if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
187
- prob = sum(video_pred_list) / len(video_pred_list)
188
- label = "Spoofing" if prob <= 0.5 else "Bonafide"
189
- prob = prob if label == "Bonafide" else 1 - prob
190
- image_results = f"The largest face in this image may be {label} with probability {prob * 100:.1f}%"
191
- return image_results
192
-
193
-
194
- def FSFM3C_video_detection(video, num_frames):
195
- try:
196
- frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
197
- os.makedirs(frame_path, exist_ok=True)
198
- os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
199
- frame_indices = extract_face_from_fixed_num_frames(video, frame_path, num_frames=num_frames)
200
- args.data_path = frame_path
201
- args.batch_size = num_frames
202
- dataset_val = build_dataset(is_train=False, args=args)
203
- sampler_val = torch.utils.data.SequentialSampler(dataset_val)
204
- data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
205
-
206
- if CKPT_CLASS[ckpt] > 2:
207
- frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
208
- class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
209
- avg_video_pred = np.mean(video_pred_list, axis=0)
210
- max_prob_index = np.argmax(avg_video_pred)
211
- max_prob_class = class_names[max_prob_index]
212
- probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
213
-
214
- frame_results = {f"frame_{frame_indices[i]}": [f"{class_names[j]}: {prob * 100:.1f}%" for j, prob in enumerate(frame_preds_list[i])] for i in range(len(frame_indices))}
215
- video_results = (f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]\n \n"
216
- f"The frame-level detection results ['frame_index': 'probabilities']: \n{frame_results}")
217
- return video_results
218
-
219
- if CKPT_CLASS[ckpt] == 2:
220
- frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
221
- if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
222
- prob = sum(video_pred_list) / len(video_pred_list)
223
- label = "Deepfake" if prob <= 0.5 else "Real"
224
- prob = prob if label == "Real" else 1 - prob
225
- frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
226
- range(len(frame_indices))} if label == "Real" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
227
- range(len(frame_indices))}
228
-
229
- if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
230
- prob = sum(video_pred_list) / len(video_pred_list)
231
- label = "Spoofing" if prob <= 0.5 else "Bonafide"
232
- prob = prob if label == "Bonafide" else 1 - prob
233
- frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
234
- range(len(frame_indices))} if label == "Bonafide" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
235
- range(len(frame_indices))}
236
-
237
- video_results = (f"The largest face in this image may be {label} with probability {prob * 100:.1f}%\n \n"
238
- f"The frame-level detection results ['frame_index': 'real_face_probability']: \n{frame_results}")
239
- return video_results
240
- except Exception as e:
241
- return f"Error occurred. Please provide a clear face video or reduce the number of frames."
242
-
243
- # Paths and Constants
244
- P = os.path.abspath(__file__)
245
- FRAME_SAVE_PATH = os.path.join(os.path.dirname(P), 'frame')
246
- CKPT_SAVE_PATH = os.path.join(os.path.dirname(P), 'checkpoints')
247
- os.makedirs(FRAME_SAVE_PATH, exist_ok=True)
248
- os.makedirs(CKPT_SAVE_PATH, exist_ok=True)
249
- CKPT_NAME = [
250
- '✨Unified-detector_v1_Fine-tuned_on_4_classes',
251
- 'DfD-Checkpoint_Fine-tuned_on_FF++',
252
- 'FAS-Checkpoint_Fine-tuned_on_MCIO',
253
- ]
254
- CKPT_PATH = {
255
- '✨Unified-detector_v1_Fine-tuned_on_4_classes': 'finetuned_models/Unified-detector/v1_Fine-tuned_on_4_classes/checkpoint-min_train_loss.pth',
256
- 'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
257
- 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth',
258
- }
259
- CKPT_CLASS = {
260
- '✨Unified-detector_v1_Fine-tuned_on_4_classes': 4,
261
- 'DfD-Checkpoint_Fine-tuned_on_FF++': 2,
262
- 'FAS-Checkpoint_Fine-tuned_on_MCIO': 2
263
- }
264
- CKPT_MODEL = {
265
- '✨Unified-detector_v1_Fine-tuned_on_4_classes': 'vit_base_patch16',
266
- 'DfD-Checkpoint_Fine-tuned_on_FF++': 'vit_base_patch16',
267
- 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'vit_base_patch16',
268
- }
269
-
270
-
271
- with gr.Blocks(css=".custom-label { font-weight: bold !important; font-size: 16px !important; }") as demo:
272
- gr.HTML("<h1 style='text-align: center;'>🦱 Real Facial Image&Video Detection <br> Against Face Forgery (Deepfake/Diffusion) and Spoofing (Presentation-attacks)</h1>")
273
- gr.Markdown("<b>☉ Powered by the fine-tuned ViT models that is pre-trained from [FSFM-3C](https://fsfm-3c.github.io/)</b> <br> "
274
- "<b>☉ We do not and cannot access or store the data you have uploaded!</b> <br> "
275
- "<b>☉ Release (Continuously updating) </b> <br> <b>[V1.0] 2025/02/22-Current🎉</b>: "
276
- "1) Updated <b>[✨Unified-detector_v1] for Unified Physical-Digital Face Attack&Forgery Detection, a ViT-B/16-224 (FSFM Pre-trained) detector that could identify Real&Bonafide, Deepfake, Diffusion&AIGC, Spooing&Presentation-attacks facial images or videos </b> ; 2) Provided the selection of the number of video frames (uniformly sampling 1-32 frames, more frames may time-consuming for this page without GPU acceleration); 3) Fixed some errors of V0.1 including loading and prediction. <br>"
277
- "<b>[V0.1] 2024/12-2025/02/21</b>: "
278
- "Create this page with basic detectors [DfD-Checkpoint_Fine-tuned_on_FF++, FAS-Checkpoint_Fine-tuned_on_MCIO] that follow the paper implementation. <br> ")
279
- gr.Markdown("- Please <b>provide a facial image or video(<100s)</b>, and <b>select the model</b> for detection: <br> <b>[SUGGEST] [✨Unified-detector_v1_Fine-tuned_on_4_classes]</b> a (FSFM Pre-trained) ViT-B/16-224 for Both Real/Deepfake/Diffusion/Spoofing facial images&videos Detection <br> <b>[DfD-Checkpoint_Fine-tuned_on_FF++]</b> for deepfake detection, FSFM ViT-B/16-224 fine-tuned on the FF++_c23 train&val sets (4 manipulations, 32 frames per video) <br> <b>[FAS-Checkpoint_Fine-tuned_on_MCIO]</b> for face anti-spoofing, FSFM ViT-B/16-224 fine-tuned on the MCIO datasets (2 frames per video)")
280
-
281
-
282
- with gr.Row():
283
- ckpt_select_dropdown = gr.Dropdown(
284
- label="Select the Model for Detection ⬇️",
285
- elem_classes="custom-label",
286
- choices=['Choose Model Here 🖱️'] + CKPT_NAME + ['continuously updating...'],
287
- multiselect=False,
288
- value='Choose Model Here 🖱️',
289
- interactive=True,
290
- )
291
- model_loading_status = gr.Textbox(label="Model Loading Status")
292
- with gr.Row():
293
- with gr.Column(scale=5):
294
- gr.Markdown("### Image Detection (Fast Try: copying image from [whichfaceisreal](https://whichfaceisreal.com/))")
295
- image = gr.Image(label="Upload/Capture/Paste your image", type="pil")
296
- image_submit_btn = gr.Button("Submit")
297
- output_results_image = gr.Textbox(label="Detection Result")
298
- with gr.Column(scale=5):
299
- gr.Markdown("### Video Detection")
300
- video = gr.Video(label="Upload/Capture your video")
301
- frame_slider = gr.Slider(minimum=1, maximum=32, step=1, value=32, label="Number of Frames for Detection")
302
- video_submit_btn = gr.Button("Submit")
303
- output_results_video = gr.Textbox(label="Detection Result")
304
-
305
- ckpt_select_dropdown.change(
306
- fn=load_model,
307
- inputs=[ckpt_select_dropdown],
308
- outputs=[ckpt_select_dropdown, model_loading_status],
309
- )
310
- image_submit_btn.click(
311
- fn=FSFM3C_image_detection,
312
- inputs=[image],
313
- outputs=[output_results_image],
314
- )
315
- video_submit_btn.click(
316
- fn=FSFM3C_video_detection,
317
- inputs=[video, frame_slider],
318
- outputs=[output_results_video],
319
- )
320
-
321
-
322
- if __name__ == "__main__":
323
- args = get_args_parser()
324
- args = args.parse_args()
325
- ckpt = '✨Unified-detector_v1_Fine-tuned_on_4_classes'
326
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
327
- args.nb_classes = CKPT_CLASS[ckpt]
328
- model = models_vit.__dict__[CKPT_MODEL[ckpt]](
329
- num_classes=args.nb_classes,
330
- drop_path_rate=args.drop_path,
331
- global_pool=args.global_pool,
332
- ).to(device)
333
- args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
334
- if os.path.isfile(args.resume) == False:
335
- hf_hub_download(local_dir=CKPT_SAVE_PATH,
336
- local_dir_use_symlinks=False,
337
- repo_id='Wolowolo/fsfm-3c',
338
- filename=CKPT_PATH[ckpt])
339
- args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
340
- checkpoint = torch.load(args.resume, map_location=device)
341
- model.load_state_dict(checkpoint['model'], strict=False)
342
- model.eval()
343
-
344
- gr.close_all()
345
- demo.queue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  demo.launch()
 
1
+ # -*- coding: utf-8 -*-
2
+ # Author: Gaojian Wang@ZJUICSR; TongWu@ZJUICSR
3
+ # --------------------------------------------------------
4
+ # This source code is licensed under the Attribution-NonCommercial 4.0 International License.
5
+ # You can find the license in the LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+
8
+ import sys
9
+ import os
10
+ os.system(f'pip install dlib')
11
+ import dlib
12
+ import argparse
13
+ import numpy as np
14
+ from PIL import Image
15
+ import cv2
16
+ import torch
17
+ from huggingface_hub import hf_hub_download
18
+ import gradio as gr
19
+
20
+ import models_vit
21
+ from util.datasets import build_dataset
22
+ from engine_finetune import test_two_class, test_multi_class
23
+ import matplotlib.pyplot as plt
24
+ from torchvision import transforms
25
+ import traceback
26
+ from pytorch_grad_cam import (
27
+ GradCAM,ScoreCAM,
28
+ XGradCAM, EigenCAM
29
+ )
30
+ from pytorch_grad_cam import GuidedBackpropReLUModel
31
+ from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
32
+
33
+ def reshape_transform(tensor,height=14,width=14):
34
+ result = tensor[:, 1:, :].reshape(tensor.size(0),height,width,tensor.size(2))
35
+ result = result.transpose(2,3).transpose(1,2)
36
+ return result
37
+
38
+ def get_args_parser():
39
+ parser = argparse.ArgumentParser('FSFM3C fine-tuning&Testing for image classification', add_help=False)
40
+ parser.add_argument('--batch_size', default=64, type=int, help='Batch size per GPU')
41
+ parser.add_argument('--epochs', default=50, type=int)
42
+ parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations')
43
+ parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', help='Name of model to train')
44
+ parser.add_argument('--input_size', default=224, type=int, help='images input size')
45
+ parser.add_argument('--normalize_from_IMN', action='store_true', help='cal mean and std from imagenet')
46
+ parser.set_defaults(normalize_from_IMN=True)
47
+ parser.add_argument('--apply_simple_augment', action='store_true', help='apply simple data augment')
48
+ parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path rate')
49
+ parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', help='Clip gradient norm')
50
+ parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay')
51
+ parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate')
52
+ parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', help='base learning rate')
53
+ parser.add_argument('--layer_decay', type=float, default=0.75, help='layer-wise lr decay')
54
+ parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound')
55
+ parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR')
56
+ parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', help='Color jitter factor')
57
+ parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy')
58
+ parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing')
59
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob')
60
+ parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode')
61
+ parser.add_argument('--recount', type=int, default=1, help='Random erase count')
62
+ parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first augmentation split')
63
+ parser.add_argument('--mixup', type=float, default=0, help='mixup alpha')
64
+ parser.add_argument('--cutmix', type=float, default=0, help='cutmix alpha')
65
+ parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio')
66
+ parser.add_argument('--mixup_prob', type=float, default=1.0, help='Probability of performing mixup or cutmix')
67
+ parser.add_argument('--mixup_switch_prob', type=float, default=0.5, help='Probability of switching to cutmix')
68
+ parser.add_argument('--mixup_mode', type=str, default='batch', help='How to apply mixup/cutmix params')
69
+ parser.add_argument('--finetune', default='', help='finetune from checkpoint')
70
+ parser.add_argument('--global_pool', action='store_true')
71
+ parser.set_defaults(global_pool=True)
72
+ parser.add_argument('--cls_token', action='store_false', dest='global_pool', help='Use class token for classification')
73
+ parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, help='dataset path')
74
+ parser.add_argument('--nb_classes', default=1000, type=int, help='number of the classification types')
75
+ parser.add_argument('--output_dir', default='', help='path where to save')
76
+ parser.add_argument('--log_dir', default='', help='path where to tensorboard log')
77
+ parser.add_argument('--device', default='cuda', help='device to use for training / testing')
78
+ parser.add_argument('--seed', default=0, type=int)
79
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
80
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch')
81
+ parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
82
+ parser.set_defaults(eval=True)
83
+ parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation')
84
+ parser.add_argument('--num_workers', default=10, type=int)
85
+ parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader')
86
+ parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
87
+ parser.set_defaults(pin_mem=True)
88
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
89
+ parser.add_argument('--local_rank', default=-1, type=int)
90
+ parser.add_argument('--dist_on_itp', action='store_true')
91
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
92
+ return parser
93
+
94
+
95
+ def load_model(select_skpt):
96
+ global ckpt, device, model, checkpoint
97
+ if select_skpt not in CKPT_NAME:
98
+ return gr.update(), "Select a correct model"
99
+ ckpt = select_skpt
100
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
+ args.nb_classes = CKPT_CLASS[ckpt]
102
+ model = models_vit.__dict__[CKPT_MODEL[ckpt]](
103
+ num_classes=args.nb_classes,
104
+ drop_path_rate=args.drop_path,
105
+ global_pool=args.global_pool,
106
+ ).to(device)
107
+
108
+ args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
109
+ if os.path.isfile(args.resume) == False:
110
+ hf_hub_download(local_dir=CKPT_SAVE_PATH,
111
+ local_dir_use_symlinks=False,
112
+ repo_id='Wolowolo/fsfm-3c',
113
+ filename=CKPT_PATH[ckpt])
114
+ args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
115
+ checkpoint = torch.load(args.resume, map_location=device)
116
+ model.load_state_dict(checkpoint['model'], strict=False)
117
+ model.eval()
118
+ global cam
119
+ cam = GradCAM(model = model,
120
+ target_layers=[model.blocks[-1].norm1],
121
+ reshape_transform=reshape_transform
122
+ )
123
+ return gr.update(), f"[Loaded Model Successfully:] {args.resume}] "
124
+
125
+
126
+ def get_boundingbox(face, width, height, minsize=None):
127
+ x1, y1, x2, y2 = face.left(), face.top(), face.right(), face.bottom()
128
+ size_bb = int(max(x2 - x1, y2 - y1) * 1.3)
129
+ if minsize and size_bb < minsize:
130
+ size_bb = minsize
131
+ center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2
132
+ x1, y1 = max(int(center_x - size_bb // 2), 0), max(int(center_y - size_bb // 2), 0)
133
+ size_bb = min(width - x1, size_bb)
134
+ size_bb = min(height - y1, size_bb)
135
+ return x1, y1, size_bb
136
+
137
+
138
+ def extract_face(frame):
139
+ face_detector = dlib.get_frontal_face_detector()
140
+ image = np.array(frame.convert('RGB'))
141
+ faces = face_detector(image, 1)
142
+ if faces:
143
+ face = faces[0]
144
+ x, y, size = get_boundingbox(face, image.shape[1], image.shape[0])
145
+ cropped_face = image[y:y + size, x:x + size]
146
+ return Image.fromarray(cropped_face)
147
+ return None
148
+
149
+
150
+ def get_frame_index_uniform_sample(total_frame_num, extract_frame_num):
151
+ return np.linspace(0, total_frame_num - 1, num=extract_frame_num, dtype=int).tolist()
152
+
153
+
154
+ def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None):
155
+ video_capture = cv2.VideoCapture(src_video)
156
+ total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
157
+ frame_indices = get_frame_index_uniform_sample(total_frames, num_frames) if num_frames else range(total_frames)
158
+ for frame_index in frame_indices:
159
+ video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
160
+ ret, frame = video_capture.read()
161
+ if not ret:
162
+ continue
163
+ image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
164
+ img = extract_face(image)
165
+ if img:
166
+ img = img.resize((224, 224), Image.BICUBIC)
167
+ save_img_name = f"frame_{frame_index}.png"
168
+ img.save(os.path.join(dst_path, '0', save_img_name))
169
+ video_capture.release()
170
+ return frame_indices
171
+ class TargetCategory:
172
+ def __init__(self, category_index):
173
+ self.category_index = category_index
174
+
175
+ def __call__(self, output):
176
+ return output[self.category_index]
177
+ def preprocess_image_cam(pil_img,mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]):
178
+ # 将 PIL 图像转换为 numpy 数组
179
+ img_np = np.array(pil_img)
180
+
181
+ # 归一化到 [0, 1]
182
+ img_np = img_np.astype(np.float32) / 255.0
183
+
184
+ # 标准化
185
+ img_np = (img_np - mean) / std
186
+
187
+ # 调整维度顺序以适应模型输入 (C, H, W)
188
+ img_np = np.transpose(img_np, (2, 0, 1))
189
+
190
+ # 添加批次维度 (B, C, H, W)
191
+ img_np = np.expand_dims(img_np, axis=0)
192
+
193
+ return img_np
194
+ def FSFM3C_image_detection(image):
195
+ frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
196
+ os.makedirs(frame_path, exist_ok=True)
197
+ os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
198
+ img = extract_face(image)
199
+ if img is None:
200
+ return 'No face detected, please upload a clear face!'
201
+ img = img.resize((224, 224), Image.BICUBIC)
202
+ img.save(os.path.join(frame_path, '0', "frame_0.png"))
203
+ args.data_path = frame_path
204
+ args.batch_size = 1
205
+ dataset_val = build_dataset(is_train=False, args=args)
206
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
207
+ data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
208
+
209
+ if CKPT_CLASS[ckpt] > 2:
210
+ frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
211
+ class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
212
+ avg_video_pred = np.mean(video_pred_list, axis=0)
213
+ max_prob_index = np.argmax(avg_video_pred)
214
+ max_prob_class = class_names[max_prob_index]
215
+ probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
216
+ image_results = f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]"
217
+
218
+ # Generate CAM heatmap for the detected class
219
+ use_cuda = True
220
+ input_tensor = preprocess_image(img, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
221
+ if use_cuda:
222
+ input_tensor = input_tensor.cuda()
223
+
224
+ # Dynamically determine the target category based on the maximum probability class
225
+ category_names_to_index = {
226
+ 'Real or Bonafide': 0,
227
+ 'Deepfake': 1,
228
+ 'Diffusion or AIGC generated': 2,
229
+ 'Spoofing or Presentation-attack': 3
230
+ }
231
+ target_category = TargetCategory(category_names_to_index[max_prob_class])
232
+
233
+ grayscale_cam = cam(input_tensor=input_tensor, targets=[target_category])
234
+ grayscale_cam = 1 - grayscale_cam[0, :]
235
+ img = np.array(img)
236
+ if img.shape[2] == 4:
237
+ img = img[:, :, :3]
238
+ img = img.astype(np.float32) / 255.0
239
+ visualization = show_cam_on_image(img, grayscale_cam)
240
+ visualization = cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR)
241
+
242
+ # Add text overlay to the heatmap
243
+ # text = f"Detected: {max_prob_class}"
244
+ # cv2.putText(visualization, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
245
+ output_path = "./CAM_images/output_heatmap.png"
246
+ cv2.imwrite(output_path, visualization)
247
+ return image_results, output_path,probabilities[max_prob_index]
248
+
249
+ if CKPT_CLASS[ckpt] == 2:
250
+ frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
251
+ if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
252
+ prob = sum(video_pred_list) / len(video_pred_list)
253
+ label = "Deepfake" if prob <= 0.5 else "Real"
254
+ prob = prob if label == "Real" else 1 - prob
255
+ if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
256
+ prob = sum(video_pred_list) / len(video_pred_list)
257
+ label = "Spoofing" if prob <= 0.5 else "Bonafide"
258
+ prob = prob if label == "Bonafide" else 1 - prob
259
+ image_results = f"The largest face in this image may be {label} with probability {prob * 100:.1f}%"
260
+ return image_results, None ,None
261
+
262
+
263
+ def FSFM3C_video_detection(video, num_frames):
264
+ try:
265
+ frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH))))
266
+ os.makedirs(frame_path, exist_ok=True)
267
+ os.makedirs(os.path.join(frame_path, '0'), exist_ok=True)
268
+ frame_indices = extract_face_from_fixed_num_frames(video, frame_path, num_frames=num_frames)
269
+ args.data_path = frame_path
270
+ args.batch_size = num_frames
271
+ dataset_val = build_dataset(is_train=False, args=args)
272
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
273
+ data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=False)
274
+
275
+ if CKPT_CLASS[ckpt] > 2:
276
+ frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device)
277
+ class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack']
278
+ avg_video_pred = np.mean(video_pred_list, axis=0)
279
+ max_prob_index = np.argmax(avg_video_pred)
280
+ max_prob_class = class_names[max_prob_index]
281
+ probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)]
282
+
283
+ frame_results = {f"frame_{frame_indices[i]}": [f"{class_names[j]}: {prob * 100:.1f}%" for j, prob in enumerate(frame_preds_list[i])] for i in range(len(frame_indices))}
284
+ video_results = (f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]\n \n"
285
+ f"The frame-level detection results ['frame_index': 'probabilities']: \n{frame_results}")
286
+ return video_results
287
+
288
+ if CKPT_CLASS[ckpt] == 2:
289
+ frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device)
290
+ if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++':
291
+ prob = sum(video_pred_list) / len(video_pred_list)
292
+ label = "Deepfake" if prob <= 0.5 else "Real"
293
+ prob = prob if label == "Real" else 1 - prob
294
+ frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
295
+ range(len(frame_indices))} if label == "Real" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
296
+ range(len(frame_indices))}
297
+
298
+ if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO':
299
+ prob = sum(video_pred_list) / len(video_pred_list)
300
+ label = "Spoofing" if prob <= 0.5 else "Bonafide"
301
+ prob = prob if label == "Bonafide" else 1 - prob
302
+ frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in
303
+ range(len(frame_indices))} if label == "Bonafide" else {f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in
304
+ range(len(frame_indices))}
305
+
306
+ video_results = (f"The largest face in this image may be {label} with probability {prob * 100:.1f}%\n \n"
307
+ f"The frame-level detection results ['frame_index': 'real_face_probability']: \n{frame_results}")
308
+ return video_results
309
+ except Exception as e:
310
+ return f"Error occurred. Please provide a clear face video or reduce the number of frames."
311
+
312
+ # Paths and Constants
313
+ P = os.path.abspath(__file__)
314
+ FRAME_SAVE_PATH = os.path.join(os.path.dirname(P), 'frame')
315
+ CKPT_SAVE_PATH = os.path.join(os.path.dirname(P), 'checkpoints')
316
+ os.makedirs(FRAME_SAVE_PATH, exist_ok=True)
317
+ os.makedirs(CKPT_SAVE_PATH, exist_ok=True)
318
+ CKPT_NAME = [
319
+ '✨Unified-detector_v1_Fine-tuned_on_4_classes',
320
+ 'DfD-Checkpoint_Fine-tuned_on_FF++',
321
+ 'FAS-Checkpoint_Fine-tuned_on_MCIO',
322
+ ]
323
+ CKPT_PATH = {
324
+ '✨Unified-detector_v1_Fine-tuned_on_4_classes': 'finetuned_models/Unified-detector/v1_Fine-tuned_on_4_classes/checkpoint-min_train_loss.pth',
325
+ 'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth',
326
+ 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth',
327
+ }
328
+ CKPT_CLASS = {
329
+ '✨Unified-detector_v1_Fine-tuned_on_4_classes': 4,
330
+ 'DfD-Checkpoint_Fine-tuned_on_FF++': 2,
331
+ 'FAS-Checkpoint_Fine-tuned_on_MCIO': 2
332
+ }
333
+ CKPT_MODEL = {
334
+ '✨Unified-detector_v1_Fine-tuned_on_4_classes': 'vit_base_patch16',
335
+ 'DfD-Checkpoint_Fine-tuned_on_FF++': 'vit_base_patch16',
336
+ 'FAS-Checkpoint_Fine-tuned_on_MCIO': 'vit_base_patch16',
337
+ }
338
+
339
+ with gr.Blocks(css=".custom-label { font-weight: bold !important; font-size: 16px !important; }") as demo:
340
+ gr.HTML("<h1 style='text-align: center;'>🦱 Real Facial Image&Video Detection <br> Against Face Forgery (Deepfake/Diffusion) and Spoofing (Presentation-attacks)</h1>")
341
+ gr.Markdown("<b>☉ Powered by the fine-tuned ViT models that is pre-trained from [FSFM-3C](https://fsfm-3c.github.io/)</b> <br> "
342
+ "<b>☉ We do not and cannot access or store the data you have uploaded!</b> <br> "
343
+ "<b>☉ Release (Continuously updating) </b> <br> <b>[V1.0] 2025/02/22-Current🎉</b>: "
344
+ "1) Updated <b>[✨Unified-detector_v1] for Unified Physical-Digital Face Attack&Forgery Detection, a ViT-B/16-224 (FSFM Pre-trained) detector that could identify Real&Bonafide, Deepfake, Diffusion&AIGC, Spooing&Presentation-attacks facial images or videos </b> ; 2) Provided the selection of the number of video frames (uniformly sampling 1-32 frames, more frames may time-consuming for this page without GPU acceleration); 3) Fixed some errors of V0.1 including loading and prediction. <br>"
345
+ "<b>[V0.1] 2024/12-2025/02/21</b>: "
346
+ "Create this page with basic detectors [DfD-Checkpoint_Fine-tuned_on_FF++, FAS-Checkpoint_Fine-tuned_on_MCIO] that follow the paper implementation. <br> ")
347
+ gr.Markdown("- Please <b>provide a facial image or video(<100s)</b>, and <b>select the model</b> for detection: <br> <b>[SUGGEST] [✨Unified-detector_v1_Fine-tuned_on_4_classes]</b> a (FSFM Pre-trained) ViT-B/16-224 for Both Real/Deepfake/Diffusion/Spoofing facial images&videos Detection <br> <b>[DfD-Checkpoint_Fine-tuned_on_FF++]</b> for deepfake detection, FSFM ViT-B/16-224 fine-tuned on the FF++_c23 train&val sets (4 manipulations, 32 frames per video) <br> <b>[FAS-Checkpoint_Fine-tuned_on_MCIO]</b> for face anti-spoofing, FSFM ViT-B/16-224 fine-tuned on the MCIO datasets (2 frames per video)")
348
+
349
+ with gr.Row():
350
+ ckpt_select_dropdown = gr.Dropdown(
351
+ label="Select the Model for Detection ⬇️",
352
+ elem_classes="custom-label",
353
+ choices=['Choose Model Here 🖱️'] + CKPT_NAME + ['continuously updating...'],
354
+ multiselect=False,
355
+ value='Choose Model Here 🖱️',
356
+ interactive=True,
357
+ )
358
+ model_loading_status = gr.Textbox(label="Model Loading Status")
359
+ with gr.Row():
360
+ with gr.Column(scale=5):
361
+ gr.Markdown("### Image Detection (Fast Try: copying image from [whichfaceisreal](https://whichfaceisreal.com/))")
362
+ image = gr.Image(label="Upload/Capture/Paste your image", type="pil")
363
+ image_submit_btn = gr.Button("Submit")
364
+ output_results_image = gr.Textbox(label="Detection Result")
365
+
366
+ with gr.Row():
367
+ output_heatmap = gr.Image(label="Grad_CAM")
368
+ output_max_prob_class = gr.Textbox(label="Detected Class")
369
+ with gr.Column(scale=5):
370
+ gr.Markdown("### Video Detection")
371
+ video = gr.Video(label="Upload/Capture your video")
372
+ frame_slider = gr.Slider(minimum=1, maximum=32, step=1, value=32, label="Number of Frames for Detection")
373
+ video_submit_btn = gr.Button("Submit")
374
+ output_results_video = gr.Textbox(label="Detection Result")
375
+
376
+ ckpt_select_dropdown.change(
377
+ fn=load_model,
378
+ inputs=[ckpt_select_dropdown],
379
+ outputs=[ckpt_select_dropdown, model_loading_status],
380
+ )
381
+ image_submit_btn.click(
382
+ fn=FSFM3C_image_detection,
383
+ inputs=[image],
384
+ outputs=[output_results_image, output_heatmap,output_max_prob_class],
385
+ )
386
+ video_submit_btn.click(
387
+ fn=FSFM3C_video_detection,
388
+ inputs=[video, frame_slider],
389
+ outputs=[output_results_video],
390
+ )
391
+
392
+ if __name__ == "__main__":
393
+ args = get_args_parser()
394
+ args = args.parse_args()
395
+ ckpt = '✨Unified-detector_v1_Fine-tuned_on_4_classes'
396
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
397
+ args.nb_classes = CKPT_CLASS[ckpt]
398
+ model = models_vit.__dict__[CKPT_MODEL[ckpt]](
399
+ num_classes=args.nb_classes,
400
+ drop_path_rate=args.drop_path,
401
+ global_pool=args.global_pool,
402
+ ).to(device)
403
+ args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
404
+ if os.path.isfile(args.resume) == False:
405
+ hf_hub_download(local_dir=CKPT_SAVE_PATH,
406
+ local_dir_use_symlinks=False,
407
+ repo_id='Wolowolo/fsfm-3c',
408
+ filename=CKPT_PATH[ckpt])
409
+ args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt])
410
+ checkpoint = torch.load(args.resume, map_location=device)
411
+ model.load_state_dict(checkpoint['model'], strict=False)
412
+ model.eval()
413
+
414
+ gr.close_all()
415
+ demo.queue()
416
  demo.launch()