Duplicate from thecho7/deepfake
Browse filesCo-authored-by: Suho Cho <[email protected]>
- .gitattributes +36 -0
- Dockerfile +54 -0
- LICENSE +21 -0
- README.md +14 -0
- __pycache__/kernel_utils.cpython-310.pyc +0 -0
- app.py +86 -0
- configs/b5.json +28 -0
- configs/b7.json +29 -0
- download_weights.sh +9 -0
- examples/liuujwwgpr.mp4 +3 -0
- examples/nlurbvsozt.mp4 +3 -0
- examples/rfjuhbnlro.mp4 +3 -0
- kernel_utils.py +366 -0
- libs/shape_predictor_68_face_landmarks.dat +3 -0
- requirements.txt +131 -0
- training/__init__.py +0 -0
- training/__pycache__/__init__.cpython-310.pyc +0 -0
- training/__pycache__/__init__.cpython-39.pyc +0 -0
- training/__pycache__/losses.cpython-310.pyc +0 -0
- training/__pycache__/losses.cpython-39.pyc +0 -0
- training/datasets/__init__.py +0 -0
- training/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
- training/datasets/__pycache__/classifier_dataset.cpython-310.pyc +0 -0
- training/datasets/__pycache__/validation_set.cpython-310.pyc +0 -0
- training/datasets/classifier_dataset.py +384 -0
- training/datasets/validation_set.py +60 -0
- training/losses.py +28 -0
- training/pipelines/__init__.py +0 -0
- training/pipelines/train_classifier.py +364 -0
- training/tools/__init__.py +0 -0
- training/tools/__pycache__/__init__.cpython-310.pyc +0 -0
- training/tools/__pycache__/config.cpython-310.pyc +0 -0
- training/tools/__pycache__/schedulers.cpython-310.pyc +0 -0
- training/tools/__pycache__/utils.cpython-310.pyc +0 -0
- training/tools/config.py +43 -0
- training/tools/schedulers.py +46 -0
- training/tools/utils.py +121 -0
- training/transforms/__init__.py +0 -0
- training/transforms/__pycache__/__init__.cpython-310.pyc +0 -0
- training/transforms/__pycache__/albu.cpython-310.pyc +0 -0
- training/transforms/albu.py +100 -0
- training/zoo/__init__.py +0 -0
- training/zoo/__pycache__/__init__.cpython-310.pyc +0 -0
- training/zoo/__pycache__/classifiers.cpython-310.pyc +0 -0
- training/zoo/classifiers.py +172 -0
- training/zoo/unet.py +151 -0
- weights/.gitkeep +0 -0
- weights/b7_ns_best.pth +3 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.dat filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ARG PYTORCH="1.10.0"
|
2 |
+
ARG CUDA="11.3"
|
3 |
+
ARG CUDNN="8"
|
4 |
+
|
5 |
+
FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
|
6 |
+
|
7 |
+
ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
|
8 |
+
ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
|
9 |
+
|
10 |
+
# Setting noninteractive build, setting up tzdata and configuring timezones
|
11 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
12 |
+
ENV TZ=Europe/Berlin
|
13 |
+
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
14 |
+
|
15 |
+
RUN apt-get update && apt-get install -y libglib2.0-0 libsm6 libxrender-dev libxext6 nano mc glances vim git \
|
16 |
+
&& apt-get clean \
|
17 |
+
&& rm -rf /var/lib/apt/lists/*
|
18 |
+
|
19 |
+
# Install cython
|
20 |
+
RUN conda install cython -y && conda clean --all
|
21 |
+
|
22 |
+
# Installing APEX
|
23 |
+
RUN pip install -U pip
|
24 |
+
RUN git clone https://github.com/NVIDIA/apex
|
25 |
+
RUN sed -i 's/check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)/pass/g' apex/setup.py
|
26 |
+
RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex
|
27 |
+
RUN apt-get update -y
|
28 |
+
RUN apt-get install build-essential cmake -y
|
29 |
+
RUN apt-get install libopenblas-dev liblapack-dev -y
|
30 |
+
RUN apt-get install libx11-dev libgtk-3-dev -y
|
31 |
+
RUN pip install dlib
|
32 |
+
RUN pip install facenet-pytorch
|
33 |
+
RUN pip install albumentations==1.0.0 timm==0.4.12 pytorch_toolbelt tensorboardx
|
34 |
+
RUN pip install cython jupyter jupyterlab ipykernel matplotlib tqdm pandas
|
35 |
+
|
36 |
+
# download pretraned Imagenet models
|
37 |
+
RUN apt install wget
|
38 |
+
RUN wget https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth -P /root/.cache/torch/hub/checkpoints/
|
39 |
+
RUN wget https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth -P /root/.cache/torch/hub/checkpoints/
|
40 |
+
|
41 |
+
# Setting the working directory
|
42 |
+
WORKDIR /workspace
|
43 |
+
|
44 |
+
# Copying the required codebase
|
45 |
+
COPY . /workspace
|
46 |
+
|
47 |
+
RUN chmod 777 preprocess_data.sh
|
48 |
+
RUN chmod 777 train.sh
|
49 |
+
RUN chmod 777 predict_submission.sh
|
50 |
+
|
51 |
+
ENV PYTHONPATH=.
|
52 |
+
|
53 |
+
CMD ["/bin/bash"]
|
54 |
+
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Selim Seferbekov
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Deepfake
|
3 |
+
emoji: 🔥
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.29.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: unlicense
|
11 |
+
duplicated_from: thecho7/deepfake
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
__pycache__/kernel_utils.cpython-310.pyc
ADDED
Binary file (11.8 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import time
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video
|
8 |
+
from training.zoo.classifiers import DeepFakeClassifier
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
def model_fn(model_dir):
|
13 |
+
model_path = os.path.join(model_dir, 'b7_ns_best.pth')
|
14 |
+
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns") # default: CPU
|
15 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
16 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
17 |
+
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
|
18 |
+
model.eval()
|
19 |
+
del checkpoint
|
20 |
+
#models.append(model.half())
|
21 |
+
|
22 |
+
return model
|
23 |
+
|
24 |
+
def convert_result(pred, class_names=["Real", "Fake"]):
|
25 |
+
preds = [pred, 1 - pred]
|
26 |
+
assert len(class_names) == len(preds), "Class / Prediction should have the same length"
|
27 |
+
return {n: float(p) for n, p in zip(class_names, preds)}
|
28 |
+
|
29 |
+
def predict_fn(video):
|
30 |
+
start = time.time()
|
31 |
+
prediction = predict_on_video(face_extractor=meta["face_extractor"],
|
32 |
+
video_path=video,
|
33 |
+
batch_size=meta["fps"],
|
34 |
+
input_size=meta["input_size"],
|
35 |
+
models=model,
|
36 |
+
strategy=meta["strategy"],
|
37 |
+
apply_compression=False,
|
38 |
+
device='cpu')
|
39 |
+
|
40 |
+
elapsed_time = round(time.time() - start, 2)
|
41 |
+
|
42 |
+
prediction = convert_result(prediction)
|
43 |
+
|
44 |
+
return prediction, elapsed_time
|
45 |
+
|
46 |
+
# Create title, description and article strings
|
47 |
+
title = "Deepfake Detector (private)"
|
48 |
+
description = "A video Deepfake Classifier (code: https://github.com/selimsef/dfdc_deepfake_challenge)"
|
49 |
+
|
50 |
+
example_list = ["examples/" + str(p) for p in os.listdir("examples/")]
|
51 |
+
|
52 |
+
# Environments
|
53 |
+
model_dir = 'weights'
|
54 |
+
frames_per_video = 32
|
55 |
+
video_reader = VideoReader()
|
56 |
+
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
|
57 |
+
face_extractor = FaceExtractor(video_read_fn)
|
58 |
+
input_size = 380
|
59 |
+
strategy = confident_strategy
|
60 |
+
class_names = ["Real", "Fake"]
|
61 |
+
|
62 |
+
meta = {"fps": 32,
|
63 |
+
"face_extractor": face_extractor,
|
64 |
+
"input_size": input_size,
|
65 |
+
"strategy": strategy}
|
66 |
+
|
67 |
+
model = model_fn(model_dir)
|
68 |
+
|
69 |
+
"""
|
70 |
+
if __name__ == '__main__':
|
71 |
+
video_path = "examples/nlurbvsozt.mp4"
|
72 |
+
model = model_fn(model_dir)
|
73 |
+
a, b = predict_fn(video_path)
|
74 |
+
print(a, b)
|
75 |
+
"""
|
76 |
+
# Create the Gradio demo
|
77 |
+
demo = gr.Interface(fn=predict_fn, # mapping function from input to output
|
78 |
+
inputs=gr.Video(),
|
79 |
+
outputs=[gr.Label(num_top_classes=2, label="Predictions"), # what are the outputs?
|
80 |
+
gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
|
81 |
+
examples=example_list,
|
82 |
+
title=title,
|
83 |
+
description=description)
|
84 |
+
|
85 |
+
# Launch the demo!
|
86 |
+
demo.launch(debug=False,) # Hugging face space don't need shareable_links
|
configs/b5.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"network": "DeepFakeClassifier",
|
3 |
+
"encoder": "tf_efficientnet_b5_ns",
|
4 |
+
"batches_per_epoch": 2500,
|
5 |
+
"size": 380,
|
6 |
+
"fp16": true,
|
7 |
+
"optimizer": {
|
8 |
+
"batch_size": 20,
|
9 |
+
"type": "SGD",
|
10 |
+
"momentum": 0.9,
|
11 |
+
"weight_decay": 1e-4,
|
12 |
+
"learning_rate": 0.01,
|
13 |
+
"nesterov": true,
|
14 |
+
"schedule": {
|
15 |
+
"type": "poly",
|
16 |
+
"mode": "step",
|
17 |
+
"epochs": 30,
|
18 |
+
"params": {"max_iter": 75100}
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"normalize": {
|
22 |
+
"mean": [0.485, 0.456, 0.406],
|
23 |
+
"std": [0.229, 0.224, 0.225]
|
24 |
+
},
|
25 |
+
"losses": {
|
26 |
+
"BinaryCrossentropy": 1
|
27 |
+
}
|
28 |
+
}
|
configs/b7.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"network": "DeepFakeClassifier",
|
3 |
+
"encoder": "tf_efficientnet_b7_ns",
|
4 |
+
"batches_per_epoch": 2500,
|
5 |
+
"size": 380,
|
6 |
+
"fp16": true,
|
7 |
+
"optimizer": {
|
8 |
+
"batch_size": 4,
|
9 |
+
"type": "SGD",
|
10 |
+
"momentum": 0.9,
|
11 |
+
"weight_decay": 1e-4,
|
12 |
+
"learning_rate": 1e-4,
|
13 |
+
"nesterov": true,
|
14 |
+
"schedule": {
|
15 |
+
"type": "poly",
|
16 |
+
"mode": "step",
|
17 |
+
"epochs": 20,
|
18 |
+
"params": {"max_iter": 100500}
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"normalize": {
|
22 |
+
"mean": [0.485, 0.456, 0.406],
|
23 |
+
"std": [0.229, 0.224, 0.225]
|
24 |
+
},
|
25 |
+
"losses": {
|
26 |
+
"BinaryCrossentropy": 1
|
27 |
+
}
|
28 |
+
}
|
29 |
+
|
download_weights.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tag=0.0.1
|
2 |
+
|
3 |
+
wget -O weights/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36
|
4 |
+
wget -O weights/final_555_DeepFakeClassifier_tf_efficientnet_b7_ns_0_19 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_555_DeepFakeClassifier_tf_efficientnet_b7_ns_0_19
|
5 |
+
wget -O weights/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_29 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_29
|
6 |
+
wget -O weights/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_31 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_31
|
7 |
+
wget -O weights/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_37 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_37
|
8 |
+
wget -O weights/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_40 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_40
|
9 |
+
wget -O weights/final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23
|
examples/liuujwwgpr.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b3aaefb51aa5720cdabcc68d93da5c6a22573d8da06bdaf5e009c7a370943e85
|
3 |
+
size 12852441
|
examples/nlurbvsozt.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:300b7dea93132b512f35de76572e7fcde666c812b91aec6b189dafa6f100c9b5
|
3 |
+
size 4486723
|
examples/rfjuhbnlro.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b6d0bb841ebe6a8e20cf265b45356a1ea3fed9837025e8d549b2437290d79273
|
3 |
+
size 16218775
|
kernel_utils.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from albumentations.augmentations.functional import image_compression
|
8 |
+
from facenet_pytorch.models.mtcnn import MTCNN
|
9 |
+
from concurrent.futures import ThreadPoolExecutor
|
10 |
+
|
11 |
+
from torchvision.transforms import Normalize
|
12 |
+
|
13 |
+
mean = [0.485, 0.456, 0.406]
|
14 |
+
std = [0.229, 0.224, 0.225]
|
15 |
+
normalize_transform = Normalize(mean, std)
|
16 |
+
|
17 |
+
|
18 |
+
class VideoReader:
|
19 |
+
"""Helper class for reading one or more frames from a video file."""
|
20 |
+
|
21 |
+
def __init__(self, verbose=True, insets=(0, 0)):
|
22 |
+
"""Creates a new VideoReader.
|
23 |
+
|
24 |
+
Arguments:
|
25 |
+
verbose: whether to print warnings and error messages
|
26 |
+
insets: amount to inset the image by, as a percentage of
|
27 |
+
(width, height). This lets you "zoom in" to an image
|
28 |
+
to remove unimportant content around the borders.
|
29 |
+
Useful for face detection, which may not work if the
|
30 |
+
faces are too small.
|
31 |
+
"""
|
32 |
+
self.verbose = verbose
|
33 |
+
self.insets = insets
|
34 |
+
|
35 |
+
def read_frames(self, path, num_frames, jitter=0, seed=None):
|
36 |
+
"""Reads frames that are always evenly spaced throughout the video.
|
37 |
+
|
38 |
+
Arguments:
|
39 |
+
path: the video file
|
40 |
+
num_frames: how many frames to read, -1 means the entire video
|
41 |
+
(warning: this will take up a lot of memory!)
|
42 |
+
jitter: if not 0, adds small random offsets to the frame indices;
|
43 |
+
this is useful so we don't always land on even or odd frames
|
44 |
+
seed: random seed for jittering; if you set this to a fixed value,
|
45 |
+
you probably want to set it only on the first video
|
46 |
+
"""
|
47 |
+
assert num_frames > 0
|
48 |
+
|
49 |
+
capture = cv2.VideoCapture(path)
|
50 |
+
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
51 |
+
if frame_count <= 0: return None
|
52 |
+
|
53 |
+
frame_idxs = np.linspace(0, frame_count - 1, num_frames, endpoint=True, dtype=np.int32)
|
54 |
+
if jitter > 0:
|
55 |
+
np.random.seed(seed)
|
56 |
+
jitter_offsets = np.random.randint(-jitter, jitter, len(frame_idxs))
|
57 |
+
frame_idxs = np.clip(frame_idxs + jitter_offsets, 0, frame_count - 1)
|
58 |
+
|
59 |
+
result = self._read_frames_at_indices(path, capture, frame_idxs)
|
60 |
+
capture.release()
|
61 |
+
return result
|
62 |
+
|
63 |
+
def read_random_frames(self, path, num_frames, seed=None):
|
64 |
+
"""Picks the frame indices at random.
|
65 |
+
|
66 |
+
Arguments:
|
67 |
+
path: the video file
|
68 |
+
num_frames: how many frames to read, -1 means the entire video
|
69 |
+
(warning: this will take up a lot of memory!)
|
70 |
+
"""
|
71 |
+
assert num_frames > 0
|
72 |
+
np.random.seed(seed)
|
73 |
+
|
74 |
+
capture = cv2.VideoCapture(path)
|
75 |
+
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
76 |
+
if frame_count <= 0: return None
|
77 |
+
|
78 |
+
frame_idxs = sorted(np.random.choice(np.arange(0, frame_count), num_frames))
|
79 |
+
result = self._read_frames_at_indices(path, capture, frame_idxs)
|
80 |
+
|
81 |
+
capture.release()
|
82 |
+
return result
|
83 |
+
|
84 |
+
def read_frames_at_indices(self, path, frame_idxs):
|
85 |
+
"""Reads frames from a video and puts them into a NumPy array.
|
86 |
+
|
87 |
+
Arguments:
|
88 |
+
path: the video file
|
89 |
+
frame_idxs: a list of frame indices. Important: should be
|
90 |
+
sorted from low-to-high! If an index appears multiple
|
91 |
+
times, the frame is still read only once.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
- a NumPy array of shape (num_frames, height, width, 3)
|
95 |
+
- a list of the frame indices that were read
|
96 |
+
|
97 |
+
Reading stops if loading a frame fails, in which case the first
|
98 |
+
dimension returned may actually be less than num_frames.
|
99 |
+
|
100 |
+
Returns None if an exception is thrown for any reason, or if no
|
101 |
+
frames were read.
|
102 |
+
"""
|
103 |
+
assert len(frame_idxs) > 0
|
104 |
+
capture = cv2.VideoCapture(path)
|
105 |
+
result = self._read_frames_at_indices(path, capture, frame_idxs)
|
106 |
+
capture.release()
|
107 |
+
return result
|
108 |
+
|
109 |
+
def _read_frames_at_indices(self, path, capture, frame_idxs):
|
110 |
+
try:
|
111 |
+
frames = []
|
112 |
+
idxs_read = []
|
113 |
+
for frame_idx in range(frame_idxs[0], frame_idxs[-1] + 1):
|
114 |
+
# Get the next frame, but don't decode if we're not using it.
|
115 |
+
ret = capture.grab()
|
116 |
+
if not ret:
|
117 |
+
if self.verbose:
|
118 |
+
print("Error grabbing frame %d from movie %s" % (frame_idx, path))
|
119 |
+
break
|
120 |
+
|
121 |
+
# Need to look at this frame?
|
122 |
+
current = len(idxs_read)
|
123 |
+
if frame_idx == frame_idxs[current]:
|
124 |
+
ret, frame = capture.retrieve()
|
125 |
+
if not ret or frame is None:
|
126 |
+
if self.verbose:
|
127 |
+
print("Error retrieving frame %d from movie %s" % (frame_idx, path))
|
128 |
+
break
|
129 |
+
|
130 |
+
frame = self._postprocess_frame(frame)
|
131 |
+
frames.append(frame)
|
132 |
+
idxs_read.append(frame_idx)
|
133 |
+
|
134 |
+
if len(frames) > 0:
|
135 |
+
return np.stack(frames), idxs_read
|
136 |
+
if self.verbose:
|
137 |
+
print("No frames read from movie %s" % path)
|
138 |
+
return None
|
139 |
+
except:
|
140 |
+
if self.verbose:
|
141 |
+
print("Exception while reading movie %s" % path)
|
142 |
+
return None
|
143 |
+
|
144 |
+
def read_middle_frame(self, path):
|
145 |
+
"""Reads the frame from the middle of the video."""
|
146 |
+
capture = cv2.VideoCapture(path)
|
147 |
+
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
148 |
+
result = self._read_frame_at_index(path, capture, frame_count // 2)
|
149 |
+
capture.release()
|
150 |
+
return result
|
151 |
+
|
152 |
+
def read_frame_at_index(self, path, frame_idx):
|
153 |
+
"""Reads a single frame from a video.
|
154 |
+
|
155 |
+
If you just want to read a single frame from the video, this is more
|
156 |
+
efficient than scanning through the video to find the frame. However,
|
157 |
+
for reading multiple frames it's not efficient.
|
158 |
+
|
159 |
+
My guess is that a "streaming" approach is more efficient than a
|
160 |
+
"random access" approach because, unless you happen to grab a keyframe,
|
161 |
+
the decoder still needs to read all the previous frames in order to
|
162 |
+
reconstruct the one you're asking for.
|
163 |
+
|
164 |
+
Returns a NumPy array of shape (1, H, W, 3) and the index of the frame,
|
165 |
+
or None if reading failed.
|
166 |
+
"""
|
167 |
+
capture = cv2.VideoCapture(path)
|
168 |
+
result = self._read_frame_at_index(path, capture, frame_idx)
|
169 |
+
capture.release()
|
170 |
+
return result
|
171 |
+
|
172 |
+
def _read_frame_at_index(self, path, capture, frame_idx):
|
173 |
+
capture.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
174 |
+
ret, frame = capture.read()
|
175 |
+
if not ret or frame is None:
|
176 |
+
if self.verbose:
|
177 |
+
print("Error retrieving frame %d from movie %s" % (frame_idx, path))
|
178 |
+
return None
|
179 |
+
else:
|
180 |
+
frame = self._postprocess_frame(frame)
|
181 |
+
return np.expand_dims(frame, axis=0), [frame_idx]
|
182 |
+
|
183 |
+
def _postprocess_frame(self, frame):
|
184 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
185 |
+
|
186 |
+
if self.insets[0] > 0:
|
187 |
+
W = frame.shape[1]
|
188 |
+
p = int(W * self.insets[0])
|
189 |
+
frame = frame[:, p:-p, :]
|
190 |
+
|
191 |
+
if self.insets[1] > 0:
|
192 |
+
H = frame.shape[1]
|
193 |
+
q = int(H * self.insets[1])
|
194 |
+
frame = frame[q:-q, :, :]
|
195 |
+
|
196 |
+
return frame
|
197 |
+
|
198 |
+
|
199 |
+
class FaceExtractor:
|
200 |
+
def __init__(self, video_read_fn):
|
201 |
+
self.video_read_fn = video_read_fn
|
202 |
+
self.detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device="cpu")
|
203 |
+
|
204 |
+
def process_videos(self, input_dir, filenames, video_idxs):
|
205 |
+
videos_read = []
|
206 |
+
frames_read = []
|
207 |
+
frames = []
|
208 |
+
results = []
|
209 |
+
for video_idx in video_idxs:
|
210 |
+
# Read the full-size frames from this video.
|
211 |
+
filename = filenames[video_idx]
|
212 |
+
video_path = os.path.join(input_dir, filename)
|
213 |
+
result = self.video_read_fn(video_path)
|
214 |
+
# Error? Then skip this video.
|
215 |
+
if result is None: continue
|
216 |
+
|
217 |
+
videos_read.append(video_idx)
|
218 |
+
|
219 |
+
# Keep track of the original frames (need them later).
|
220 |
+
my_frames, my_idxs = result
|
221 |
+
|
222 |
+
frames.append(my_frames)
|
223 |
+
frames_read.append(my_idxs)
|
224 |
+
for i, frame in enumerate(my_frames):
|
225 |
+
h, w = frame.shape[:2]
|
226 |
+
img = Image.fromarray(frame.astype(np.uint8))
|
227 |
+
img = img.resize(size=[s // 2 for s in img.size])
|
228 |
+
|
229 |
+
batch_boxes, probs = self.detector.detect(img, landmarks=False)
|
230 |
+
|
231 |
+
faces = []
|
232 |
+
scores = []
|
233 |
+
if batch_boxes is None:
|
234 |
+
continue
|
235 |
+
for bbox, score in zip(batch_boxes, probs):
|
236 |
+
if bbox is not None:
|
237 |
+
xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox]
|
238 |
+
w = xmax - xmin
|
239 |
+
h = ymax - ymin
|
240 |
+
p_h = h // 3
|
241 |
+
p_w = w // 3
|
242 |
+
crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w]
|
243 |
+
faces.append(crop)
|
244 |
+
scores.append(score)
|
245 |
+
|
246 |
+
frame_dict = {"video_idx": video_idx,
|
247 |
+
"frame_idx": my_idxs[i],
|
248 |
+
"frame_w": w,
|
249 |
+
"frame_h": h,
|
250 |
+
"faces": faces,
|
251 |
+
"scores": scores}
|
252 |
+
results.append(frame_dict)
|
253 |
+
|
254 |
+
return results
|
255 |
+
|
256 |
+
def process_video(self, video_path):
|
257 |
+
"""Convenience method for doing face extraction on a single video."""
|
258 |
+
input_dir = os.path.dirname(video_path)
|
259 |
+
filenames = [os.path.basename(video_path)]
|
260 |
+
return self.process_videos(input_dir, filenames, [0])
|
261 |
+
|
262 |
+
|
263 |
+
|
264 |
+
def confident_strategy(pred, t=0.8):
|
265 |
+
pred = np.array(pred)
|
266 |
+
sz = len(pred)
|
267 |
+
fakes = np.count_nonzero(pred > t)
|
268 |
+
# 11 frames are detected as fakes with high probability
|
269 |
+
if fakes > sz // 2.5 and fakes > 11:
|
270 |
+
return np.mean(pred[pred > t])
|
271 |
+
elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
|
272 |
+
return np.mean(pred[pred < 0.2])
|
273 |
+
else:
|
274 |
+
return np.mean(pred)
|
275 |
+
|
276 |
+
strategy = confident_strategy
|
277 |
+
|
278 |
+
|
279 |
+
def put_to_center(img, input_size):
|
280 |
+
img = img[:input_size, :input_size]
|
281 |
+
image = np.zeros((input_size, input_size, 3), dtype=np.uint8)
|
282 |
+
start_w = (input_size - img.shape[1]) // 2
|
283 |
+
start_h = (input_size - img.shape[0]) // 2
|
284 |
+
image[start_h:start_h + img.shape[0], start_w: start_w + img.shape[1], :] = img
|
285 |
+
return image
|
286 |
+
|
287 |
+
|
288 |
+
def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
|
289 |
+
h, w = img.shape[:2]
|
290 |
+
if max(w, h) == size:
|
291 |
+
return img
|
292 |
+
if w > h:
|
293 |
+
scale = size / w
|
294 |
+
h = h * scale
|
295 |
+
w = size
|
296 |
+
else:
|
297 |
+
scale = size / h
|
298 |
+
w = w * scale
|
299 |
+
h = size
|
300 |
+
interpolation = interpolation_up if scale > 1 else interpolation_down
|
301 |
+
resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
|
302 |
+
return resized
|
303 |
+
|
304 |
+
|
305 |
+
def predict_on_video(face_extractor, video_path, batch_size, input_size, models, strategy=np.mean,
|
306 |
+
apply_compression=False, device='cpu'):
|
307 |
+
batch_size *= 4
|
308 |
+
try:
|
309 |
+
faces = face_extractor.process_video(video_path)
|
310 |
+
if len(faces) > 0:
|
311 |
+
x = np.zeros((batch_size, input_size, input_size, 3), dtype=np.uint8)
|
312 |
+
n = 0
|
313 |
+
for frame_data in faces:
|
314 |
+
for face in frame_data["faces"]:
|
315 |
+
resized_face = isotropically_resize_image(face, input_size)
|
316 |
+
resized_face = put_to_center(resized_face, input_size)
|
317 |
+
if apply_compression:
|
318 |
+
resized_face = image_compression(resized_face, quality=90, image_type=".jpg")
|
319 |
+
if n + 1 < batch_size:
|
320 |
+
x[n] = resized_face
|
321 |
+
n += 1
|
322 |
+
else:
|
323 |
+
pass
|
324 |
+
if n > 0:
|
325 |
+
if device == 'cpu':
|
326 |
+
x = torch.tensor(x, device='cpu').float()
|
327 |
+
else:
|
328 |
+
x = torch.tensor(x, device="cuda").float()
|
329 |
+
# Preprocess the images.
|
330 |
+
x = x.permute((0, 3, 1, 2))
|
331 |
+
for i in range(len(x)):
|
332 |
+
x[i] = normalize_transform(x[i] / 255.)
|
333 |
+
# Make a prediction, then take the average.
|
334 |
+
with torch.no_grad():
|
335 |
+
preds = []
|
336 |
+
models_ = [models]
|
337 |
+
for model in models_:
|
338 |
+
if device == 'cpu':
|
339 |
+
y_pred = model(x[:n])
|
340 |
+
else:
|
341 |
+
y_pred = model(x[:n].half())
|
342 |
+
y_pred = torch.sigmoid(y_pred.squeeze())
|
343 |
+
bpred = y_pred[:n].cpu().numpy()
|
344 |
+
preds.append(strategy(bpred))
|
345 |
+
return np.mean(preds)
|
346 |
+
except Exception as e:
|
347 |
+
print("Prediction error on video %s: %s" % (video_path, str(e)))
|
348 |
+
|
349 |
+
return 0.5
|
350 |
+
|
351 |
+
|
352 |
+
def predict_on_video_set(face_extractor, videos, input_size, num_workers, test_dir, frames_per_video, models,
|
353 |
+
strategy=np.mean,
|
354 |
+
apply_compression=False):
|
355 |
+
def process_file(i):
|
356 |
+
filename = videos[i]
|
357 |
+
y_pred = predict_on_video(face_extractor=face_extractor, video_path=os.path.join(test_dir, filename),
|
358 |
+
input_size=input_size,
|
359 |
+
batch_size=frames_per_video,
|
360 |
+
models=models, strategy=strategy, apply_compression=apply_compression)
|
361 |
+
return y_pred
|
362 |
+
|
363 |
+
with ThreadPoolExecutor(max_workers=num_workers) as ex:
|
364 |
+
predictions = ex.map(process_file, range(len(videos)))
|
365 |
+
return list(predictions)
|
366 |
+
|
libs/shape_predictor_68_face_landmarks.dat
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f
|
3 |
+
size 99693937
|
requirements.txt
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.1.0
|
2 |
+
aiohttp==3.8.4
|
3 |
+
aiosignal==1.3.1
|
4 |
+
albumentations==1.3.0
|
5 |
+
altair==5.0.0
|
6 |
+
anyio==3.6.2
|
7 |
+
anykeystore==0.2
|
8 |
+
apex
|
9 |
+
appdirs==1.4.4
|
10 |
+
async-timeout==4.0.2
|
11 |
+
attrs==23.1.0
|
12 |
+
certifi==2022.12.7
|
13 |
+
charset-normalizer==2.1.1
|
14 |
+
click==8.1.3
|
15 |
+
cmake==3.25.0
|
16 |
+
contourpy==1.0.7
|
17 |
+
cryptacular==1.6.2
|
18 |
+
cycler==0.11.0
|
19 |
+
defusedxml==0.7.1
|
20 |
+
dlib==19.24.1
|
21 |
+
docker-pycreds==0.4.0
|
22 |
+
facenet-pytorch==2.5.3
|
23 |
+
fastapi==0.95.1
|
24 |
+
ffmpy==0.3.0
|
25 |
+
filelock==3.9.0
|
26 |
+
fonttools==4.39.4
|
27 |
+
frozenlist==1.3.3
|
28 |
+
fsspec==2023.5.0
|
29 |
+
gitdb==4.0.10
|
30 |
+
GitPython==3.1.31
|
31 |
+
gradio==3.30.0
|
32 |
+
gradio_client==0.2.4
|
33 |
+
greenlet==2.0.2
|
34 |
+
h11==0.14.0
|
35 |
+
httpcore==0.17.0
|
36 |
+
httpx==0.24.0
|
37 |
+
huggingface-hub==0.14.1
|
38 |
+
hupper==1.12
|
39 |
+
idna==3.4
|
40 |
+
imageio==2.28.1
|
41 |
+
Jinja2==3.1.2
|
42 |
+
joblib==1.2.0
|
43 |
+
jsonschema==4.17.3
|
44 |
+
kiwisolver==1.4.4
|
45 |
+
lazy_loader==0.2
|
46 |
+
linkify-it-py==2.0.2
|
47 |
+
lit==15.0.7
|
48 |
+
markdown-it-py==2.2.0
|
49 |
+
MarkupSafe==2.1.2
|
50 |
+
matplotlib==3.7.1
|
51 |
+
mdit-py-plugins==0.3.3
|
52 |
+
mdurl==0.1.2
|
53 |
+
mpmath==1.2.1
|
54 |
+
multidict==6.0.4
|
55 |
+
networkx==3.0
|
56 |
+
numpy==1.24.3
|
57 |
+
oauthlib==3.2.2
|
58 |
+
opencv-python==4.7.0.72
|
59 |
+
opencv-python-headless==4.7.0.72
|
60 |
+
orjson==3.8.12
|
61 |
+
packaging==23.0
|
62 |
+
pandas==2.0.1
|
63 |
+
PasteDeploy==3.0.1
|
64 |
+
pathtools==0.1.2
|
65 |
+
pbkdf2==1.3
|
66 |
+
pep517==0.13.0
|
67 |
+
Pillow==9.3.0
|
68 |
+
plaster==1.1.2
|
69 |
+
plaster-pastedeploy==1.0.1
|
70 |
+
protobuf==3.20.3
|
71 |
+
psutil==5.9.5
|
72 |
+
pydantic==1.10.7
|
73 |
+
pydub==0.25.1
|
74 |
+
Pygments==2.15.1
|
75 |
+
pyparsing==3.0.9
|
76 |
+
pyramid==2.0.1
|
77 |
+
pyramid-mailer==0.15.1
|
78 |
+
pyrsistent==0.19.3
|
79 |
+
python-dateutil==2.8.2
|
80 |
+
python-multipart==0.0.6
|
81 |
+
python3-openid==3.2.0
|
82 |
+
pytorch-toolbelt==0.6.3
|
83 |
+
pytz==2023.3
|
84 |
+
PyWavelets==1.4.1
|
85 |
+
PyYAML==6.0
|
86 |
+
qudida==0.0.4
|
87 |
+
repoze.sendmail==4.4.1
|
88 |
+
requests==2.28.1
|
89 |
+
requests-oauthlib==1.3.1
|
90 |
+
scikit-image==0.20.0
|
91 |
+
scikit-learn==1.2.2
|
92 |
+
scipy==1.9.0
|
93 |
+
semantic-version==2.10.0
|
94 |
+
sentry-sdk==1.22.2
|
95 |
+
setproctitle==1.3.2
|
96 |
+
sh==1.14.3
|
97 |
+
six==1.16.0
|
98 |
+
smmap==5.0.0
|
99 |
+
sniffio==1.3.0
|
100 |
+
SQLAlchemy==1.4.48
|
101 |
+
starlette==0.26.1
|
102 |
+
sympy==1.11.1
|
103 |
+
tensorboardX==2.6
|
104 |
+
threadpoolctl==3.1.0
|
105 |
+
tifffile==2023.4.12
|
106 |
+
timm==0.6.13
|
107 |
+
toml==0.10.2
|
108 |
+
tomli==2.0.1
|
109 |
+
toolz==0.12.0
|
110 |
+
torch==2.0.1
|
111 |
+
torchvision==0.15.2
|
112 |
+
tqdm==4.65.0
|
113 |
+
transaction==3.1.0
|
114 |
+
translationstring==1.4
|
115 |
+
triton==2.0.0
|
116 |
+
typing_extensions==4.4.0
|
117 |
+
tzdata==2023.3
|
118 |
+
uc-micro-py==1.0.2
|
119 |
+
urllib3==1.26.13
|
120 |
+
uvicorn==0.22.0
|
121 |
+
velruse==1.1.1
|
122 |
+
venusian==3.0.0
|
123 |
+
wandb==0.15.2
|
124 |
+
WebOb==1.8.7
|
125 |
+
websockets==11.0.3
|
126 |
+
WTForms==3.0.1
|
127 |
+
wtforms-recaptcha==0.3.2
|
128 |
+
yarl==1.9.2
|
129 |
+
zope.deprecation==5.0
|
130 |
+
zope.interface==6.0
|
131 |
+
zope.sqlalchemy==2.0
|
training/__init__.py
ADDED
File without changes
|
training/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (148 Bytes). View file
|
|
training/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (146 Bytes). View file
|
|
training/__pycache__/losses.cpython-310.pyc
ADDED
Binary file (1.54 kB). View file
|
|
training/__pycache__/losses.cpython-39.pyc
ADDED
Binary file (1.53 kB). View file
|
|
training/datasets/__init__.py
ADDED
File without changes
|
training/datasets/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (157 Bytes). View file
|
|
training/datasets/__pycache__/classifier_dataset.cpython-310.pyc
ADDED
Binary file (10.8 kB). View file
|
|
training/datasets/__pycache__/validation_set.cpython-310.pyc
ADDED
Binary file (4.99 kB). View file
|
|
training/datasets/classifier_dataset.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import sys
|
5 |
+
import traceback
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import skimage.draw
|
11 |
+
from albumentations import ImageCompression, OneOf, GaussianBlur, Blur
|
12 |
+
from albumentations.augmentations.functional import image_compression
|
13 |
+
from albumentations.augmentations.geometric.functional import rot90
|
14 |
+
from albumentations.pytorch.functional import img_to_tensor
|
15 |
+
from scipy.ndimage import binary_erosion, binary_dilation
|
16 |
+
from skimage import measure
|
17 |
+
from torch.utils.data import Dataset
|
18 |
+
import dlib
|
19 |
+
|
20 |
+
from training.datasets.validation_set import PUBLIC_SET
|
21 |
+
|
22 |
+
|
23 |
+
def prepare_bit_masks(mask):
|
24 |
+
h, w = mask.shape
|
25 |
+
mid_w = w // 2
|
26 |
+
mid_h = w // 2
|
27 |
+
masks = []
|
28 |
+
ones = np.ones_like(mask)
|
29 |
+
ones[:mid_h] = 0
|
30 |
+
masks.append(ones)
|
31 |
+
ones = np.ones_like(mask)
|
32 |
+
ones[mid_h:] = 0
|
33 |
+
masks.append(ones)
|
34 |
+
ones = np.ones_like(mask)
|
35 |
+
ones[:, :mid_w] = 0
|
36 |
+
masks.append(ones)
|
37 |
+
ones = np.ones_like(mask)
|
38 |
+
ones[:, mid_w:] = 0
|
39 |
+
masks.append(ones)
|
40 |
+
ones = np.ones_like(mask)
|
41 |
+
ones[:mid_h, :mid_w] = 0
|
42 |
+
ones[mid_h:, mid_w:] = 0
|
43 |
+
masks.append(ones)
|
44 |
+
ones = np.ones_like(mask)
|
45 |
+
ones[:mid_h, mid_w:] = 0
|
46 |
+
ones[mid_h:, :mid_w] = 0
|
47 |
+
masks.append(ones)
|
48 |
+
return masks
|
49 |
+
|
50 |
+
|
51 |
+
detector = dlib.get_frontal_face_detector()
|
52 |
+
predictor = dlib.shape_predictor('libs/shape_predictor_68_face_landmarks.dat')
|
53 |
+
|
54 |
+
|
55 |
+
def blackout_convex_hull(img):
|
56 |
+
try:
|
57 |
+
rect = detector(img)[0]
|
58 |
+
sp = predictor(img, rect)
|
59 |
+
landmarks = np.array([[p.x, p.y] for p in sp.parts()])
|
60 |
+
outline = landmarks[[*range(17), *range(26, 16, -1)]]
|
61 |
+
Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0])
|
62 |
+
cropped_img = np.zeros(img.shape[:2], dtype=np.uint8)
|
63 |
+
cropped_img[Y, X] = 1
|
64 |
+
# if random.random() > 0.5:
|
65 |
+
# img[cropped_img == 0] = 0
|
66 |
+
# #leave only face
|
67 |
+
# return img
|
68 |
+
|
69 |
+
y, x = measure.centroid(cropped_img)
|
70 |
+
y = int(y)
|
71 |
+
x = int(x)
|
72 |
+
first = random.random() > 0.5
|
73 |
+
if random.random() > 0.5:
|
74 |
+
if first:
|
75 |
+
cropped_img[:y, :] = 0
|
76 |
+
else:
|
77 |
+
cropped_img[y:, :] = 0
|
78 |
+
else:
|
79 |
+
if first:
|
80 |
+
cropped_img[:, :x] = 0
|
81 |
+
else:
|
82 |
+
cropped_img[:, x:] = 0
|
83 |
+
|
84 |
+
img[cropped_img > 0] = 0
|
85 |
+
except Exception as e:
|
86 |
+
pass
|
87 |
+
|
88 |
+
|
89 |
+
def dist(p1, p2):
|
90 |
+
return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
|
91 |
+
|
92 |
+
|
93 |
+
def remove_eyes(image, landmarks):
|
94 |
+
image = image.copy()
|
95 |
+
(x1, y1), (x2, y2) = landmarks[:2]
|
96 |
+
mask = np.zeros_like(image[..., 0])
|
97 |
+
line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
|
98 |
+
w = dist((x1, y1), (x2, y2))
|
99 |
+
dilation = int(w // 4)
|
100 |
+
line = binary_dilation(line, iterations=dilation)
|
101 |
+
image[line, :] = 0
|
102 |
+
return image
|
103 |
+
|
104 |
+
|
105 |
+
def remove_nose(image, landmarks):
|
106 |
+
image = image.copy()
|
107 |
+
(x1, y1), (x2, y2) = landmarks[:2]
|
108 |
+
x3, y3 = landmarks[2]
|
109 |
+
mask = np.zeros_like(image[..., 0])
|
110 |
+
x4 = int((x1 + x2) / 2)
|
111 |
+
y4 = int((y1 + y2) / 2)
|
112 |
+
line = cv2.line(mask, (x3, y3), (x4, y4), color=(1), thickness=2)
|
113 |
+
w = dist((x1, y1), (x2, y2))
|
114 |
+
dilation = int(w // 4)
|
115 |
+
line = binary_dilation(line, iterations=dilation)
|
116 |
+
image[line, :] = 0
|
117 |
+
return image
|
118 |
+
|
119 |
+
|
120 |
+
def remove_mouth(image, landmarks):
|
121 |
+
image = image.copy()
|
122 |
+
(x1, y1), (x2, y2) = landmarks[-2:]
|
123 |
+
mask = np.zeros_like(image[..., 0])
|
124 |
+
line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
|
125 |
+
w = dist((x1, y1), (x2, y2))
|
126 |
+
dilation = int(w // 3)
|
127 |
+
line = binary_dilation(line, iterations=dilation)
|
128 |
+
image[line, :] = 0
|
129 |
+
return image
|
130 |
+
|
131 |
+
|
132 |
+
def remove_landmark(image, landmarks):
|
133 |
+
if random.random() > 0.5:
|
134 |
+
image = remove_eyes(image, landmarks)
|
135 |
+
elif random.random() > 0.5:
|
136 |
+
image = remove_mouth(image, landmarks)
|
137 |
+
elif random.random() > 0.5:
|
138 |
+
image = remove_nose(image, landmarks)
|
139 |
+
return image
|
140 |
+
|
141 |
+
|
142 |
+
def change_padding(image, part=5):
|
143 |
+
h, w = image.shape[:2]
|
144 |
+
# original padding was done with 1/3 from each side, too much
|
145 |
+
pad_h = int(((3 / 5) * h) / part)
|
146 |
+
pad_w = int(((3 / 5) * w) / part)
|
147 |
+
image = image[h // 5 - pad_h:-h // 5 + pad_h, w // 5 - pad_w:-w // 5 + pad_w]
|
148 |
+
return image
|
149 |
+
|
150 |
+
|
151 |
+
def blackout_random(image, mask, label):
|
152 |
+
binary_mask = mask > 0.4 * 255
|
153 |
+
h, w = binary_mask.shape[:2]
|
154 |
+
|
155 |
+
tries = 50
|
156 |
+
current_try = 1
|
157 |
+
while current_try < tries:
|
158 |
+
first = random.random() < 0.5
|
159 |
+
if random.random() < 0.5:
|
160 |
+
pivot = random.randint(h // 2 - h // 5, h // 2 + h // 5)
|
161 |
+
bitmap_msk = np.ones_like(binary_mask)
|
162 |
+
if first:
|
163 |
+
bitmap_msk[:pivot, :] = 0
|
164 |
+
else:
|
165 |
+
bitmap_msk[pivot:, :] = 0
|
166 |
+
else:
|
167 |
+
pivot = random.randint(w // 2 - w // 5, w // 2 + w // 5)
|
168 |
+
bitmap_msk = np.ones_like(binary_mask)
|
169 |
+
if first:
|
170 |
+
bitmap_msk[:, :pivot] = 0
|
171 |
+
else:
|
172 |
+
bitmap_msk[:, pivot:] = 0
|
173 |
+
|
174 |
+
if label < 0.5 and np.count_nonzero(image * np.expand_dims(bitmap_msk, axis=-1)) / 3 > (h * w) / 5 \
|
175 |
+
or np.count_nonzero(binary_mask * bitmap_msk) > 40:
|
176 |
+
mask *= bitmap_msk
|
177 |
+
image *= np.expand_dims(bitmap_msk, axis=-1)
|
178 |
+
break
|
179 |
+
current_try += 1
|
180 |
+
return image
|
181 |
+
|
182 |
+
|
183 |
+
def blend_original(img):
|
184 |
+
img = img.copy()
|
185 |
+
h, w = img.shape[:2]
|
186 |
+
rect = detector(img)
|
187 |
+
if len(rect) == 0:
|
188 |
+
return img
|
189 |
+
else:
|
190 |
+
rect = rect[0]
|
191 |
+
sp = predictor(img, rect)
|
192 |
+
landmarks = np.array([[p.x, p.y] for p in sp.parts()])
|
193 |
+
outline = landmarks[[*range(17), *range(26, 16, -1)]]
|
194 |
+
Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0])
|
195 |
+
raw_mask = np.zeros(img.shape[:2], dtype=np.uint8)
|
196 |
+
raw_mask[Y, X] = 1
|
197 |
+
face = img * np.expand_dims(raw_mask, -1)
|
198 |
+
|
199 |
+
# add warping
|
200 |
+
h1 = random.randint(h - h // 2, h + h // 2)
|
201 |
+
w1 = random.randint(w - w // 2, w + w // 2)
|
202 |
+
while abs(h1 - h) < h // 3 and abs(w1 - w) < w // 3:
|
203 |
+
h1 = random.randint(h - h // 2, h + h // 2)
|
204 |
+
w1 = random.randint(w - w // 2, w + w // 2)
|
205 |
+
face = cv2.resize(face, (w1, h1), interpolation=random.choice([cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC]))
|
206 |
+
face = cv2.resize(face, (w, h), interpolation=random.choice([cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC]))
|
207 |
+
|
208 |
+
raw_mask = binary_erosion(raw_mask, iterations=random.randint(4, 10))
|
209 |
+
img[raw_mask, :] = face[raw_mask, :]
|
210 |
+
if random.random() < 0.2:
|
211 |
+
img = OneOf([GaussianBlur(), Blur()], p=0.5)(image=img)["image"]
|
212 |
+
# image compression
|
213 |
+
if random.random() < 0.5:
|
214 |
+
img = ImageCompression(quality_lower=40, quality_upper=95)(image=img)["image"]
|
215 |
+
return img
|
216 |
+
|
217 |
+
|
218 |
+
class DeepFakeClassifierDataset(Dataset):
|
219 |
+
|
220 |
+
def __init__(self,
|
221 |
+
data_path="/mnt/sota/datasets/deepfake",
|
222 |
+
fold=0,
|
223 |
+
label_smoothing=0.01,
|
224 |
+
padding_part=3,
|
225 |
+
hardcore=True,
|
226 |
+
crops_dir="crops",
|
227 |
+
folds_csv="folds.csv",
|
228 |
+
normalize={"mean": [0.485, 0.456, 0.406],
|
229 |
+
"std": [0.229, 0.224, 0.225]},
|
230 |
+
rotation=False,
|
231 |
+
mode="train",
|
232 |
+
reduce_val=True,
|
233 |
+
oversample_real=True,
|
234 |
+
transforms=None
|
235 |
+
):
|
236 |
+
super().__init__()
|
237 |
+
self.data_root = data_path
|
238 |
+
self.fold = fold
|
239 |
+
self.folds_csv = folds_csv
|
240 |
+
self.mode = mode
|
241 |
+
self.rotation = rotation
|
242 |
+
self.padding_part = padding_part
|
243 |
+
self.hardcore = hardcore
|
244 |
+
self.crops_dir = crops_dir
|
245 |
+
self.label_smoothing = label_smoothing
|
246 |
+
self.normalize = normalize
|
247 |
+
self.transforms = transforms
|
248 |
+
self.df = pd.read_csv(self.folds_csv)
|
249 |
+
self.oversample_real = oversample_real
|
250 |
+
self.reduce_val = reduce_val
|
251 |
+
|
252 |
+
def __getitem__(self, index: int):
|
253 |
+
|
254 |
+
while True:
|
255 |
+
video, img_file, label, ori_video, frame, fold = self.data[index]
|
256 |
+
try:
|
257 |
+
if self.mode == "train":
|
258 |
+
label = np.clip(label, self.label_smoothing, 1 - self.label_smoothing)
|
259 |
+
img_path = os.path.join(self.data_root, self.crops_dir, video, img_file)
|
260 |
+
image = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
261 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
262 |
+
mask = np.zeros(image.shape[:2], dtype=np.uint8)
|
263 |
+
diff_path = os.path.join(self.data_root, "diffs", video, img_file[:-4] + "_diff.png")
|
264 |
+
try:
|
265 |
+
msk = cv2.imread(diff_path, cv2.IMREAD_GRAYSCALE)
|
266 |
+
if msk is not None:
|
267 |
+
mask = msk
|
268 |
+
except:
|
269 |
+
print("not found mask", diff_path)
|
270 |
+
pass
|
271 |
+
if self.mode == "train" and self.hardcore and not self.rotation:
|
272 |
+
landmark_path = os.path.join(self.data_root, "landmarks", ori_video, img_file[:-4] + ".npy")
|
273 |
+
if os.path.exists(landmark_path) and random.random() < 0.7:
|
274 |
+
landmarks = np.load(landmark_path)
|
275 |
+
image = remove_landmark(image, landmarks)
|
276 |
+
elif random.random() < 0.2:
|
277 |
+
blackout_convex_hull(image)
|
278 |
+
elif random.random() < 0.1:
|
279 |
+
binary_mask = mask > 0.4 * 255
|
280 |
+
masks = prepare_bit_masks((binary_mask * 1).astype(np.uint8))
|
281 |
+
tries = 6
|
282 |
+
current_try = 1
|
283 |
+
while current_try < tries:
|
284 |
+
bitmap_msk = random.choice(masks)
|
285 |
+
if label < 0.5 or np.count_nonzero(mask * bitmap_msk) > 20:
|
286 |
+
mask *= bitmap_msk
|
287 |
+
image *= np.expand_dims(bitmap_msk, axis=-1)
|
288 |
+
break
|
289 |
+
current_try += 1
|
290 |
+
if self.mode == "train" and self.padding_part > 3:
|
291 |
+
image = change_padding(image, self.padding_part)
|
292 |
+
valid_label = np.count_nonzero(mask[mask > 20]) > 32 or label < 0.5
|
293 |
+
valid_label = 1 if valid_label else 0
|
294 |
+
rotation = 0
|
295 |
+
if self.transforms:
|
296 |
+
data = self.transforms(image=image, mask=mask)
|
297 |
+
image = data["image"]
|
298 |
+
mask = data["mask"]
|
299 |
+
if self.mode == "train" and self.hardcore and self.rotation:
|
300 |
+
# landmark_path = os.path.join(self.data_root, "landmarks", ori_video, img_file[:-4] + ".npy")
|
301 |
+
dropout = 0.8 if label > 0.5 else 0.6
|
302 |
+
if self.rotation:
|
303 |
+
dropout *= 0.7
|
304 |
+
elif random.random() < dropout:
|
305 |
+
blackout_random(image, mask, label)
|
306 |
+
|
307 |
+
#
|
308 |
+
# os.makedirs("../images", exist_ok=True)
|
309 |
+
# cv2.imwrite(os.path.join("../images", video+ "_" + str(1 if label > 0.5 else 0) + "_"+img_file), image[...,::-1])
|
310 |
+
|
311 |
+
if self.mode == "train" and self.rotation:
|
312 |
+
rotation = random.randint(0, 3)
|
313 |
+
image = rot90(image, rotation)
|
314 |
+
|
315 |
+
image = img_to_tensor(image, self.normalize)
|
316 |
+
return {"image": image, "labels": np.array((label,)), "img_name": os.path.join(video, img_file),
|
317 |
+
"valid": valid_label, "rotations": rotation}
|
318 |
+
except Exception as e:
|
319 |
+
traceback.print_exc(file=sys.stdout)
|
320 |
+
print("Broken image", os.path.join(self.data_root, self.crops_dir, video, img_file))
|
321 |
+
index = random.randint(0, len(self.data) - 1)
|
322 |
+
|
323 |
+
def random_blackout_landmark(self, image, mask, landmarks):
|
324 |
+
x, y = random.choice(landmarks)
|
325 |
+
first = random.random() > 0.5
|
326 |
+
# crop half face either vertically or horizontally
|
327 |
+
if random.random() > 0.5:
|
328 |
+
# width
|
329 |
+
if first:
|
330 |
+
image[:, :x] = 0
|
331 |
+
mask[:, :x] = 0
|
332 |
+
else:
|
333 |
+
image[:, x:] = 0
|
334 |
+
mask[:, x:] = 0
|
335 |
+
else:
|
336 |
+
# height
|
337 |
+
if first:
|
338 |
+
image[:y, :] = 0
|
339 |
+
mask[:y, :] = 0
|
340 |
+
else:
|
341 |
+
image[y:, :] = 0
|
342 |
+
mask[y:, :] = 0
|
343 |
+
|
344 |
+
def reset(self, epoch, seed):
|
345 |
+
self.data = self._prepare_data(epoch, seed)
|
346 |
+
|
347 |
+
def __len__(self) -> int:
|
348 |
+
return len(self.data)
|
349 |
+
|
350 |
+
def get_distribution(self):
|
351 |
+
return self.n_real, self.n_fake
|
352 |
+
|
353 |
+
def _prepare_data(self, epoch, seed):
|
354 |
+
df = self.df
|
355 |
+
if self.mode == "train":
|
356 |
+
rows = df[df["fold"] != self.fold]
|
357 |
+
else:
|
358 |
+
rows = df[df["fold"] == self.fold]
|
359 |
+
seed = (epoch + 1) * seed
|
360 |
+
if self.oversample_real:
|
361 |
+
rows = self._oversample(rows, seed)
|
362 |
+
if self.mode == "val" and self.reduce_val:
|
363 |
+
# every 2nd frame, to speed up validation
|
364 |
+
rows = rows[rows["frame"] % 20 == 0]
|
365 |
+
# another option is to use public validation set
|
366 |
+
#rows = rows[rows["video"].isin(PUBLIC_SET)]
|
367 |
+
|
368 |
+
print(
|
369 |
+
"real {} fakes {} mode {}".format(len(rows[rows["label"] == 0]), len(rows[rows["label"] == 1]), self.mode))
|
370 |
+
data = rows.values
|
371 |
+
|
372 |
+
self.n_real = len(rows[rows["label"] == 0])
|
373 |
+
self.n_fake = len(rows[rows["label"] == 1])
|
374 |
+
np.random.seed(seed)
|
375 |
+
np.random.shuffle(data)
|
376 |
+
return data
|
377 |
+
|
378 |
+
def _oversample(self, rows: pd.DataFrame, seed):
|
379 |
+
real = rows[rows["label"] == 0]
|
380 |
+
fakes = rows[rows["label"] == 1]
|
381 |
+
num_real = real["video"].count()
|
382 |
+
if self.mode == "train":
|
383 |
+
fakes = fakes.sample(n=num_real, replace=False, random_state=seed)
|
384 |
+
return pd.concat([real, fakes])
|
training/datasets/validation_set.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
PUBLIC_SET = {'tjuihawuqm', 'prwsfljdjo', 'scrbqgpvzz', 'ziipxxchai', 'uubgqnvfdl', 'wclvkepakb', 'xjvxtuakyd',
|
4 |
+
'qlvsqdroqo', 'bcbqxhziqz', 'yzuestxcbq', 'hxwtsaydal', 'kqlvggiqee', 'vtunvalyji', 'mohiqoogpb',
|
5 |
+
'siebfpwuhu', 'cekwtyxdoo', 'hszwwswewp', 'orekjthsef', 'huvlwkxoxm', 'fmhiujydwo', 'lhvjzhjxdp',
|
6 |
+
'ibxfxggtqh', 'bofrwgeyjo', 'rmufsuogzn', 'zbgssotnjm', 'dpevefkefv', 'sufvvwmbha', 'ncoeewrdlo',
|
7 |
+
'qhsehzgxqj', 'yxadevzohx', 'aomqqjipcp', 'pcyswtgick', 'wfzjxzhdkj', 'rcjfxxhcal', 'lnjkpdviqb',
|
8 |
+
'xmkwsnuzyq', 'ouaowjmigq', 'bkuzquigyt', 'vwxednhlwz', 'mszblrdprw', 'blnmxntbey', 'gccnvdoknm',
|
9 |
+
'mkzaekkvej', 'hclsparpth', 'eryjktdexi', 'hfsvqabzfq', 'acazlolrpz', 'yoyhmxtrys', 'rerpivllud',
|
10 |
+
'elackxuccp', 'zgbhzkditd', 'vjljdfopjg', 'famlupsgqm', 'nymodlmxni', 'qcbkztamqc', 'qclpbcbgeq',
|
11 |
+
'lpkgabskbw', 'mnowxangqx', 'czfqlbcfpa', 'qyyhuvqmyf', 'toinozytsp', 'ztyvglkcsf', 'nplviymzlg',
|
12 |
+
'opvqdabdap', 'uxuvkrjhws', 'mxahsihabr', 'cqxxumarvp', 'ptbfnkajyi', 'njzshtfmcw', 'dcqodpzomd',
|
13 |
+
'ajiyrjfyzp', 'ywauoonmlr', 'gochxzemmq', 'lpgxwdgnio', 'hnfwagcxdf', 'gfcycflhbo', 'gunamloolc',
|
14 |
+
'yhjlnisfel', 'srfefmyjvt', 'evysmtpnrf', 'aktnlyqpah', 'gpsxfxrjrr', 'zfobicuigx', 'mnzabbkpmt',
|
15 |
+
'rfjuhbnlro', 'zuwwbbusgl', 'csnkohqxdv', 'bzvzpwrabw', 'yietrwuncf', 'wynotylpnm', 'ekboxwrwuv',
|
16 |
+
'rcecrgeotc', 'rklawjhbpv', 'ilqwcbprqa', 'jsysgmycsx', 'sqixhnilfm', 'wnlubukrki', 'nikynwcvuh',
|
17 |
+
'sjkfxrlxxs', 'btdxnajogv', 'wjhpisoeaj', 'dyjklprkoc', 'qlqhjcshpk', 'jyfvaequfg', 'dozjwhnedd',
|
18 |
+
'owaogcehvc', 'oyqgwjdwaj', 'vvfszaosiv', 'kmcdjxmnoa', 'jiswxuqzyz', 'ddtbarpcgo', 'wqysrieiqu',
|
19 |
+
'xcruhaccxc', 'honxqdilvv', 'nxgzmgzkfv', 'cxsvvnxpyz', 'demuhxssgl', 'hzoiotcykp', 'fwykevubzy',
|
20 |
+
'tejfudfgpq', 'kvmpmhdxly', 'oojxonbgow', 'vurjckblge', 'oysopgovhu', 'khpipxnsvx', 'pqthmvwonf',
|
21 |
+
'fddmkqjwsh', 'pcoxcmtroa', 'cnxccbjlct', 'ggzjfrirjh', 'jquevmhdvc', 'ecumyiowzs', 'esmqxszybs',
|
22 |
+
'mllzkpgatp', 'ryxaqpfubf', 'hbufmvbium', 'vdtsbqidjb', 'sjwywglgym', 'qxyrtwozyw', 'upmgtackuf',
|
23 |
+
'ucthmsajay', 'zgjosltkie', 'snlyjbnpgw', 'nswtvttxre', 'iznnzjvaxc', 'jhczqfefgw', 'htzbnroagi',
|
24 |
+
'pdswwyyntw', 'uvrzaczrbx', 'vbcgoyxsvn', 'hzssdinxec', 'novarhxpbj', 'vizerpsvbz', 'jawgcggquk',
|
25 |
+
'iorbtaarte', 'yarpxfqejd', 'vhbbwdflyh', 'rrrfjhugvb', 'fneqiqpqvs', 'jytrvwlewz', 'bfjsthfhbd',
|
26 |
+
'rxdoimqble', 'ekelfsnqof', 'uqvxjfpwdo', 'cjkctqqakb', 'tynfsthodx', 'yllztsrwjw', 'bktkwbcawi',
|
27 |
+
'wcqvzujamg', 'bcvheslzrq', 'aqrsylrzgi', 'sktpeppbkc', 'mkmgcxaztt', 'etdliwticv', 'hqzwudvhih',
|
28 |
+
'swsaoktwgi', 'temjefwaas', 'papagllumt', 'xrtvqhdibb', 'oelqpetgwj', 'ggdpclfcgk', 'imdmhwkkni',
|
29 |
+
'lebzjtusnr', 'xhtppuyqdr', 'nxzgekegsp', 'waucvvmtkq', 'rnfcjxynfa', 'adohdulfwb', 'tjywwgftmv',
|
30 |
+
'fjrueenjyp', 'oaguiggjyv', 'ytopzxrswu', 'yxvmusxvcz', 'rukyxomwcx', 'qdqdsaiitt', 'mxlipjhmqk',
|
31 |
+
'voawxrmqyl', 'kezwvsxxzj', 'oocincvedt', 'qooxnxqqjb', 'mwwploizlj', 'yaxgpxhavq', 'uhakqelqri',
|
32 |
+
'bvpeerislp', 'bkcyglmfci', 'jyoxdvxpza', 'gkutjglghz', 'knxltsvzyu', 'ybbrkacebd', 'apvzjkvnwn',
|
33 |
+
'ahjnxtiamx', 'hsbljbsgxr', 'fnxgqcvlsd', 'xphdfgmfmz', 'scbdenmaed', 'ywxpquomgt', 'yljecirelf',
|
34 |
+
'wcvsqnplsk', 'vmxfwxgdei', 'icbsahlivv', 'yhylappzid', 'irqzdokcws', 'petmyhjclt', 'rmlzgerevr',
|
35 |
+
'qarqtkvgby', 'nkhzxomani', 'viteugozpv', 'qhkzlnzruj', 'eisofhptvk', 'gqnaxievjx', 'heiyoojifp',
|
36 |
+
'zcxcmneefk', 'wvgviwnwob', 'gcdtglsoqj', 'yqhouqakbx', 'fopjiyxiqd', 'hierggamuo', 'ypbtpunjvm',
|
37 |
+
'sjinmmbipg', 'kmqkiihrmj', 'wmoqzxddkb', 'lnhkjhyhvw', 'wixbuuzygv', 'fsdrwikhge', 'sfsayjgzrh',
|
38 |
+
'pqdeutauqc', 'frqfsucgao', 'pdufsewrec', 'bfdopzvxbi', 'shnsajrsow', 'rvvpazsffd', 'pxcfrszlgi',
|
39 |
+
'itfsvvmslp', 'ayipraspbn', 'prhmixykhr', 'doniqevxeg', 'dvtpwatuja', 'jiavqbrkyk', 'ipkpxvwroe',
|
40 |
+
'syxobtuucp', 'syuxttuyhm', 'nwvsbmyndn', 'eqslzbqfea', 'ytddugrwph', 'vokrpfjpeb', 'bdshuoldwx',
|
41 |
+
'fmvvmcbdrw', 'bnuwxhfahw', 'gbnzicjyhz', 'txnmkabufs', 'gfdjzwnpyp', 'hweshqpfwe', 'dxgnpnowgk',
|
42 |
+
'xugmhbetrw', 'rktrpsdlci', 'nthpnwylxo', 'ihglzxzroo', 'ocgdbrgmtq', 'ruhtnngrqv', 'xljemofssi',
|
43 |
+
'zxacihctqp', 'ghnpsltzyn', 'lbigytrrtr', 'ndikguxzek', 'mdfndlljvt', 'lyoslorecs', 'oefukgnvel',
|
44 |
+
'zmxeiipnqb', 'cosghhimnd', 'alrtntfxtd', 'eywdmustbb', 'ooafcxxfrs', 'fqgypsunzr', 'hevcclcklc',
|
45 |
+
'uhrqlmlclw', 'ipvwtgdlre', 'wcssbghcpc', 'didzujjhtg', 'fjxovgmwnm', 'dmmvuaikkv', 'hitfycdavv',
|
46 |
+
'zyufpqvpyu', 'coujjnypba', 'temeqbmzxu', 'apedduehoy', 'iksxzpqxzi', 'kwfdyqofzw', 'aassnaulhq',
|
47 |
+
'eyguqfmgzh', 'yiykshcbaz', 'sngjsueuhs', 'okgelildpc', 'ztyuiqrhdk', 'tvhjcfnqtg', 'gfgcwxkbjd',
|
48 |
+
'lbfqksftuo', 'kowiwvrjht', 'dkuqbduxev', 'mwnibuujwz', 'sodvtfqbpf', 'hsbwhlolsn', 'qsjiypnjwi',
|
49 |
+
'blszgmxkvu', 'ystdtnetgj', 'rfwxcinshk', 'vnlzxqwthl', 'ljouzjaqqe', 'gahgyuwzbu', 'xxzefxwyku',
|
50 |
+
'xitgdpzbxv', 'sylnrepacf', 'igpvrfjdzc', 'nxnmkytwze', 'psesikjaxx', 'dvwpvqdflx', 'bjyaxvggle',
|
51 |
+
'dpmgoiwhuf', 'wadvzjhwtw', 'kcjvhgvhpt', 'eppyqpgewp', 'tyjpjpglgx', 'cekarydqba', 'dvkdfhrpph',
|
52 |
+
'cnpanmywno', 'ljauauuyka', 'hicjuubiau', 'cqhwesrciw', 'dnmowthjcj', 'lujvyveojc', 'wndursivcx',
|
53 |
+
'espkiocpxq', 'jsbpkpxwew', 'dsnxgrfdmd', 'hyjqolupxn', 'xdezcezszc', 'axfhbpkdlc', 'qqnlrngaft',
|
54 |
+
'coqwgzpbhx', 'ncmpqwmnzb', 'sznkemeqro', 'omphqltjdd', 'uoccaiathd', 'jzmzdispyo', 'pxjkzvqomp',
|
55 |
+
'udxqbhgvvx', 'dzkyxbbqkr', 'dtozwcapoa', 'qswlzfgcgj', 'tgawasvbbr', 'lmdyicksrv', 'fzvpbrzssi',
|
56 |
+
'dxfdovivlw', 'zzmgnglanj', 'vssmlqoiti', 'vajkicalux', 'ekvwecwltj', 'ylxwcwhjjd', 'keioymnobc',
|
57 |
+
'usqqvxcjmg', 'phjvutxpoi', 'nycmyuzpml', 'bwdmzwhdnw', 'fxuxxtryjn', 'orixbcfvdz', 'hefisnapds',
|
58 |
+
'fpevfidstw', 'halvwiltfs', 'dzojiwfvba', 'ojsxxkalat', 'esjdyghhog', 'ptbnewtvon', 'hcanfkwivl',
|
59 |
+
'yronlutbgm', 'llplvmcvbl', 'yxirnfyijn', 'nwvloufjty', 'rtpbawlmxr', 'aayfryxljh', 'zfrrixsimm',
|
60 |
+
'txmnoyiyte'}
|
training/losses.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
from pytorch_toolbelt.losses import BinaryFocalLoss
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn.modules.loss import BCEWithLogitsLoss
|
6 |
+
|
7 |
+
|
8 |
+
class WeightedLosses(nn.Module):
|
9 |
+
def __init__(self, losses, weights):
|
10 |
+
super().__init__()
|
11 |
+
self.losses = losses
|
12 |
+
self.weights = weights
|
13 |
+
|
14 |
+
def forward(self, *input: Any, **kwargs: Any):
|
15 |
+
cum_loss = 0
|
16 |
+
for loss, w in zip(self.losses, self.weights):
|
17 |
+
cum_loss += w * loss.forward(*input, **kwargs)
|
18 |
+
return cum_loss
|
19 |
+
|
20 |
+
|
21 |
+
class BinaryCrossentropy(BCEWithLogitsLoss):
|
22 |
+
pass
|
23 |
+
|
24 |
+
|
25 |
+
class FocalLoss(BinaryFocalLoss):
|
26 |
+
def __init__(self, alpha=None, gamma=3, ignore_index=None, reduction="mean", normalized=False,
|
27 |
+
reduced_threshold=None):
|
28 |
+
super().__init__(alpha, gamma, ignore_index, reduction, normalized, reduced_threshold)
|
training/pipelines/__init__.py
ADDED
File without changes
|
training/pipelines/train_classifier.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
from sklearn.metrics import log_loss
|
7 |
+
from torch import topk
|
8 |
+
|
9 |
+
import sys
|
10 |
+
print('@@@@@@@@@@@@@@@@@@')
|
11 |
+
sys.path.append('..')
|
12 |
+
|
13 |
+
from training import losses
|
14 |
+
from training.datasets.classifier_dataset import DeepFakeClassifierDataset
|
15 |
+
from training.losses import WeightedLosses
|
16 |
+
from training.tools.config import load_config
|
17 |
+
from training.tools.utils import create_optimizer, AverageMeter
|
18 |
+
from training.transforms.albu import IsotropicResize
|
19 |
+
from training.zoo import classifiers
|
20 |
+
|
21 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
22 |
+
os.environ["NUMEXPR_NUM_THREADS"] = "1"
|
23 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
24 |
+
|
25 |
+
import cv2
|
26 |
+
|
27 |
+
cv2.ocl.setUseOpenCL(False)
|
28 |
+
cv2.setNumThreads(0)
|
29 |
+
import numpy as np
|
30 |
+
from albumentations import Compose, RandomBrightnessContrast, \
|
31 |
+
HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, \
|
32 |
+
ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, GaussianBlur
|
33 |
+
|
34 |
+
from apex.parallel import DistributedDataParallel, convert_syncbn_model
|
35 |
+
from tensorboardX import SummaryWriter
|
36 |
+
|
37 |
+
from apex import amp
|
38 |
+
|
39 |
+
import torch
|
40 |
+
from torch.backends import cudnn
|
41 |
+
from torch.nn import DataParallel
|
42 |
+
from torch.utils.data import DataLoader
|
43 |
+
from tqdm import tqdm
|
44 |
+
import torch.distributed as dist
|
45 |
+
|
46 |
+
torch.backends.cudnn.benchmark = True
|
47 |
+
|
48 |
+
def create_train_transforms(size=300):
|
49 |
+
return Compose([
|
50 |
+
ImageCompression(quality_lower=60, quality_upper=100, p=0.5),
|
51 |
+
GaussNoise(p=0.1),
|
52 |
+
GaussianBlur(blur_limit=3, p=0.05),
|
53 |
+
HorizontalFlip(),
|
54 |
+
OneOf([
|
55 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
|
56 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
|
57 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
|
58 |
+
], p=1),
|
59 |
+
PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
|
60 |
+
OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.7),
|
61 |
+
ToGray(p=0.2),
|
62 |
+
ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5),
|
63 |
+
]
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
def create_val_transforms(size=300):
|
68 |
+
return Compose([
|
69 |
+
IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
|
70 |
+
PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
|
71 |
+
])
|
72 |
+
|
73 |
+
|
74 |
+
def main():
|
75 |
+
parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
|
76 |
+
arg = parser.add_argument
|
77 |
+
arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
|
78 |
+
arg('--workers', type=int, default=6, help='number of cpu threads to use')
|
79 |
+
arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
|
80 |
+
arg('--output-dir', type=str, default='weights/')
|
81 |
+
arg('--resume', type=str, default='')
|
82 |
+
arg('--fold', type=int, default=0)
|
83 |
+
arg('--prefix', type=str, default='classifier_')
|
84 |
+
arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake")
|
85 |
+
arg('--folds-csv', type=str, default='folds.csv')
|
86 |
+
arg('--crops-dir', type=str, default='crops')
|
87 |
+
arg('--label-smoothing', type=float, default=0.01)
|
88 |
+
arg('--logdir', type=str, default='logs')
|
89 |
+
arg('--zero-score', action='store_true', default=False)
|
90 |
+
arg('--from-zero', action='store_true', default=False)
|
91 |
+
arg('--distributed', action='store_true', default=False)
|
92 |
+
arg('--freeze-epochs', type=int, default=0)
|
93 |
+
arg("--local_rank", default=0, type=int)
|
94 |
+
arg("--seed", default=777, type=int)
|
95 |
+
arg("--padding-part", default=3, type=int)
|
96 |
+
arg("--opt-level", default='O1', type=str)
|
97 |
+
arg("--test_every", type=int, default=1)
|
98 |
+
arg("--no-oversample", action="store_true")
|
99 |
+
arg("--no-hardcore", action="store_true")
|
100 |
+
arg("--only-changed-frames", action="store_true")
|
101 |
+
|
102 |
+
args = parser.parse_args()
|
103 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
104 |
+
if args.distributed:
|
105 |
+
torch.cuda.set_device(args.local_rank)
|
106 |
+
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
107 |
+
else:
|
108 |
+
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
|
109 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
110 |
+
|
111 |
+
cudnn.benchmark = True
|
112 |
+
|
113 |
+
conf = load_config(args.config)
|
114 |
+
model = classifiers.__dict__[conf['network']](encoder=conf['encoder'])
|
115 |
+
|
116 |
+
model = model.cuda()
|
117 |
+
if args.distributed:
|
118 |
+
model = convert_syncbn_model(model)
|
119 |
+
ohem = conf.get("ohem_samples", None)
|
120 |
+
reduction = "mean"
|
121 |
+
if ohem:
|
122 |
+
reduction = "none"
|
123 |
+
loss_fn = []
|
124 |
+
weights = []
|
125 |
+
for loss_name, weight in conf["losses"].items():
|
126 |
+
loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda())
|
127 |
+
weights.append(weight)
|
128 |
+
loss = WeightedLosses(loss_fn, weights)
|
129 |
+
loss_functions = {"classifier_loss": loss}
|
130 |
+
optimizer, scheduler = create_optimizer(conf['optimizer'], model)
|
131 |
+
bce_best = 100
|
132 |
+
start_epoch = 0
|
133 |
+
batch_size = conf['optimizer']['batch_size']
|
134 |
+
|
135 |
+
data_train = DeepFakeClassifierDataset(mode="train",
|
136 |
+
oversample_real=not args.no_oversample,
|
137 |
+
fold=args.fold,
|
138 |
+
padding_part=args.padding_part,
|
139 |
+
hardcore=not args.no_hardcore,
|
140 |
+
crops_dir=args.crops_dir,
|
141 |
+
data_path=args.data_dir,
|
142 |
+
label_smoothing=args.label_smoothing,
|
143 |
+
folds_csv=args.folds_csv,
|
144 |
+
transforms=create_train_transforms(conf["size"]),
|
145 |
+
normalize=conf.get("normalize", None))
|
146 |
+
data_val = DeepFakeClassifierDataset(mode="val",
|
147 |
+
fold=args.fold,
|
148 |
+
padding_part=args.padding_part,
|
149 |
+
crops_dir=args.crops_dir,
|
150 |
+
data_path=args.data_dir,
|
151 |
+
folds_csv=args.folds_csv,
|
152 |
+
transforms=create_val_transforms(conf["size"]),
|
153 |
+
normalize=conf.get("normalize", None))
|
154 |
+
val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False,
|
155 |
+
pin_memory=False)
|
156 |
+
os.makedirs(args.logdir, exist_ok=True)
|
157 |
+
summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold))
|
158 |
+
if args.resume:
|
159 |
+
if os.path.isfile(args.resume):
|
160 |
+
print("=> loading checkpoint '{}'".format(args.resume))
|
161 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
162 |
+
state_dict = checkpoint['state_dict']
|
163 |
+
state_dict = {k[7:]: w for k, w in state_dict.items()}
|
164 |
+
model.load_state_dict(state_dict, strict=False)
|
165 |
+
if not args.from_zero:
|
166 |
+
start_epoch = checkpoint['epoch']
|
167 |
+
if not args.zero_score:
|
168 |
+
bce_best = checkpoint.get('bce_best', 0)
|
169 |
+
print("=> loaded checkpoint '{}' (epoch {}, bce_best {})"
|
170 |
+
.format(args.resume, checkpoint['epoch'], checkpoint['bce_best']))
|
171 |
+
else:
|
172 |
+
print("=> no checkpoint found at '{}'".format(args.resume))
|
173 |
+
if args.from_zero:
|
174 |
+
start_epoch = 0
|
175 |
+
current_epoch = start_epoch
|
176 |
+
|
177 |
+
if conf['fp16']:
|
178 |
+
model, optimizer = amp.initialize(model, optimizer,
|
179 |
+
opt_level=args.opt_level,
|
180 |
+
loss_scale='dynamic')
|
181 |
+
|
182 |
+
snapshot_name = "{}{}_{}_{}".format(conf.get("prefix", args.prefix), conf['network'], conf['encoder'], args.fold)
|
183 |
+
|
184 |
+
if args.distributed:
|
185 |
+
model = DistributedDataParallel(model, delay_allreduce=True)
|
186 |
+
else:
|
187 |
+
model = DataParallel(model).cuda()
|
188 |
+
data_val.reset(1, args.seed)
|
189 |
+
max_epochs = conf['optimizer']['schedule']['epochs']
|
190 |
+
for epoch in range(start_epoch, max_epochs):
|
191 |
+
data_train.reset(epoch, args.seed)
|
192 |
+
train_sampler = None
|
193 |
+
if args.distributed:
|
194 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(data_train)
|
195 |
+
train_sampler.set_epoch(epoch)
|
196 |
+
if epoch < args.freeze_epochs:
|
197 |
+
print("Freezing encoder!!!")
|
198 |
+
model.module.encoder.eval()
|
199 |
+
for p in model.module.encoder.parameters():
|
200 |
+
p.requires_grad = False
|
201 |
+
else:
|
202 |
+
model.module.encoder.train()
|
203 |
+
for p in model.module.encoder.parameters():
|
204 |
+
p.requires_grad = True
|
205 |
+
|
206 |
+
train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers,
|
207 |
+
shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False,
|
208 |
+
drop_last=True)
|
209 |
+
|
210 |
+
train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
|
211 |
+
args.local_rank, args.only_changed_frames)
|
212 |
+
model = model.eval()
|
213 |
+
|
214 |
+
if args.local_rank == 0:
|
215 |
+
torch.save({
|
216 |
+
'epoch': current_epoch + 1,
|
217 |
+
'state_dict': model.state_dict(),
|
218 |
+
'bce_best': bce_best,
|
219 |
+
}, args.output_dir + '/' + snapshot_name + "_last")
|
220 |
+
torch.save({
|
221 |
+
'epoch': current_epoch + 1,
|
222 |
+
'state_dict': model.state_dict(),
|
223 |
+
'bce_best': bce_best,
|
224 |
+
}, args.output_dir + snapshot_name + "_{}".format(current_epoch))
|
225 |
+
if (epoch + 1) % args.test_every == 0:
|
226 |
+
bce_best = evaluate_val(args, val_data_loader, bce_best, model,
|
227 |
+
snapshot_name=snapshot_name,
|
228 |
+
current_epoch=current_epoch,
|
229 |
+
summary_writer=summary_writer)
|
230 |
+
current_epoch += 1
|
231 |
+
|
232 |
+
|
233 |
+
def evaluate_val(args, data_val, bce_best, model, snapshot_name, current_epoch, summary_writer):
|
234 |
+
print("Test phase")
|
235 |
+
model = model.eval()
|
236 |
+
|
237 |
+
bce, probs, targets = validate(model, data_loader=data_val)
|
238 |
+
if args.local_rank == 0:
|
239 |
+
summary_writer.add_scalar('val/bce', float(bce), global_step=current_epoch)
|
240 |
+
if bce < bce_best:
|
241 |
+
print("Epoch {} improved from {} to {}".format(current_epoch, bce_best, bce))
|
242 |
+
if args.output_dir is not None:
|
243 |
+
torch.save({
|
244 |
+
'epoch': current_epoch + 1,
|
245 |
+
'state_dict': model.state_dict(),
|
246 |
+
'bce_best': bce,
|
247 |
+
}, args.output_dir + snapshot_name + "_best_dice")
|
248 |
+
bce_best = bce
|
249 |
+
with open("predictions_{}.json".format(args.fold), "w") as f:
|
250 |
+
json.dump({"probs": probs, "targets": targets}, f)
|
251 |
+
torch.save({
|
252 |
+
'epoch': current_epoch + 1,
|
253 |
+
'state_dict': model.state_dict(),
|
254 |
+
'bce_best': bce_best,
|
255 |
+
}, args.output_dir + snapshot_name + "_last")
|
256 |
+
print("Epoch: {} bce: {}, bce_best: {}".format(current_epoch, bce, bce_best))
|
257 |
+
return bce_best
|
258 |
+
|
259 |
+
|
260 |
+
def validate(net, data_loader, prefix=""):
|
261 |
+
probs = defaultdict(list)
|
262 |
+
targets = defaultdict(list)
|
263 |
+
|
264 |
+
with torch.no_grad():
|
265 |
+
for sample in tqdm(data_loader):
|
266 |
+
imgs = sample["image"].cuda()
|
267 |
+
img_names = sample["img_name"]
|
268 |
+
labels = sample["labels"].cuda().float()
|
269 |
+
out = net(imgs)
|
270 |
+
labels = labels.cpu().numpy()
|
271 |
+
preds = torch.sigmoid(out).cpu().numpy()
|
272 |
+
for i in range(out.shape[0]):
|
273 |
+
video, img_id = img_names[i].split("/")
|
274 |
+
probs[video].append(preds[i].tolist())
|
275 |
+
targets[video].append(labels[i].tolist())
|
276 |
+
data_x = []
|
277 |
+
data_y = []
|
278 |
+
for vid, score in probs.items():
|
279 |
+
score = np.array(score)
|
280 |
+
lbl = targets[vid]
|
281 |
+
|
282 |
+
score = np.mean(score)
|
283 |
+
lbl = np.mean(lbl)
|
284 |
+
data_x.append(score)
|
285 |
+
data_y.append(lbl)
|
286 |
+
y = np.array(data_y)
|
287 |
+
x = np.array(data_x)
|
288 |
+
fake_idx = y > 0.1
|
289 |
+
real_idx = y < 0.1
|
290 |
+
fake_loss = log_loss(y[fake_idx], x[fake_idx], labels=[0, 1])
|
291 |
+
real_loss = log_loss(y[real_idx], x[real_idx], labels=[0, 1])
|
292 |
+
print("{}fake_loss".format(prefix), fake_loss)
|
293 |
+
print("{}real_loss".format(prefix), real_loss)
|
294 |
+
|
295 |
+
return (fake_loss + real_loss) / 2, probs, targets
|
296 |
+
|
297 |
+
|
298 |
+
def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
|
299 |
+
local_rank, only_valid):
|
300 |
+
losses = AverageMeter()
|
301 |
+
fake_losses = AverageMeter()
|
302 |
+
real_losses = AverageMeter()
|
303 |
+
max_iters = conf["batches_per_epoch"]
|
304 |
+
print("training epoch {}".format(current_epoch))
|
305 |
+
model.train()
|
306 |
+
pbar = tqdm(enumerate(train_data_loader), total=max_iters, desc="Epoch {}".format(current_epoch), ncols=0)
|
307 |
+
if conf["optimizer"]["schedule"]["mode"] == "epoch":
|
308 |
+
scheduler.step(current_epoch)
|
309 |
+
for i, sample in pbar:
|
310 |
+
imgs = sample["image"].cuda()
|
311 |
+
labels = sample["labels"].cuda().float()
|
312 |
+
out_labels = model(imgs)
|
313 |
+
if only_valid:
|
314 |
+
valid_idx = sample["valid"].cuda().float() > 0
|
315 |
+
out_labels = out_labels[valid_idx]
|
316 |
+
labels = labels[valid_idx]
|
317 |
+
if labels.size(0) == 0:
|
318 |
+
continue
|
319 |
+
|
320 |
+
fake_loss = 0
|
321 |
+
real_loss = 0
|
322 |
+
fake_idx = labels > 0.5
|
323 |
+
real_idx = labels <= 0.5
|
324 |
+
|
325 |
+
ohem = conf.get("ohem_samples", None)
|
326 |
+
if torch.sum(fake_idx * 1) > 0:
|
327 |
+
fake_loss = loss_functions["classifier_loss"](out_labels[fake_idx], labels[fake_idx])
|
328 |
+
if torch.sum(real_idx * 1) > 0:
|
329 |
+
real_loss = loss_functions["classifier_loss"](out_labels[real_idx], labels[real_idx])
|
330 |
+
if ohem:
|
331 |
+
fake_loss = topk(fake_loss, k=min(ohem, fake_loss.size(0)), sorted=False)[0].mean()
|
332 |
+
real_loss = topk(real_loss, k=min(ohem, real_loss.size(0)), sorted=False)[0].mean()
|
333 |
+
|
334 |
+
loss = (fake_loss + real_loss) / 2
|
335 |
+
losses.update(loss.item(), imgs.size(0))
|
336 |
+
fake_losses.update(0 if fake_loss == 0 else fake_loss.item(), imgs.size(0))
|
337 |
+
real_losses.update(0 if real_loss == 0 else real_loss.item(), imgs.size(0))
|
338 |
+
|
339 |
+
optimizer.zero_grad()
|
340 |
+
pbar.set_postfix({"lr": float(scheduler.get_lr()[-1]), "epoch": current_epoch, "loss": losses.avg,
|
341 |
+
"fake_loss": fake_losses.avg, "real_loss": real_losses.avg})
|
342 |
+
|
343 |
+
if conf['fp16']:
|
344 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
345 |
+
scaled_loss.backward()
|
346 |
+
else:
|
347 |
+
loss.backward()
|
348 |
+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
|
349 |
+
optimizer.step()
|
350 |
+
torch.cuda.synchronize()
|
351 |
+
if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
|
352 |
+
scheduler.step(i + current_epoch * max_iters)
|
353 |
+
if i == max_iters - 1:
|
354 |
+
break
|
355 |
+
pbar.close()
|
356 |
+
if local_rank == 0:
|
357 |
+
for idx, param_group in enumerate(optimizer.param_groups):
|
358 |
+
lr = param_group['lr']
|
359 |
+
summary_writer.add_scalar('group{}/lr'.format(idx), float(lr), global_step=current_epoch)
|
360 |
+
summary_writer.add_scalar('train/loss', float(losses.avg), global_step=current_epoch)
|
361 |
+
|
362 |
+
|
363 |
+
if __name__ == '__main__':
|
364 |
+
main()
|
training/tools/__init__.py
ADDED
File without changes
|
training/tools/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (154 Bytes). View file
|
|
training/tools/__pycache__/config.cpython-310.pyc
ADDED
Binary file (1.06 kB). View file
|
|
training/tools/__pycache__/schedulers.cpython-310.pyc
ADDED
Binary file (3.01 kB). View file
|
|
training/tools/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.65 kB). View file
|
|
training/tools/config.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
DEFAULTS = {
|
4 |
+
"network": "dpn",
|
5 |
+
"encoder": "dpn92",
|
6 |
+
"model_params": {},
|
7 |
+
"optimizer": {
|
8 |
+
"batch_size": 32,
|
9 |
+
"type": "SGD", # supported: SGD, Adam
|
10 |
+
"momentum": 0.9,
|
11 |
+
"weight_decay": 0,
|
12 |
+
"clip": 1.,
|
13 |
+
"learning_rate": 0.1,
|
14 |
+
"classifier_lr": -1,
|
15 |
+
"nesterov": True,
|
16 |
+
"schedule": {
|
17 |
+
"type": "constant", # supported: constant, step, multistep, exponential, linear, poly
|
18 |
+
"mode": "epoch", # supported: epoch, step
|
19 |
+
"epochs": 10,
|
20 |
+
"params": {}
|
21 |
+
}
|
22 |
+
},
|
23 |
+
"normalize": {
|
24 |
+
"mean": [0.485, 0.456, 0.406],
|
25 |
+
"std": [0.229, 0.224, 0.225]
|
26 |
+
}
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
def _merge(src, dst):
|
31 |
+
for k, v in src.items():
|
32 |
+
if k in dst:
|
33 |
+
if isinstance(v, dict):
|
34 |
+
_merge(src[k], dst[k])
|
35 |
+
else:
|
36 |
+
dst[k] = v
|
37 |
+
|
38 |
+
|
39 |
+
def load_config(config_file, defaults=DEFAULTS):
|
40 |
+
with open(config_file, "r") as fd:
|
41 |
+
config = json.load(fd)
|
42 |
+
_merge(defaults, config)
|
43 |
+
return config
|
training/tools/schedulers.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bisect import bisect_right
|
2 |
+
|
3 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
4 |
+
|
5 |
+
|
6 |
+
class LRStepScheduler(_LRScheduler):
|
7 |
+
def __init__(self, optimizer, steps, last_epoch=-1):
|
8 |
+
self.lr_steps = steps
|
9 |
+
super().__init__(optimizer, last_epoch)
|
10 |
+
|
11 |
+
def get_lr(self):
|
12 |
+
pos = max(bisect_right([x for x, y in self.lr_steps], self.last_epoch) - 1, 0)
|
13 |
+
return [self.lr_steps[pos][1] if self.lr_steps[pos][0] <= self.last_epoch else base_lr for base_lr in self.base_lrs]
|
14 |
+
|
15 |
+
|
16 |
+
class PolyLR(_LRScheduler):
|
17 |
+
"""Sets the learning rate of each parameter group according to poly learning rate policy
|
18 |
+
"""
|
19 |
+
def __init__(self, optimizer, max_iter=90000, power=0.9, last_epoch=-1):
|
20 |
+
self.max_iter = max_iter
|
21 |
+
self.power = power
|
22 |
+
super(PolyLR, self).__init__(optimizer, last_epoch)
|
23 |
+
|
24 |
+
def get_lr(self):
|
25 |
+
self.last_epoch = (self.last_epoch + 1) % self.max_iter
|
26 |
+
return [base_lr * ((1 - float(self.last_epoch) / self.max_iter) ** (self.power)) for base_lr in self.base_lrs]
|
27 |
+
|
28 |
+
class ExponentialLRScheduler(_LRScheduler):
|
29 |
+
"""Decays the learning rate of each parameter group by gamma every epoch.
|
30 |
+
When last_epoch=-1, sets initial lr as lr.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
optimizer (Optimizer): Wrapped optimizer.
|
34 |
+
gamma (float): Multiplicative factor of learning rate decay.
|
35 |
+
last_epoch (int): The index of last epoch. Default: -1.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, optimizer, gamma, last_epoch=-1):
|
39 |
+
self.gamma = gamma
|
40 |
+
super(ExponentialLRScheduler, self).__init__(optimizer, last_epoch)
|
41 |
+
|
42 |
+
def get_lr(self):
|
43 |
+
if self.last_epoch <= 0:
|
44 |
+
return self.base_lrs
|
45 |
+
return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
|
46 |
+
|
training/tools/utils.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
from apex.optimizers import FusedAdam, FusedSGD
|
3 |
+
from timm.optim import AdamW
|
4 |
+
from torch import optim
|
5 |
+
from torch.optim import lr_scheduler
|
6 |
+
from torch.optim.rmsprop import RMSprop
|
7 |
+
from torch.optim.adamw import AdamW
|
8 |
+
from torch.optim.lr_scheduler import MultiStepLR, CyclicLR
|
9 |
+
|
10 |
+
from training.tools.schedulers import ExponentialLRScheduler, PolyLR, LRStepScheduler
|
11 |
+
|
12 |
+
cv2.ocl.setUseOpenCL(False)
|
13 |
+
cv2.setNumThreads(0)
|
14 |
+
|
15 |
+
|
16 |
+
class AverageMeter(object):
|
17 |
+
"""Computes and stores the average and current value"""
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
self.reset()
|
21 |
+
|
22 |
+
def reset(self):
|
23 |
+
self.val = 0
|
24 |
+
self.avg = 0
|
25 |
+
self.sum = 0
|
26 |
+
self.count = 0
|
27 |
+
|
28 |
+
def update(self, val, n=1):
|
29 |
+
self.val = val
|
30 |
+
self.sum += val * n
|
31 |
+
self.count += n
|
32 |
+
self.avg = self.sum / self.count
|
33 |
+
|
34 |
+
def create_optimizer(optimizer_config, model, master_params=None):
|
35 |
+
"""Creates optimizer and schedule from configuration
|
36 |
+
|
37 |
+
Parameters
|
38 |
+
----------
|
39 |
+
optimizer_config : dict
|
40 |
+
Dictionary containing the configuration options for the optimizer.
|
41 |
+
model : Model
|
42 |
+
The network model.
|
43 |
+
|
44 |
+
Returns
|
45 |
+
-------
|
46 |
+
optimizer : Optimizer
|
47 |
+
The optimizer.
|
48 |
+
scheduler : LRScheduler
|
49 |
+
The learning rate scheduler.
|
50 |
+
"""
|
51 |
+
if optimizer_config.get("classifier_lr", -1) != -1:
|
52 |
+
# Separate classifier parameters from all others
|
53 |
+
net_params = []
|
54 |
+
classifier_params = []
|
55 |
+
for k, v in model.named_parameters():
|
56 |
+
if not v.requires_grad:
|
57 |
+
continue
|
58 |
+
if k.find("encoder") != -1:
|
59 |
+
net_params.append(v)
|
60 |
+
else:
|
61 |
+
classifier_params.append(v)
|
62 |
+
params = [
|
63 |
+
{"params": net_params},
|
64 |
+
{"params": classifier_params, "lr": optimizer_config["classifier_lr"]},
|
65 |
+
]
|
66 |
+
else:
|
67 |
+
if master_params:
|
68 |
+
params = master_params
|
69 |
+
else:
|
70 |
+
params = model.parameters()
|
71 |
+
|
72 |
+
if optimizer_config["type"] == "SGD":
|
73 |
+
optimizer = optim.SGD(params,
|
74 |
+
lr=optimizer_config["learning_rate"],
|
75 |
+
momentum=optimizer_config["momentum"],
|
76 |
+
weight_decay=optimizer_config["weight_decay"],
|
77 |
+
nesterov=optimizer_config["nesterov"])
|
78 |
+
elif optimizer_config["type"] == "FusedSGD":
|
79 |
+
optimizer = FusedSGD(params,
|
80 |
+
lr=optimizer_config["learning_rate"],
|
81 |
+
momentum=optimizer_config["momentum"],
|
82 |
+
weight_decay=optimizer_config["weight_decay"],
|
83 |
+
nesterov=optimizer_config["nesterov"])
|
84 |
+
elif optimizer_config["type"] == "Adam":
|
85 |
+
optimizer = optim.Adam(params,
|
86 |
+
lr=optimizer_config["learning_rate"],
|
87 |
+
weight_decay=optimizer_config["weight_decay"])
|
88 |
+
elif optimizer_config["type"] == "FusedAdam":
|
89 |
+
optimizer = FusedAdam(params,
|
90 |
+
lr=optimizer_config["learning_rate"],
|
91 |
+
weight_decay=optimizer_config["weight_decay"])
|
92 |
+
elif optimizer_config["type"] == "AdamW":
|
93 |
+
optimizer = AdamW(params,
|
94 |
+
lr=optimizer_config["learning_rate"],
|
95 |
+
weight_decay=optimizer_config["weight_decay"])
|
96 |
+
elif optimizer_config["type"] == "RmsProp":
|
97 |
+
optimizer = RMSprop(params,
|
98 |
+
lr=optimizer_config["learning_rate"],
|
99 |
+
weight_decay=optimizer_config["weight_decay"])
|
100 |
+
else:
|
101 |
+
raise KeyError("unrecognized optimizer {}".format(optimizer_config["type"]))
|
102 |
+
|
103 |
+
if optimizer_config["schedule"]["type"] == "step":
|
104 |
+
scheduler = LRStepScheduler(optimizer, **optimizer_config["schedule"]["params"])
|
105 |
+
elif optimizer_config["schedule"]["type"] == "clr":
|
106 |
+
scheduler = CyclicLR(optimizer, **optimizer_config["schedule"]["params"])
|
107 |
+
elif optimizer_config["schedule"]["type"] == "multistep":
|
108 |
+
scheduler = MultiStepLR(optimizer, **optimizer_config["schedule"]["params"])
|
109 |
+
elif optimizer_config["schedule"]["type"] == "exponential":
|
110 |
+
scheduler = ExponentialLRScheduler(optimizer, **optimizer_config["schedule"]["params"])
|
111 |
+
elif optimizer_config["schedule"]["type"] == "poly":
|
112 |
+
scheduler = PolyLR(optimizer, **optimizer_config["schedule"]["params"])
|
113 |
+
elif optimizer_config["schedule"]["type"] == "constant":
|
114 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 1.0)
|
115 |
+
elif optimizer_config["schedule"]["type"] == "linear":
|
116 |
+
def linear_lr(it):
|
117 |
+
return it * optimizer_config["schedule"]["params"]["alpha"] + optimizer_config["schedule"]["params"]["beta"]
|
118 |
+
|
119 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, linear_lr)
|
120 |
+
|
121 |
+
return optimizer, scheduler
|
training/transforms/__init__.py
ADDED
File without changes
|
training/transforms/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (159 Bytes). View file
|
|
training/transforms/__pycache__/albu.cpython-310.pyc
ADDED
Binary file (4.36 kB). View file
|
|
training/transforms/albu.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
from albumentations import DualTransform, ImageOnlyTransform
|
6 |
+
from albumentations.augmentations.crops.functional import crop
|
7 |
+
#from albumentations.augmentations.functional import crop
|
8 |
+
|
9 |
+
|
10 |
+
def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
|
11 |
+
h, w = img.shape[:2]
|
12 |
+
if max(w, h) == size:
|
13 |
+
return img
|
14 |
+
if w > h:
|
15 |
+
scale = size / w
|
16 |
+
h = h * scale
|
17 |
+
w = size
|
18 |
+
else:
|
19 |
+
scale = size / h
|
20 |
+
w = w * scale
|
21 |
+
h = size
|
22 |
+
interpolation = interpolation_up if scale > 1 else interpolation_down
|
23 |
+
resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
|
24 |
+
return resized
|
25 |
+
|
26 |
+
|
27 |
+
class IsotropicResize(DualTransform):
|
28 |
+
def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC,
|
29 |
+
always_apply=False, p=1):
|
30 |
+
super(IsotropicResize, self).__init__(always_apply, p)
|
31 |
+
self.max_side = max_side
|
32 |
+
self.interpolation_down = interpolation_down
|
33 |
+
self.interpolation_up = interpolation_up
|
34 |
+
|
35 |
+
def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params):
|
36 |
+
return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down,
|
37 |
+
interpolation_up=interpolation_up)
|
38 |
+
|
39 |
+
def apply_to_mask(self, img, **params):
|
40 |
+
return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params)
|
41 |
+
|
42 |
+
def get_transform_init_args_names(self):
|
43 |
+
return ("max_side", "interpolation_down", "interpolation_up")
|
44 |
+
|
45 |
+
|
46 |
+
class Resize4xAndBack(ImageOnlyTransform):
|
47 |
+
def __init__(self, always_apply=False, p=0.5):
|
48 |
+
super(Resize4xAndBack, self).__init__(always_apply, p)
|
49 |
+
|
50 |
+
def apply(self, img, **params):
|
51 |
+
h, w = img.shape[:2]
|
52 |
+
scale = random.choice([2, 4])
|
53 |
+
img = cv2.resize(img, (w // scale, h // scale), interpolation=cv2.INTER_AREA)
|
54 |
+
img = cv2.resize(img, (w, h),
|
55 |
+
interpolation=random.choice([cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST]))
|
56 |
+
return img
|
57 |
+
|
58 |
+
|
59 |
+
class RandomSizedCropNonEmptyMaskIfExists(DualTransform):
|
60 |
+
|
61 |
+
def __init__(self, min_max_height, w2h_ratio=[0.7, 1.3], always_apply=False, p=0.5):
|
62 |
+
super(RandomSizedCropNonEmptyMaskIfExists, self).__init__(always_apply, p)
|
63 |
+
|
64 |
+
self.min_max_height = min_max_height
|
65 |
+
self.w2h_ratio = w2h_ratio
|
66 |
+
|
67 |
+
def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params):
|
68 |
+
cropped = crop(img, x_min, y_min, x_max, y_max)
|
69 |
+
return cropped
|
70 |
+
|
71 |
+
@property
|
72 |
+
def targets_as_params(self):
|
73 |
+
return ["mask"]
|
74 |
+
|
75 |
+
def get_params_dependent_on_targets(self, params):
|
76 |
+
mask = params["mask"]
|
77 |
+
mask_height, mask_width = mask.shape[:2]
|
78 |
+
crop_height = int(mask_height * random.uniform(self.min_max_height[0], self.min_max_height[1]))
|
79 |
+
w2h_ratio = random.uniform(*self.w2h_ratio)
|
80 |
+
crop_width = min(int(crop_height * w2h_ratio), mask_width - 1)
|
81 |
+
if mask.sum() == 0:
|
82 |
+
x_min = random.randint(0, mask_width - crop_width + 1)
|
83 |
+
y_min = random.randint(0, mask_height - crop_height + 1)
|
84 |
+
else:
|
85 |
+
mask = mask.sum(axis=-1) if mask.ndim == 3 else mask
|
86 |
+
non_zero_yx = np.argwhere(mask)
|
87 |
+
y, x = random.choice(non_zero_yx)
|
88 |
+
x_min = x - random.randint(0, crop_width - 1)
|
89 |
+
y_min = y - random.randint(0, crop_height - 1)
|
90 |
+
x_min = np.clip(x_min, 0, mask_width - crop_width)
|
91 |
+
y_min = np.clip(y_min, 0, mask_height - crop_height)
|
92 |
+
|
93 |
+
x_max = x_min + crop_height
|
94 |
+
y_max = y_min + crop_width
|
95 |
+
y_max = min(mask_height, y_max)
|
96 |
+
x_max = min(mask_width, x_max)
|
97 |
+
return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
|
98 |
+
|
99 |
+
def get_transform_init_args_names(self):
|
100 |
+
return "min_max_height", "height", "width", "w2h_ratio"
|
training/zoo/__init__.py
ADDED
File without changes
|
training/zoo/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (152 Bytes). View file
|
|
training/zoo/__pycache__/classifiers.cpython-310.pyc
ADDED
Binary file (5.55 kB). View file
|
|
training/zoo/classifiers.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
|
6 |
+
tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7_ns
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn.modules.dropout import Dropout
|
9 |
+
from torch.nn.modules.linear import Linear
|
10 |
+
from torch.nn.modules.pooling import AdaptiveAvgPool2d
|
11 |
+
|
12 |
+
encoder_params = {
|
13 |
+
"tf_efficientnet_b3_ns": {
|
14 |
+
"features": 1536,
|
15 |
+
"init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
|
16 |
+
},
|
17 |
+
"tf_efficientnet_b2_ns": {
|
18 |
+
"features": 1408,
|
19 |
+
"init_op": partial(tf_efficientnet_b2_ns, pretrained=False, drop_path_rate=0.2)
|
20 |
+
},
|
21 |
+
"tf_efficientnet_b4_ns": {
|
22 |
+
"features": 1792,
|
23 |
+
"init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.5)
|
24 |
+
},
|
25 |
+
"tf_efficientnet_b5_ns": {
|
26 |
+
"features": 2048,
|
27 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
|
28 |
+
},
|
29 |
+
"tf_efficientnet_b4_ns_03d": {
|
30 |
+
"features": 1792,
|
31 |
+
"init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.3)
|
32 |
+
},
|
33 |
+
"tf_efficientnet_b5_ns_03d": {
|
34 |
+
"features": 2048,
|
35 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.3)
|
36 |
+
},
|
37 |
+
"tf_efficientnet_b5_ns_04d": {
|
38 |
+
"features": 2048,
|
39 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.4)
|
40 |
+
},
|
41 |
+
"tf_efficientnet_b6_ns": {
|
42 |
+
"features": 2304,
|
43 |
+
"init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.2)
|
44 |
+
},
|
45 |
+
"tf_efficientnet_b7_ns": {
|
46 |
+
"features": 2560,
|
47 |
+
"init_op": partial(tf_efficientnet_b7_ns, pretrained=True, drop_path_rate=0.2)
|
48 |
+
},
|
49 |
+
"tf_efficientnet_b6_ns_04d": {
|
50 |
+
"features": 2304,
|
51 |
+
"init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.4)
|
52 |
+
},
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
def setup_srm_weights(input_channels: int = 3) -> torch.Tensor:
|
57 |
+
"""Creates the SRM kernels for noise analysis."""
|
58 |
+
# note: values taken from Zhou et al., "Learning Rich Features for Image Manipulation Detection", CVPR2018
|
59 |
+
srm_kernel = torch.from_numpy(np.array([
|
60 |
+
[ # srm 1/2 horiz
|
61 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
62 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
63 |
+
[0., 1., -2., 1., 0.], # noqa: E241,E201
|
64 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
65 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
66 |
+
], [ # srm 1/4
|
67 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
68 |
+
[0., -1., 2., -1., 0.], # noqa: E241,E201
|
69 |
+
[0., 2., -4., 2., 0.], # noqa: E241,E201
|
70 |
+
[0., -1., 2., -1., 0.], # noqa: E241,E201
|
71 |
+
[0., 0., 0., 0., 0.], # noqa: E241,E201
|
72 |
+
], [ # srm 1/12
|
73 |
+
[-1., 2., -2., 2., -1.], # noqa: E241,E201
|
74 |
+
[2., -6., 8., -6., 2.], # noqa: E241,E201
|
75 |
+
[-2., 8., -12., 8., -2.], # noqa: E241,E201
|
76 |
+
[2., -6., 8., -6., 2.], # noqa: E241,E201
|
77 |
+
[-1., 2., -2., 2., -1.], # noqa: E241,E201
|
78 |
+
]
|
79 |
+
])).float()
|
80 |
+
srm_kernel[0] /= 2
|
81 |
+
srm_kernel[1] /= 4
|
82 |
+
srm_kernel[2] /= 12
|
83 |
+
return srm_kernel.view(3, 1, 5, 5).repeat(1, input_channels, 1, 1)
|
84 |
+
|
85 |
+
|
86 |
+
def setup_srm_layer(input_channels: int = 3) -> torch.nn.Module:
|
87 |
+
"""Creates a SRM convolution layer for noise analysis."""
|
88 |
+
weights = setup_srm_weights(input_channels)
|
89 |
+
conv = torch.nn.Conv2d(input_channels, out_channels=3, kernel_size=5, stride=1, padding=2, bias=False)
|
90 |
+
with torch.no_grad():
|
91 |
+
conv.weight = torch.nn.Parameter(weights, requires_grad=False)
|
92 |
+
return conv
|
93 |
+
|
94 |
+
|
95 |
+
class DeepFakeClassifierSRM(nn.Module):
|
96 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
97 |
+
super().__init__()
|
98 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
99 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
100 |
+
self.srm_conv = setup_srm_layer(3)
|
101 |
+
self.dropout = Dropout(dropout_rate)
|
102 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
noise = self.srm_conv(x)
|
106 |
+
x = self.encoder.forward_features(noise)
|
107 |
+
x = self.avg_pool(x).flatten(1)
|
108 |
+
x = self.dropout(x)
|
109 |
+
x = self.fc(x)
|
110 |
+
return x
|
111 |
+
|
112 |
+
|
113 |
+
class GlobalWeightedAvgPool2d(nn.Module):
|
114 |
+
"""
|
115 |
+
Global Weighted Average Pooling from paper "Global Weighted Average
|
116 |
+
Pooling Bridges Pixel-level Localization and Image-level Classification"
|
117 |
+
"""
|
118 |
+
|
119 |
+
def __init__(self, features: int, flatten=False):
|
120 |
+
super().__init__()
|
121 |
+
self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True)
|
122 |
+
self.flatten = flatten
|
123 |
+
|
124 |
+
def fscore(self, x):
|
125 |
+
m = self.conv(x)
|
126 |
+
m = m.sigmoid().exp()
|
127 |
+
return m
|
128 |
+
|
129 |
+
def norm(self, x: torch.Tensor):
|
130 |
+
return x / x.sum(dim=[2, 3], keepdim=True)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
input_x = x
|
134 |
+
x = self.fscore(x)
|
135 |
+
x = self.norm(x)
|
136 |
+
x = x * input_x
|
137 |
+
x = x.sum(dim=[2, 3], keepdim=not self.flatten)
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
class DeepFakeClassifier(nn.Module):
|
142 |
+
def __init__(self, encoder, dropout_rate=0.0) -> None:
|
143 |
+
super().__init__()
|
144 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
145 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
146 |
+
self.dropout = Dropout(dropout_rate)
|
147 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
x = self.encoder.forward_features(x)
|
151 |
+
x = self.avg_pool(x).flatten(1)
|
152 |
+
x = self.dropout(x)
|
153 |
+
x = self.fc(x)
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
|
158 |
+
|
159 |
+
class DeepFakeClassifierGWAP(nn.Module):
|
160 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
161 |
+
super().__init__()
|
162 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
163 |
+
self.avg_pool = GlobalWeightedAvgPool2d(encoder_params[encoder]["features"])
|
164 |
+
self.dropout = Dropout(dropout_rate)
|
165 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
x = self.encoder.forward_features(x)
|
169 |
+
x = self.avg_pool(x).flatten(1)
|
170 |
+
x = self.dropout(x)
|
171 |
+
x = self.fc(x)
|
172 |
+
return x
|
training/zoo/unet.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from timm.models.efficientnet import tf_efficientnet_b3_ns, tf_efficientnet_b5_ns
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import Dropout2d, Conv2d
|
7 |
+
from torch.nn.modules.dropout import Dropout
|
8 |
+
from torch.nn.modules.linear import Linear
|
9 |
+
from torch.nn.modules.pooling import AdaptiveAvgPool2d
|
10 |
+
from torch.nn.modules.upsampling import UpsamplingBilinear2d
|
11 |
+
|
12 |
+
encoder_params = {
|
13 |
+
"tf_efficientnet_b3_ns": {
|
14 |
+
"features": 1536,
|
15 |
+
"filters": [40, 32, 48, 136, 1536],
|
16 |
+
"decoder_filters": [64, 128, 256, 256],
|
17 |
+
"init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
|
18 |
+
},
|
19 |
+
"tf_efficientnet_b5_ns": {
|
20 |
+
"features": 2048,
|
21 |
+
"filters": [48, 40, 64, 176, 2048],
|
22 |
+
"decoder_filters": [64, 128, 256, 256],
|
23 |
+
"init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
|
24 |
+
},
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
class DecoderBlock(nn.Module):
|
29 |
+
def __init__(self, in_channels, out_channels):
|
30 |
+
super().__init__()
|
31 |
+
self.layer = nn.Sequential(
|
32 |
+
nn.Upsample(scale_factor=2),
|
33 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
34 |
+
nn.ReLU(inplace=True)
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.layer(x)
|
39 |
+
|
40 |
+
|
41 |
+
class ConcatBottleneck(nn.Module):
|
42 |
+
def __init__(self, in_channels, out_channels):
|
43 |
+
super().__init__()
|
44 |
+
self.seq = nn.Sequential(
|
45 |
+
nn.Conv2d(in_channels, out_channels, 3, padding=1),
|
46 |
+
nn.ReLU(inplace=True)
|
47 |
+
)
|
48 |
+
|
49 |
+
def forward(self, dec, enc):
|
50 |
+
x = torch.cat([dec, enc], dim=1)
|
51 |
+
return self.seq(x)
|
52 |
+
|
53 |
+
|
54 |
+
class Decoder(nn.Module):
|
55 |
+
def __init__(self, decoder_filters, filters, upsample_filters=None,
|
56 |
+
decoder_block=DecoderBlock, bottleneck=ConcatBottleneck, dropout=0):
|
57 |
+
super().__init__()
|
58 |
+
self.decoder_filters = decoder_filters
|
59 |
+
self.filters = filters
|
60 |
+
self.decoder_block = decoder_block
|
61 |
+
self.decoder_stages = nn.ModuleList([self._get_decoder(idx) for idx in range(0, len(decoder_filters))])
|
62 |
+
self.bottlenecks = nn.ModuleList([bottleneck(self.filters[-i - 2] + f, f)
|
63 |
+
for i, f in enumerate(reversed(decoder_filters))])
|
64 |
+
self.dropout = Dropout2d(dropout) if dropout > 0 else None
|
65 |
+
self.last_block = None
|
66 |
+
if upsample_filters:
|
67 |
+
self.last_block = decoder_block(decoder_filters[0], out_channels=upsample_filters)
|
68 |
+
else:
|
69 |
+
self.last_block = UpsamplingBilinear2d(scale_factor=2)
|
70 |
+
|
71 |
+
def forward(self, encoder_results: list):
|
72 |
+
x = encoder_results[0]
|
73 |
+
bottlenecks = self.bottlenecks
|
74 |
+
for idx, bottleneck in enumerate(bottlenecks):
|
75 |
+
rev_idx = - (idx + 1)
|
76 |
+
x = self.decoder_stages[rev_idx](x)
|
77 |
+
x = bottleneck(x, encoder_results[-rev_idx])
|
78 |
+
if self.last_block:
|
79 |
+
x = self.last_block(x)
|
80 |
+
if self.dropout:
|
81 |
+
x = self.dropout(x)
|
82 |
+
return x
|
83 |
+
|
84 |
+
def _get_decoder(self, layer):
|
85 |
+
idx = layer + 1
|
86 |
+
if idx == len(self.decoder_filters):
|
87 |
+
in_channels = self.filters[idx]
|
88 |
+
else:
|
89 |
+
in_channels = self.decoder_filters[idx]
|
90 |
+
return self.decoder_block(in_channels, self.decoder_filters[max(layer, 0)])
|
91 |
+
|
92 |
+
|
93 |
+
def _initialize_weights(module):
|
94 |
+
for m in module.modules():
|
95 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
|
96 |
+
m.weight.data = nn.init.kaiming_normal_(m.weight.data)
|
97 |
+
if m.bias is not None:
|
98 |
+
m.bias.data.zero_()
|
99 |
+
elif isinstance(m, nn.BatchNorm2d):
|
100 |
+
m.weight.data.fill_(1)
|
101 |
+
m.bias.data.zero_()
|
102 |
+
|
103 |
+
|
104 |
+
class EfficientUnetClassifier(nn.Module):
|
105 |
+
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
106 |
+
super().__init__()
|
107 |
+
self.decoder = Decoder(decoder_filters=encoder_params[encoder]["decoder_filters"],
|
108 |
+
filters=encoder_params[encoder]["filters"])
|
109 |
+
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
110 |
+
self.dropout = Dropout(dropout_rate)
|
111 |
+
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
112 |
+
self.final = Conv2d(encoder_params[encoder]["decoder_filters"][0], out_channels=1, kernel_size=1, bias=False)
|
113 |
+
_initialize_weights(self)
|
114 |
+
self.encoder = encoder_params[encoder]["init_op"]()
|
115 |
+
|
116 |
+
def get_encoder_features(self, x):
|
117 |
+
encoder_results = []
|
118 |
+
x = self.encoder.conv_stem(x)
|
119 |
+
x = self.encoder.bn1(x)
|
120 |
+
x = self.encoder.act1(x)
|
121 |
+
encoder_results.append(x)
|
122 |
+
x = self.encoder.blocks[:2](x)
|
123 |
+
encoder_results.append(x)
|
124 |
+
x = self.encoder.blocks[2:3](x)
|
125 |
+
encoder_results.append(x)
|
126 |
+
x = self.encoder.blocks[3:5](x)
|
127 |
+
encoder_results.append(x)
|
128 |
+
x = self.encoder.blocks[5:](x)
|
129 |
+
x = self.encoder.conv_head(x)
|
130 |
+
x = self.encoder.bn2(x)
|
131 |
+
x = self.encoder.act2(x)
|
132 |
+
encoder_results.append(x)
|
133 |
+
encoder_results = list(reversed(encoder_results))
|
134 |
+
return encoder_results
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
encoder_results = self.get_encoder_features(x)
|
138 |
+
seg = self.final(self.decoder(encoder_results))
|
139 |
+
x = encoder_results[0]
|
140 |
+
x = self.avg_pool(x).flatten(1)
|
141 |
+
x = self.dropout(x)
|
142 |
+
x = self.fc(x)
|
143 |
+
return x, seg
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
model = EfficientUnetClassifier("tf_efficientnet_b5_ns")
|
148 |
+
model.eval()
|
149 |
+
with torch.no_grad():
|
150 |
+
input = torch.rand(4, 3, 224, 224)
|
151 |
+
print(model(input))
|
weights/.gitkeep
ADDED
File without changes
|
weights/b7_ns_best.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9db77ab9318863e2f8ab287c8eb83c2232584b82dc2fb41f1d614ddd7900cccb
|
3 |
+
size 266910617
|