|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
import os |
|
os.system(f'pip install grad-cam') |
|
os.system(f'pip install dlib') |
|
import dlib |
|
import argparse |
|
import numpy as np |
|
from PIL import Image |
|
import cv2 |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
import gradio as gr |
|
|
|
import models_vit |
|
from util.datasets import build_dataset |
|
from engine_finetune import test_two_class, test_multi_class |
|
import matplotlib.pyplot as plt |
|
from torchvision import transforms |
|
import traceback |
|
from pytorch_grad_cam import ( |
|
GradCAM, ScoreCAM, |
|
XGradCAM, EigenCAM |
|
) |
|
from pytorch_grad_cam import GuidedBackpropReLUModel |
|
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image |
|
|
|
|
|
def reshape_transform(tensor, height=14, width=14): |
|
result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2)) |
|
result = result.transpose(2, 3).transpose(1, 2) |
|
return result |
|
|
|
|
|
def get_args_parser(): |
|
parser = argparse.ArgumentParser('FSFM3C fine-tuning&Testing for image classification', add_help=False) |
|
parser.add_argument('--batch_size', default=64, type=int, help='Batch size per GPU') |
|
parser.add_argument('--epochs', default=50, type=int) |
|
parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations') |
|
parser.add_argument('--model', default='vit_large_patch16', type=str, metavar='MODEL', |
|
help='Name of model to train') |
|
parser.add_argument('--input_size', default=224, type=int, help='images input size') |
|
parser.add_argument('--normalize_from_IMN', action='store_true', help='cal mean and std from imagenet') |
|
parser.set_defaults(normalize_from_IMN=True) |
|
parser.add_argument('--apply_simple_augment', action='store_true', help='apply simple data augment') |
|
parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', help='Drop path rate') |
|
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', help='Clip gradient norm') |
|
parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay') |
|
parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate') |
|
parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', help='base learning rate') |
|
parser.add_argument('--layer_decay', type=float, default=0.75, help='layer-wise lr decay') |
|
parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', help='lower lr bound') |
|
parser.add_argument('--warmup_epochs', type=int, default=5, metavar='N', help='epochs to warmup LR') |
|
parser.add_argument('--color_jitter', type=float, default=None, metavar='PCT', help='Color jitter factor') |
|
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', help='Use AutoAugment policy') |
|
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing') |
|
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', help='Random erase prob') |
|
parser.add_argument('--remode', type=str, default='pixel', help='Random erase mode') |
|
parser.add_argument('--recount', type=int, default=1, help='Random erase count') |
|
parser.add_argument('--resplit', action='store_true', default=False, |
|
help='Do not random erase first augmentation split') |
|
parser.add_argument('--mixup', type=float, default=0, help='mixup alpha') |
|
parser.add_argument('--cutmix', type=float, default=0, help='cutmix alpha') |
|
parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio') |
|
parser.add_argument('--mixup_prob', type=float, default=1.0, help='Probability of performing mixup or cutmix') |
|
parser.add_argument('--mixup_switch_prob', type=float, default=0.5, help='Probability of switching to cutmix') |
|
parser.add_argument('--mixup_mode', type=str, default='batch', help='How to apply mixup/cutmix params') |
|
parser.add_argument('--finetune', default='', help='finetune from checkpoint') |
|
parser.add_argument('--global_pool', action='store_true') |
|
parser.set_defaults(global_pool=True) |
|
parser.add_argument('--cls_token', action='store_false', dest='global_pool', |
|
help='Use class token for classification') |
|
parser.add_argument('--data_path', default='/datasets01/imagenet_full_size/061417/', type=str, help='dataset path') |
|
parser.add_argument('--nb_classes', default=1000, type=int, help='number of the classification types') |
|
parser.add_argument('--output_dir', default='', help='path where to save') |
|
parser.add_argument('--log_dir', default='', help='path where to tensorboard log') |
|
parser.add_argument('--device', default='cuda', help='device to use for training / testing') |
|
parser.add_argument('--seed', default=0, type=int) |
|
parser.add_argument('--resume', default='', help='resume from checkpoint') |
|
parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') |
|
parser.add_argument('--eval', action='store_true', help='Perform evaluation only') |
|
parser.set_defaults(eval=True) |
|
parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation') |
|
parser.add_argument('--num_workers', default=10, type=int) |
|
parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader') |
|
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') |
|
parser.set_defaults(pin_mem=True) |
|
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') |
|
parser.add_argument('--local_rank', default=-1, type=int) |
|
parser.add_argument('--dist_on_itp', action='store_true') |
|
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') |
|
return parser |
|
|
|
|
|
def load_model(select_skpt): |
|
global ckpt, device, model, checkpoint |
|
if select_skpt not in CKPT_NAME: |
|
return gr.update(), "Select a correct model" |
|
ckpt = select_skpt |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
args.nb_classes = CKPT_CLASS[ckpt] |
|
model = models_vit.__dict__[CKPT_MODEL[ckpt]]( |
|
num_classes=args.nb_classes, |
|
drop_path_rate=args.drop_path, |
|
global_pool=args.global_pool, |
|
).to(device) |
|
|
|
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt]) |
|
if os.path.isfile(args.resume) == False: |
|
hf_hub_download(local_dir=CKPT_SAVE_PATH, |
|
local_dir_use_symlinks=False, |
|
repo_id='Wolowolo/fsfm-3c', |
|
filename=CKPT_PATH[ckpt]) |
|
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt]) |
|
checkpoint = torch.load(args.resume, map_location=device) |
|
model.load_state_dict(checkpoint['model'], strict=False) |
|
model.eval() |
|
global cam |
|
cam = GradCAM(model=model, |
|
target_layers=[model.blocks[-1].norm1], |
|
reshape_transform=reshape_transform |
|
) |
|
return gr.update(), f"[Loaded Model Successfully:] {args.resume}] " |
|
|
|
|
|
def get_boundingbox(face, width, height, minsize=None): |
|
x1, y1, x2, y2 = face.left(), face.top(), face.right(), face.bottom() |
|
size_bb = int(max(x2 - x1, y2 - y1) * 1.3) |
|
if minsize and size_bb < minsize: |
|
size_bb = minsize |
|
center_x, center_y = (x1 + x2) // 2, (y1 + y2) // 2 |
|
x1, y1 = max(int(center_x - size_bb // 2), 0), max(int(center_y - size_bb // 2), 0) |
|
size_bb = min(width - x1, size_bb) |
|
size_bb = min(height - y1, size_bb) |
|
return x1, y1, size_bb |
|
|
|
|
|
def extract_face(frame): |
|
face_detector = dlib.get_frontal_face_detector() |
|
image = np.array(frame.convert('RGB')) |
|
faces = face_detector(image, 1) |
|
if faces: |
|
face = faces[0] |
|
x, y, size = get_boundingbox(face, image.shape[1], image.shape[0]) |
|
cropped_face = image[y:y + size, x:x + size] |
|
return Image.fromarray(cropped_face) |
|
return None |
|
|
|
|
|
def get_frame_index_uniform_sample(total_frame_num, extract_frame_num): |
|
return np.linspace(0, total_frame_num - 1, num=extract_frame_num, dtype=int).tolist() |
|
|
|
|
|
def extract_face_from_fixed_num_frames(src_video, dst_path, num_frames=None): |
|
video_capture = cv2.VideoCapture(src_video) |
|
total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
frame_indices = get_frame_index_uniform_sample(total_frames, num_frames) if num_frames else range(total_frames) |
|
for frame_index in frame_indices: |
|
video_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_index) |
|
ret, frame = video_capture.read() |
|
if not ret: |
|
continue |
|
image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
img = extract_face(image) |
|
if img: |
|
img = img.resize((224, 224), Image.BICUBIC) |
|
save_img_name = f"frame_{frame_index}.png" |
|
img.save(os.path.join(dst_path, '0', save_img_name)) |
|
video_capture.release() |
|
return frame_indices |
|
|
|
|
|
class TargetCategory: |
|
def __init__(self, category_index): |
|
self.category_index = category_index |
|
|
|
def __call__(self, output): |
|
return output[self.category_index] |
|
|
|
|
|
def preprocess_image_cam(pil_img, |
|
mean=[0.5482207536697388, 0.42340534925460815, 0.3654651641845703], |
|
std=[0.2789176106452942, 0.2438540756702423, 0.23493893444538116]): |
|
img_np = np.array(pil_img) |
|
img_np = img_np.astype(np.float32) / 255.0 |
|
img_np = (img_np - mean) / std |
|
img_np = np.transpose(img_np, (2, 0, 1)) |
|
img_np = np.expand_dims(img_np, axis=0) |
|
return img_np |
|
|
|
|
|
def FSFM3C_image_detection(image): |
|
frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH)))) |
|
os.makedirs(frame_path, exist_ok=True) |
|
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True) |
|
img = extract_face(image) |
|
if img is None: |
|
return 'No face detected, please upload a clear face!' |
|
img = img.resize((224, 224), Image.BICUBIC) |
|
img.save(os.path.join(frame_path, '0', "frame_0.png")) |
|
args.data_path = frame_path |
|
args.batch_size = 1 |
|
dataset_val = build_dataset(is_train=False, args=args) |
|
sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
|
data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, |
|
num_workers=args.num_workers, pin_memory=args.pin_mem, |
|
drop_last=False) |
|
|
|
if CKPT_CLASS[ckpt] > 2: |
|
frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device) |
|
class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', 'Spoofing or Presentation-attack'] |
|
avg_video_pred = np.mean(video_pred_list, axis=0) |
|
max_prob_index = np.argmax(avg_video_pred) |
|
max_prob_class = class_names[max_prob_index] |
|
probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)] |
|
image_results = f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]" |
|
|
|
|
|
use_cuda = torch.cuda.is_available() |
|
input_tensor = preprocess_image(img, |
|
mean=[0.5482207536697388, 0.42340534925460815, 0.3654651641845703], |
|
std=[0.2789176106452942, 0.2438540756702423, 0.23493893444538116]) |
|
if use_cuda: |
|
input_tensor = input_tensor.cuda() |
|
|
|
|
|
category_names_to_index = { |
|
'Real or Bonafide': 0, |
|
'Deepfake': 1, |
|
'Diffusion or AIGC generated': 2, |
|
'Spoofing or Presentation-attack': 3 |
|
} |
|
target_category = TargetCategory(category_names_to_index[max_prob_class]) |
|
|
|
cam = GradCAM(model=model, |
|
target_layers=[model.blocks[-1].norm1], |
|
reshape_transform=reshape_transform |
|
) |
|
grayscale_cam = cam(input_tensor=input_tensor, targets=[target_category], aug_smooth=False, eigen_smooth=True) |
|
grayscale_cam = 1 - grayscale_cam[0, :] |
|
img = np.array(img) |
|
if img.shape[2] == 4: |
|
img = img[:, :, :3] |
|
img = img.astype(np.float32) / 255.0 |
|
visualization = show_cam_on_image(img, grayscale_cam) |
|
visualization = cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
|
|
|
|
cam_path = os.path.join(CAM_SAVE_PATH, str(len(os.listdir(CAM_SAVE_PATH)))) |
|
os.makedirs(cam_path, exist_ok=True) |
|
os.makedirs(os.path.join(cam_path, '0'), exist_ok=True) |
|
output_path = os.path.join(cam_path, "output_heatmap.png") |
|
cv2.imwrite(output_path, visualization) |
|
return image_results, output_path, probabilities[max_prob_index] |
|
|
|
if CKPT_CLASS[ckpt] == 2: |
|
frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device) |
|
if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++': |
|
prob = sum(video_pred_list) / len(video_pred_list) |
|
label = "Deepfake" if prob <= 0.5 else "Real" |
|
prob = prob if label == "Real" else 1 - prob |
|
if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO': |
|
prob = sum(video_pred_list) / len(video_pred_list) |
|
label = "Spoofing" if prob <= 0.5 else "Bonafide" |
|
prob = prob if label == "Bonafide" else 1 - prob |
|
image_results = f"The largest face in this image may be {label} with probability {prob * 100:.1f}%" |
|
return image_results, None, None |
|
|
|
|
|
def FSFM3C_video_detection(video, num_frames): |
|
try: |
|
frame_path = os.path.join(FRAME_SAVE_PATH, str(len(os.listdir(FRAME_SAVE_PATH)))) |
|
os.makedirs(frame_path, exist_ok=True) |
|
os.makedirs(os.path.join(frame_path, '0'), exist_ok=True) |
|
frame_indices = extract_face_from_fixed_num_frames(video, frame_path, num_frames=num_frames) |
|
args.data_path = frame_path |
|
args.batch_size = num_frames |
|
dataset_val = build_dataset(is_train=False, args=args) |
|
sampler_val = torch.utils.data.SequentialSampler(dataset_val) |
|
data_loader_val = torch.utils.data.DataLoader(dataset_val, sampler=sampler_val, batch_size=args.batch_size, |
|
num_workers=args.num_workers, pin_memory=args.pin_mem, |
|
drop_last=False) |
|
|
|
if CKPT_CLASS[ckpt] > 2: |
|
frame_preds_list, video_pred_list = test_multi_class(data_loader_val, model, device) |
|
class_names = ['Real or Bonafide', 'Deepfake', 'Diffusion or AIGC generated', |
|
'Spoofing or Presentation-attack'] |
|
avg_video_pred = np.mean(video_pred_list, axis=0) |
|
max_prob_index = np.argmax(avg_video_pred) |
|
max_prob_class = class_names[max_prob_index] |
|
probabilities = [f"{class_names[i]}: {prob * 100:.1f}%" for i, prob in enumerate(avg_video_pred)] |
|
|
|
frame_results = {f"frame_{frame_indices[i]}": [f"{class_names[j]}: {prob * 100:.1f}%" for j, prob in |
|
enumerate(frame_preds_list[i])] for i in |
|
range(len(frame_indices))} |
|
video_results = ( |
|
f"The largest face in this image may be {max_prob_class} with probability: \n [{', '.join(probabilities)}]\n \n" |
|
f"The frame-level detection results ['frame_index': 'probabilities']: \n{frame_results}") |
|
return video_results |
|
|
|
if CKPT_CLASS[ckpt] == 2: |
|
frame_preds_list, video_pred_list = test_two_class(data_loader_val, model, device) |
|
if ckpt == 'DfD-Checkpoint_Fine-tuned_on_FF++': |
|
prob = sum(video_pred_list) / len(video_pred_list) |
|
label = "Deepfake" if prob <= 0.5 else "Real" |
|
prob = prob if label == "Real" else 1 - prob |
|
frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in |
|
range(len(frame_indices))} if label == "Real" else { |
|
f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in |
|
range(len(frame_indices))} |
|
|
|
if ckpt == 'FAS-Checkpoint_Fine-tuned_on_MCIO': |
|
prob = sum(video_pred_list) / len(video_pred_list) |
|
label = "Spoofing" if prob <= 0.5 else "Bonafide" |
|
prob = prob if label == "Bonafide" else 1 - prob |
|
frame_results = {f"frame_{frame_indices[i]}": f"{(frame_preds_list[i]) * 100:.1f}%" for i in |
|
range(len(frame_indices))} if label == "Bonafide" else { |
|
f"frame_{frame_indices[i]}": f"{(1 - frame_preds_list[i]) * 100:.1f}%" for i in |
|
range(len(frame_indices))} |
|
|
|
video_results = (f"The largest face in this image may be {label} with probability {prob * 100:.1f}%\n \n" |
|
f"The frame-level detection results ['frame_index': 'real_face_probability']: \n{frame_results}") |
|
return video_results |
|
except Exception as e: |
|
return f"Error occurred. Please provide a clear face video or reduce the number of frames." |
|
|
|
|
|
|
|
P = os.path.abspath(__file__) |
|
FRAME_SAVE_PATH = os.path.join(os.path.dirname(P), 'frame') |
|
CAM_SAVE_PATH = os.path.join(os.path.dirname(P), 'cam') |
|
CKPT_SAVE_PATH = os.path.join(os.path.dirname(P), 'checkpoints') |
|
os.makedirs(FRAME_SAVE_PATH, exist_ok=True) |
|
os.makedirs(CAM_SAVE_PATH, exist_ok=True) |
|
os.makedirs(CKPT_SAVE_PATH, exist_ok=True) |
|
CKPT_NAME = [ |
|
'✨Unified-detector_v1_Fine-tuned_on_4_classes', |
|
'DfD-Checkpoint_Fine-tuned_on_FF++', |
|
'FAS-Checkpoint_Fine-tuned_on_MCIO', |
|
] |
|
CKPT_PATH = { |
|
'✨Unified-detector_v1_Fine-tuned_on_4_classes': 'finetuned_models/Unified-detector/v1_Fine-tuned_on_4_classes/checkpoint-min_train_loss.pth', |
|
'DfD-Checkpoint_Fine-tuned_on_FF++': 'finetuned_models/FF++_c23_32frames/checkpoint-min_val_loss.pth', |
|
'FAS-Checkpoint_Fine-tuned_on_MCIO': 'finetuned_models/MCIO_protocol/Both_MCIO/checkpoint-min_val_loss.pth', |
|
} |
|
CKPT_CLASS = { |
|
'✨Unified-detector_v1_Fine-tuned_on_4_classes': 4, |
|
'DfD-Checkpoint_Fine-tuned_on_FF++': 2, |
|
'FAS-Checkpoint_Fine-tuned_on_MCIO': 2 |
|
} |
|
CKPT_MODEL = { |
|
'✨Unified-detector_v1_Fine-tuned_on_4_classes': 'vit_base_patch16', |
|
'DfD-Checkpoint_Fine-tuned_on_FF++': 'vit_base_patch16', |
|
'FAS-Checkpoint_Fine-tuned_on_MCIO': 'vit_base_patch16', |
|
} |
|
|
|
with gr.Blocks(css=".custom-label { font-weight: bold !important; font-size: 16px !important; }") as demo: |
|
gr.HTML( |
|
"<h1 style='text-align: center;'>🦱 Real Facial Image&Video Detection <br> Against Face Forgery (Deepfake/Diffusion) and Spoofing (Presentation-attacks)</h1>") |
|
gr.Markdown( |
|
"<b>☉ Powered by the fine-tuned ViT models that is pre-trained from [FSFM-3C](https://fsfm-3c.github.io/)</b> <br> " |
|
"<b>☉ We do not and cannot access or store the data you have uploaded!</b> <br> " |
|
"<b>☉ Release (Continuously updating [by [Gaojian Wang/汪高健](https://scholar.google.com/citations?user=tpP4cFQAAAAJ&hl=zh-CN&oi=ao), [Tong Wu/吴桐](https://github.com/Coco-T-T), [Xingtang Luo/罗兴塘](https://github.com/Rox-C)]) </b> <br> <b>[V1.0] 2025/02/22-Current🎉</b>: " |
|
"1) Updated <b>[✨Unified-detector_v1] for Physical-Digital Face Attack&Forgery Detection, a ViT-B/16-224 (FSFM Pre-trained) detector that could identify Real&Bonafide, Deepfake, Diffusion&AIGC, Spooing&Presentation-attacks facial images or videos </b> ; 2) Provided the selection of the number of video frames (uniformly sampling 1-32 frames, more frames may time-consuming for this page without GPU acceleration); 3) Fixed some errors of V0.1. <br>" |
|
"<b>[V0.1] 2024/12-2025/02/21</b>: " |
|
"Create this page with basic detectors [DfD-Checkpoint_Fine-tuned_on_FF++, FAS-Checkpoint_Fine-tuned_on_MCIO] that follow the paper implementation. <br> ") |
|
gr.Markdown( |
|
"- Please <b>provide a facial image or video(<100s)</b>, and <b>select the model</b> for detection: <br> <b>[SUGGEST] [✨Unified-detector_v1_Fine-tuned_on_4_classes]</b> a (FSFM Pre-trained) ViT-B/16-224 for Both Real/Deepfake/Diffusion/Spoofing facial images&videos Detection <br> <b>[DfD-Checkpoint_Fine-tuned_on_FF++]</b> for deepfake detection, FSFM ViT-B/16-224 fine-tuned on the FF++_c23 train&val sets (4 manipulations, 32 frames per video) <br> <b>[FAS-Checkpoint_Fine-tuned_on_MCIO]</b> for face anti-spoofing, FSFM ViT-B/16-224 fine-tuned on the MCIO datasets (2 frames per video)") |
|
|
|
with gr.Row(): |
|
ckpt_select_dropdown = gr.Dropdown( |
|
label="Select the Model for Detection ⬇️", |
|
elem_classes="custom-label", |
|
choices=['Choose Model Here 🖱️'] + CKPT_NAME + ['continuously updating...'], |
|
multiselect=False, |
|
value='Choose Model Here 🖱️', |
|
interactive=True, |
|
) |
|
model_loading_status = gr.Textbox(label="Model Loading Status") |
|
with gr.Row(): |
|
with gr.Column(scale=5): |
|
gr.Markdown( |
|
"### Image Detection (Fast Try: copying image from [whichfaceisreal](https://whichfaceisreal.com/))") |
|
image = gr.Image(label="Upload/Capture/Paste your image", type="pil") |
|
image_submit_btn = gr.Button("Submit") |
|
output_results_image = gr.Textbox(label="Detection Result") |
|
|
|
with gr.Row(): |
|
output_heatmap = gr.Image(label="Grad_CAM") |
|
output_max_prob_class = gr.Textbox(label="Detected Class") |
|
with gr.Column(scale=5): |
|
gr.Markdown("### Video Detection") |
|
video = gr.Video(label="Upload/Capture your video") |
|
frame_slider = gr.Slider(minimum=1, maximum=32, step=1, value=32, label="Number of Frames for Detection") |
|
video_submit_btn = gr.Button("Submit") |
|
output_results_video = gr.Textbox(label="Detection Result") |
|
|
|
gr.HTML( |
|
'<div style="display: flex; justify-content: center; gap: 20px; margin-bottom: 20px;">' |
|
'<a href="https://mapmyvisitors.com/web/1bxvi" title="Visit tracker">' |
|
'<img src="https://mapmyvisitors.com/map.png?d=FYhBoxLDEaFAxdfRzk5TuchYOBGrnSa98Ky59EkEEpY&cl=ffffff">' |
|
'</a>' |
|
'</div>' |
|
) |
|
|
|
ckpt_select_dropdown.change( |
|
fn=load_model, |
|
inputs=[ckpt_select_dropdown], |
|
outputs=[ckpt_select_dropdown, model_loading_status], |
|
) |
|
image_submit_btn.click( |
|
fn=FSFM3C_image_detection, |
|
inputs=[image], |
|
outputs=[output_results_image, output_heatmap, output_max_prob_class], |
|
) |
|
video_submit_btn.click( |
|
fn=FSFM3C_video_detection, |
|
inputs=[video, frame_slider], |
|
outputs=[output_results_video], |
|
) |
|
|
|
if __name__ == "__main__": |
|
args = get_args_parser() |
|
args = args.parse_args() |
|
ckpt = '✨Unified-detector_v1_Fine-tuned_on_4_classes' |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
args.nb_classes = CKPT_CLASS[ckpt] |
|
model = models_vit.__dict__[CKPT_MODEL[ckpt]]( |
|
num_classes=args.nb_classes, |
|
drop_path_rate=args.drop_path, |
|
global_pool=args.global_pool, |
|
).to(device) |
|
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt]) |
|
if os.path.isfile(args.resume) == False: |
|
hf_hub_download(local_dir=CKPT_SAVE_PATH, |
|
local_dir_use_symlinks=False, |
|
repo_id='Wolowolo/fsfm-3c', |
|
filename=CKPT_PATH[ckpt]) |
|
args.resume = os.path.join(CKPT_SAVE_PATH, CKPT_PATH[ckpt]) |
|
checkpoint = torch.load(args.resume, map_location=device) |
|
model.load_state_dict(checkpoint['model'], strict=False) |
|
model.eval() |
|
|
|
gr.close_all() |
|
demo.queue() |
|
demo.launch() |