junkmind thecho7 commited on
Commit
f69dbe4
·
0 Parent(s):

Duplicate from thecho7/deepfake

Browse files

Co-authored-by: Suho Cho <[email protected]>

Files changed (48) hide show
  1. .gitattributes +36 -0
  2. Dockerfile +54 -0
  3. LICENSE +21 -0
  4. README.md +14 -0
  5. __pycache__/kernel_utils.cpython-310.pyc +0 -0
  6. app.py +86 -0
  7. configs/b5.json +28 -0
  8. configs/b7.json +29 -0
  9. download_weights.sh +9 -0
  10. examples/liuujwwgpr.mp4 +3 -0
  11. examples/nlurbvsozt.mp4 +3 -0
  12. examples/rfjuhbnlro.mp4 +3 -0
  13. kernel_utils.py +366 -0
  14. libs/shape_predictor_68_face_landmarks.dat +3 -0
  15. requirements.txt +131 -0
  16. training/__init__.py +0 -0
  17. training/__pycache__/__init__.cpython-310.pyc +0 -0
  18. training/__pycache__/__init__.cpython-39.pyc +0 -0
  19. training/__pycache__/losses.cpython-310.pyc +0 -0
  20. training/__pycache__/losses.cpython-39.pyc +0 -0
  21. training/datasets/__init__.py +0 -0
  22. training/datasets/__pycache__/__init__.cpython-310.pyc +0 -0
  23. training/datasets/__pycache__/classifier_dataset.cpython-310.pyc +0 -0
  24. training/datasets/__pycache__/validation_set.cpython-310.pyc +0 -0
  25. training/datasets/classifier_dataset.py +384 -0
  26. training/datasets/validation_set.py +60 -0
  27. training/losses.py +28 -0
  28. training/pipelines/__init__.py +0 -0
  29. training/pipelines/train_classifier.py +364 -0
  30. training/tools/__init__.py +0 -0
  31. training/tools/__pycache__/__init__.cpython-310.pyc +0 -0
  32. training/tools/__pycache__/config.cpython-310.pyc +0 -0
  33. training/tools/__pycache__/schedulers.cpython-310.pyc +0 -0
  34. training/tools/__pycache__/utils.cpython-310.pyc +0 -0
  35. training/tools/config.py +43 -0
  36. training/tools/schedulers.py +46 -0
  37. training/tools/utils.py +121 -0
  38. training/transforms/__init__.py +0 -0
  39. training/transforms/__pycache__/__init__.cpython-310.pyc +0 -0
  40. training/transforms/__pycache__/albu.cpython-310.pyc +0 -0
  41. training/transforms/albu.py +100 -0
  42. training/zoo/__init__.py +0 -0
  43. training/zoo/__pycache__/__init__.cpython-310.pyc +0 -0
  44. training/zoo/__pycache__/classifiers.cpython-310.pyc +0 -0
  45. training/zoo/classifiers.py +172 -0
  46. training/zoo/unet.py +151 -0
  47. weights/.gitkeep +0 -0
  48. weights/b7_ns_best.pth +3 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
36
+ *.dat filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG PYTORCH="1.10.0"
2
+ ARG CUDA="11.3"
3
+ ARG CUDNN="8"
4
+
5
+ FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
6
+
7
+ ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
8
+ ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
9
+
10
+ # Setting noninteractive build, setting up tzdata and configuring timezones
11
+ ENV DEBIAN_FRONTEND=noninteractive
12
+ ENV TZ=Europe/Berlin
13
+ RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
14
+
15
+ RUN apt-get update && apt-get install -y libglib2.0-0 libsm6 libxrender-dev libxext6 nano mc glances vim git \
16
+ && apt-get clean \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ # Install cython
20
+ RUN conda install cython -y && conda clean --all
21
+
22
+ # Installing APEX
23
+ RUN pip install -U pip
24
+ RUN git clone https://github.com/NVIDIA/apex
25
+ RUN sed -i 's/check_cuda_torch_binary_vs_bare_metal(torch.utils.cpp_extension.CUDA_HOME)/pass/g' apex/setup.py
26
+ RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./apex
27
+ RUN apt-get update -y
28
+ RUN apt-get install build-essential cmake -y
29
+ RUN apt-get install libopenblas-dev liblapack-dev -y
30
+ RUN apt-get install libx11-dev libgtk-3-dev -y
31
+ RUN pip install dlib
32
+ RUN pip install facenet-pytorch
33
+ RUN pip install albumentations==1.0.0 timm==0.4.12 pytorch_toolbelt tensorboardx
34
+ RUN pip install cython jupyter jupyterlab ipykernel matplotlib tqdm pandas
35
+
36
+ # download pretraned Imagenet models
37
+ RUN apt install wget
38
+ RUN wget https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth -P /root/.cache/torch/hub/checkpoints/
39
+ RUN wget https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth -P /root/.cache/torch/hub/checkpoints/
40
+
41
+ # Setting the working directory
42
+ WORKDIR /workspace
43
+
44
+ # Copying the required codebase
45
+ COPY . /workspace
46
+
47
+ RUN chmod 777 preprocess_data.sh
48
+ RUN chmod 777 train.sh
49
+ RUN chmod 777 predict_submission.sh
50
+
51
+ ENV PYTHONPATH=.
52
+
53
+ CMD ["/bin/bash"]
54
+
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Selim Seferbekov
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Deepfake
3
+ emoji: 🔥
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.29.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: unlicense
11
+ duplicated_from: thecho7/deepfake
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__pycache__/kernel_utils.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import re
4
+ import time
5
+
6
+ import torch
7
+ from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video
8
+ from training.zoo.classifiers import DeepFakeClassifier
9
+
10
+ import gradio as gr
11
+
12
+ def model_fn(model_dir):
13
+ model_path = os.path.join(model_dir, 'b7_ns_best.pth')
14
+ model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns") # default: CPU
15
+ checkpoint = torch.load(model_path, map_location="cpu")
16
+ state_dict = checkpoint.get("state_dict", checkpoint)
17
+ model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True)
18
+ model.eval()
19
+ del checkpoint
20
+ #models.append(model.half())
21
+
22
+ return model
23
+
24
+ def convert_result(pred, class_names=["Real", "Fake"]):
25
+ preds = [pred, 1 - pred]
26
+ assert len(class_names) == len(preds), "Class / Prediction should have the same length"
27
+ return {n: float(p) for n, p in zip(class_names, preds)}
28
+
29
+ def predict_fn(video):
30
+ start = time.time()
31
+ prediction = predict_on_video(face_extractor=meta["face_extractor"],
32
+ video_path=video,
33
+ batch_size=meta["fps"],
34
+ input_size=meta["input_size"],
35
+ models=model,
36
+ strategy=meta["strategy"],
37
+ apply_compression=False,
38
+ device='cpu')
39
+
40
+ elapsed_time = round(time.time() - start, 2)
41
+
42
+ prediction = convert_result(prediction)
43
+
44
+ return prediction, elapsed_time
45
+
46
+ # Create title, description and article strings
47
+ title = "Deepfake Detector (private)"
48
+ description = "A video Deepfake Classifier (code: https://github.com/selimsef/dfdc_deepfake_challenge)"
49
+
50
+ example_list = ["examples/" + str(p) for p in os.listdir("examples/")]
51
+
52
+ # Environments
53
+ model_dir = 'weights'
54
+ frames_per_video = 32
55
+ video_reader = VideoReader()
56
+ video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
57
+ face_extractor = FaceExtractor(video_read_fn)
58
+ input_size = 380
59
+ strategy = confident_strategy
60
+ class_names = ["Real", "Fake"]
61
+
62
+ meta = {"fps": 32,
63
+ "face_extractor": face_extractor,
64
+ "input_size": input_size,
65
+ "strategy": strategy}
66
+
67
+ model = model_fn(model_dir)
68
+
69
+ """
70
+ if __name__ == '__main__':
71
+ video_path = "examples/nlurbvsozt.mp4"
72
+ model = model_fn(model_dir)
73
+ a, b = predict_fn(video_path)
74
+ print(a, b)
75
+ """
76
+ # Create the Gradio demo
77
+ demo = gr.Interface(fn=predict_fn, # mapping function from input to output
78
+ inputs=gr.Video(),
79
+ outputs=[gr.Label(num_top_classes=2, label="Predictions"), # what are the outputs?
80
+ gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs
81
+ examples=example_list,
82
+ title=title,
83
+ description=description)
84
+
85
+ # Launch the demo!
86
+ demo.launch(debug=False,) # Hugging face space don't need shareable_links
configs/b5.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "network": "DeepFakeClassifier",
3
+ "encoder": "tf_efficientnet_b5_ns",
4
+ "batches_per_epoch": 2500,
5
+ "size": 380,
6
+ "fp16": true,
7
+ "optimizer": {
8
+ "batch_size": 20,
9
+ "type": "SGD",
10
+ "momentum": 0.9,
11
+ "weight_decay": 1e-4,
12
+ "learning_rate": 0.01,
13
+ "nesterov": true,
14
+ "schedule": {
15
+ "type": "poly",
16
+ "mode": "step",
17
+ "epochs": 30,
18
+ "params": {"max_iter": 75100}
19
+ }
20
+ },
21
+ "normalize": {
22
+ "mean": [0.485, 0.456, 0.406],
23
+ "std": [0.229, 0.224, 0.225]
24
+ },
25
+ "losses": {
26
+ "BinaryCrossentropy": 1
27
+ }
28
+ }
configs/b7.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "network": "DeepFakeClassifier",
3
+ "encoder": "tf_efficientnet_b7_ns",
4
+ "batches_per_epoch": 2500,
5
+ "size": 380,
6
+ "fp16": true,
7
+ "optimizer": {
8
+ "batch_size": 4,
9
+ "type": "SGD",
10
+ "momentum": 0.9,
11
+ "weight_decay": 1e-4,
12
+ "learning_rate": 1e-4,
13
+ "nesterov": true,
14
+ "schedule": {
15
+ "type": "poly",
16
+ "mode": "step",
17
+ "epochs": 20,
18
+ "params": {"max_iter": 100500}
19
+ }
20
+ },
21
+ "normalize": {
22
+ "mean": [0.485, 0.456, 0.406],
23
+ "std": [0.229, 0.224, 0.225]
24
+ },
25
+ "losses": {
26
+ "BinaryCrossentropy": 1
27
+ }
28
+ }
29
+
download_weights.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ tag=0.0.1
2
+
3
+ wget -O weights/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36
4
+ wget -O weights/final_555_DeepFakeClassifier_tf_efficientnet_b7_ns_0_19 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_555_DeepFakeClassifier_tf_efficientnet_b7_ns_0_19
5
+ wget -O weights/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_29 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_29
6
+ wget -O weights/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_31 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_777_DeepFakeClassifier_tf_efficientnet_b7_ns_0_31
7
+ wget -O weights/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_37 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_37
8
+ wget -O weights/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_40 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_888_DeepFakeClassifier_tf_efficientnet_b7_ns_0_40
9
+ wget -O weights/final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23 https://github.com/selimsef/dfdc_deepfake_challenge/releases/download/$tag/final_999_DeepFakeClassifier_tf_efficientnet_b7_ns_0_23
examples/liuujwwgpr.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3aaefb51aa5720cdabcc68d93da5c6a22573d8da06bdaf5e009c7a370943e85
3
+ size 12852441
examples/nlurbvsozt.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:300b7dea93132b512f35de76572e7fcde666c812b91aec6b189dafa6f100c9b5
3
+ size 4486723
examples/rfjuhbnlro.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b6d0bb841ebe6a8e20cf265b45356a1ea3fed9837025e8d549b2437290d79273
3
+ size 16218775
kernel_utils.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from albumentations.augmentations.functional import image_compression
8
+ from facenet_pytorch.models.mtcnn import MTCNN
9
+ from concurrent.futures import ThreadPoolExecutor
10
+
11
+ from torchvision.transforms import Normalize
12
+
13
+ mean = [0.485, 0.456, 0.406]
14
+ std = [0.229, 0.224, 0.225]
15
+ normalize_transform = Normalize(mean, std)
16
+
17
+
18
+ class VideoReader:
19
+ """Helper class for reading one or more frames from a video file."""
20
+
21
+ def __init__(self, verbose=True, insets=(0, 0)):
22
+ """Creates a new VideoReader.
23
+
24
+ Arguments:
25
+ verbose: whether to print warnings and error messages
26
+ insets: amount to inset the image by, as a percentage of
27
+ (width, height). This lets you "zoom in" to an image
28
+ to remove unimportant content around the borders.
29
+ Useful for face detection, which may not work if the
30
+ faces are too small.
31
+ """
32
+ self.verbose = verbose
33
+ self.insets = insets
34
+
35
+ def read_frames(self, path, num_frames, jitter=0, seed=None):
36
+ """Reads frames that are always evenly spaced throughout the video.
37
+
38
+ Arguments:
39
+ path: the video file
40
+ num_frames: how many frames to read, -1 means the entire video
41
+ (warning: this will take up a lot of memory!)
42
+ jitter: if not 0, adds small random offsets to the frame indices;
43
+ this is useful so we don't always land on even or odd frames
44
+ seed: random seed for jittering; if you set this to a fixed value,
45
+ you probably want to set it only on the first video
46
+ """
47
+ assert num_frames > 0
48
+
49
+ capture = cv2.VideoCapture(path)
50
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
51
+ if frame_count <= 0: return None
52
+
53
+ frame_idxs = np.linspace(0, frame_count - 1, num_frames, endpoint=True, dtype=np.int32)
54
+ if jitter > 0:
55
+ np.random.seed(seed)
56
+ jitter_offsets = np.random.randint(-jitter, jitter, len(frame_idxs))
57
+ frame_idxs = np.clip(frame_idxs + jitter_offsets, 0, frame_count - 1)
58
+
59
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
60
+ capture.release()
61
+ return result
62
+
63
+ def read_random_frames(self, path, num_frames, seed=None):
64
+ """Picks the frame indices at random.
65
+
66
+ Arguments:
67
+ path: the video file
68
+ num_frames: how many frames to read, -1 means the entire video
69
+ (warning: this will take up a lot of memory!)
70
+ """
71
+ assert num_frames > 0
72
+ np.random.seed(seed)
73
+
74
+ capture = cv2.VideoCapture(path)
75
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
76
+ if frame_count <= 0: return None
77
+
78
+ frame_idxs = sorted(np.random.choice(np.arange(0, frame_count), num_frames))
79
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
80
+
81
+ capture.release()
82
+ return result
83
+
84
+ def read_frames_at_indices(self, path, frame_idxs):
85
+ """Reads frames from a video and puts them into a NumPy array.
86
+
87
+ Arguments:
88
+ path: the video file
89
+ frame_idxs: a list of frame indices. Important: should be
90
+ sorted from low-to-high! If an index appears multiple
91
+ times, the frame is still read only once.
92
+
93
+ Returns:
94
+ - a NumPy array of shape (num_frames, height, width, 3)
95
+ - a list of the frame indices that were read
96
+
97
+ Reading stops if loading a frame fails, in which case the first
98
+ dimension returned may actually be less than num_frames.
99
+
100
+ Returns None if an exception is thrown for any reason, or if no
101
+ frames were read.
102
+ """
103
+ assert len(frame_idxs) > 0
104
+ capture = cv2.VideoCapture(path)
105
+ result = self._read_frames_at_indices(path, capture, frame_idxs)
106
+ capture.release()
107
+ return result
108
+
109
+ def _read_frames_at_indices(self, path, capture, frame_idxs):
110
+ try:
111
+ frames = []
112
+ idxs_read = []
113
+ for frame_idx in range(frame_idxs[0], frame_idxs[-1] + 1):
114
+ # Get the next frame, but don't decode if we're not using it.
115
+ ret = capture.grab()
116
+ if not ret:
117
+ if self.verbose:
118
+ print("Error grabbing frame %d from movie %s" % (frame_idx, path))
119
+ break
120
+
121
+ # Need to look at this frame?
122
+ current = len(idxs_read)
123
+ if frame_idx == frame_idxs[current]:
124
+ ret, frame = capture.retrieve()
125
+ if not ret or frame is None:
126
+ if self.verbose:
127
+ print("Error retrieving frame %d from movie %s" % (frame_idx, path))
128
+ break
129
+
130
+ frame = self._postprocess_frame(frame)
131
+ frames.append(frame)
132
+ idxs_read.append(frame_idx)
133
+
134
+ if len(frames) > 0:
135
+ return np.stack(frames), idxs_read
136
+ if self.verbose:
137
+ print("No frames read from movie %s" % path)
138
+ return None
139
+ except:
140
+ if self.verbose:
141
+ print("Exception while reading movie %s" % path)
142
+ return None
143
+
144
+ def read_middle_frame(self, path):
145
+ """Reads the frame from the middle of the video."""
146
+ capture = cv2.VideoCapture(path)
147
+ frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
148
+ result = self._read_frame_at_index(path, capture, frame_count // 2)
149
+ capture.release()
150
+ return result
151
+
152
+ def read_frame_at_index(self, path, frame_idx):
153
+ """Reads a single frame from a video.
154
+
155
+ If you just want to read a single frame from the video, this is more
156
+ efficient than scanning through the video to find the frame. However,
157
+ for reading multiple frames it's not efficient.
158
+
159
+ My guess is that a "streaming" approach is more efficient than a
160
+ "random access" approach because, unless you happen to grab a keyframe,
161
+ the decoder still needs to read all the previous frames in order to
162
+ reconstruct the one you're asking for.
163
+
164
+ Returns a NumPy array of shape (1, H, W, 3) and the index of the frame,
165
+ or None if reading failed.
166
+ """
167
+ capture = cv2.VideoCapture(path)
168
+ result = self._read_frame_at_index(path, capture, frame_idx)
169
+ capture.release()
170
+ return result
171
+
172
+ def _read_frame_at_index(self, path, capture, frame_idx):
173
+ capture.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
174
+ ret, frame = capture.read()
175
+ if not ret or frame is None:
176
+ if self.verbose:
177
+ print("Error retrieving frame %d from movie %s" % (frame_idx, path))
178
+ return None
179
+ else:
180
+ frame = self._postprocess_frame(frame)
181
+ return np.expand_dims(frame, axis=0), [frame_idx]
182
+
183
+ def _postprocess_frame(self, frame):
184
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
185
+
186
+ if self.insets[0] > 0:
187
+ W = frame.shape[1]
188
+ p = int(W * self.insets[0])
189
+ frame = frame[:, p:-p, :]
190
+
191
+ if self.insets[1] > 0:
192
+ H = frame.shape[1]
193
+ q = int(H * self.insets[1])
194
+ frame = frame[q:-q, :, :]
195
+
196
+ return frame
197
+
198
+
199
+ class FaceExtractor:
200
+ def __init__(self, video_read_fn):
201
+ self.video_read_fn = video_read_fn
202
+ self.detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device="cpu")
203
+
204
+ def process_videos(self, input_dir, filenames, video_idxs):
205
+ videos_read = []
206
+ frames_read = []
207
+ frames = []
208
+ results = []
209
+ for video_idx in video_idxs:
210
+ # Read the full-size frames from this video.
211
+ filename = filenames[video_idx]
212
+ video_path = os.path.join(input_dir, filename)
213
+ result = self.video_read_fn(video_path)
214
+ # Error? Then skip this video.
215
+ if result is None: continue
216
+
217
+ videos_read.append(video_idx)
218
+
219
+ # Keep track of the original frames (need them later).
220
+ my_frames, my_idxs = result
221
+
222
+ frames.append(my_frames)
223
+ frames_read.append(my_idxs)
224
+ for i, frame in enumerate(my_frames):
225
+ h, w = frame.shape[:2]
226
+ img = Image.fromarray(frame.astype(np.uint8))
227
+ img = img.resize(size=[s // 2 for s in img.size])
228
+
229
+ batch_boxes, probs = self.detector.detect(img, landmarks=False)
230
+
231
+ faces = []
232
+ scores = []
233
+ if batch_boxes is None:
234
+ continue
235
+ for bbox, score in zip(batch_boxes, probs):
236
+ if bbox is not None:
237
+ xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox]
238
+ w = xmax - xmin
239
+ h = ymax - ymin
240
+ p_h = h // 3
241
+ p_w = w // 3
242
+ crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w]
243
+ faces.append(crop)
244
+ scores.append(score)
245
+
246
+ frame_dict = {"video_idx": video_idx,
247
+ "frame_idx": my_idxs[i],
248
+ "frame_w": w,
249
+ "frame_h": h,
250
+ "faces": faces,
251
+ "scores": scores}
252
+ results.append(frame_dict)
253
+
254
+ return results
255
+
256
+ def process_video(self, video_path):
257
+ """Convenience method for doing face extraction on a single video."""
258
+ input_dir = os.path.dirname(video_path)
259
+ filenames = [os.path.basename(video_path)]
260
+ return self.process_videos(input_dir, filenames, [0])
261
+
262
+
263
+
264
+ def confident_strategy(pred, t=0.8):
265
+ pred = np.array(pred)
266
+ sz = len(pred)
267
+ fakes = np.count_nonzero(pred > t)
268
+ # 11 frames are detected as fakes with high probability
269
+ if fakes > sz // 2.5 and fakes > 11:
270
+ return np.mean(pred[pred > t])
271
+ elif np.count_nonzero(pred < 0.2) > 0.9 * sz:
272
+ return np.mean(pred[pred < 0.2])
273
+ else:
274
+ return np.mean(pred)
275
+
276
+ strategy = confident_strategy
277
+
278
+
279
+ def put_to_center(img, input_size):
280
+ img = img[:input_size, :input_size]
281
+ image = np.zeros((input_size, input_size, 3), dtype=np.uint8)
282
+ start_w = (input_size - img.shape[1]) // 2
283
+ start_h = (input_size - img.shape[0]) // 2
284
+ image[start_h:start_h + img.shape[0], start_w: start_w + img.shape[1], :] = img
285
+ return image
286
+
287
+
288
+ def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
289
+ h, w = img.shape[:2]
290
+ if max(w, h) == size:
291
+ return img
292
+ if w > h:
293
+ scale = size / w
294
+ h = h * scale
295
+ w = size
296
+ else:
297
+ scale = size / h
298
+ w = w * scale
299
+ h = size
300
+ interpolation = interpolation_up if scale > 1 else interpolation_down
301
+ resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
302
+ return resized
303
+
304
+
305
+ def predict_on_video(face_extractor, video_path, batch_size, input_size, models, strategy=np.mean,
306
+ apply_compression=False, device='cpu'):
307
+ batch_size *= 4
308
+ try:
309
+ faces = face_extractor.process_video(video_path)
310
+ if len(faces) > 0:
311
+ x = np.zeros((batch_size, input_size, input_size, 3), dtype=np.uint8)
312
+ n = 0
313
+ for frame_data in faces:
314
+ for face in frame_data["faces"]:
315
+ resized_face = isotropically_resize_image(face, input_size)
316
+ resized_face = put_to_center(resized_face, input_size)
317
+ if apply_compression:
318
+ resized_face = image_compression(resized_face, quality=90, image_type=".jpg")
319
+ if n + 1 < batch_size:
320
+ x[n] = resized_face
321
+ n += 1
322
+ else:
323
+ pass
324
+ if n > 0:
325
+ if device == 'cpu':
326
+ x = torch.tensor(x, device='cpu').float()
327
+ else:
328
+ x = torch.tensor(x, device="cuda").float()
329
+ # Preprocess the images.
330
+ x = x.permute((0, 3, 1, 2))
331
+ for i in range(len(x)):
332
+ x[i] = normalize_transform(x[i] / 255.)
333
+ # Make a prediction, then take the average.
334
+ with torch.no_grad():
335
+ preds = []
336
+ models_ = [models]
337
+ for model in models_:
338
+ if device == 'cpu':
339
+ y_pred = model(x[:n])
340
+ else:
341
+ y_pred = model(x[:n].half())
342
+ y_pred = torch.sigmoid(y_pred.squeeze())
343
+ bpred = y_pred[:n].cpu().numpy()
344
+ preds.append(strategy(bpred))
345
+ return np.mean(preds)
346
+ except Exception as e:
347
+ print("Prediction error on video %s: %s" % (video_path, str(e)))
348
+
349
+ return 0.5
350
+
351
+
352
+ def predict_on_video_set(face_extractor, videos, input_size, num_workers, test_dir, frames_per_video, models,
353
+ strategy=np.mean,
354
+ apply_compression=False):
355
+ def process_file(i):
356
+ filename = videos[i]
357
+ y_pred = predict_on_video(face_extractor=face_extractor, video_path=os.path.join(test_dir, filename),
358
+ input_size=input_size,
359
+ batch_size=frames_per_video,
360
+ models=models, strategy=strategy, apply_compression=apply_compression)
361
+ return y_pred
362
+
363
+ with ThreadPoolExecutor(max_workers=num_workers) as ex:
364
+ predictions = ex.map(process_file, range(len(videos)))
365
+ return list(predictions)
366
+
libs/shape_predictor_68_face_landmarks.dat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f
3
+ size 99693937
requirements.txt ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ albumentations==1.3.0
5
+ altair==5.0.0
6
+ anyio==3.6.2
7
+ anykeystore==0.2
8
+ apex
9
+ appdirs==1.4.4
10
+ async-timeout==4.0.2
11
+ attrs==23.1.0
12
+ certifi==2022.12.7
13
+ charset-normalizer==2.1.1
14
+ click==8.1.3
15
+ cmake==3.25.0
16
+ contourpy==1.0.7
17
+ cryptacular==1.6.2
18
+ cycler==0.11.0
19
+ defusedxml==0.7.1
20
+ dlib==19.24.1
21
+ docker-pycreds==0.4.0
22
+ facenet-pytorch==2.5.3
23
+ fastapi==0.95.1
24
+ ffmpy==0.3.0
25
+ filelock==3.9.0
26
+ fonttools==4.39.4
27
+ frozenlist==1.3.3
28
+ fsspec==2023.5.0
29
+ gitdb==4.0.10
30
+ GitPython==3.1.31
31
+ gradio==3.30.0
32
+ gradio_client==0.2.4
33
+ greenlet==2.0.2
34
+ h11==0.14.0
35
+ httpcore==0.17.0
36
+ httpx==0.24.0
37
+ huggingface-hub==0.14.1
38
+ hupper==1.12
39
+ idna==3.4
40
+ imageio==2.28.1
41
+ Jinja2==3.1.2
42
+ joblib==1.2.0
43
+ jsonschema==4.17.3
44
+ kiwisolver==1.4.4
45
+ lazy_loader==0.2
46
+ linkify-it-py==2.0.2
47
+ lit==15.0.7
48
+ markdown-it-py==2.2.0
49
+ MarkupSafe==2.1.2
50
+ matplotlib==3.7.1
51
+ mdit-py-plugins==0.3.3
52
+ mdurl==0.1.2
53
+ mpmath==1.2.1
54
+ multidict==6.0.4
55
+ networkx==3.0
56
+ numpy==1.24.3
57
+ oauthlib==3.2.2
58
+ opencv-python==4.7.0.72
59
+ opencv-python-headless==4.7.0.72
60
+ orjson==3.8.12
61
+ packaging==23.0
62
+ pandas==2.0.1
63
+ PasteDeploy==3.0.1
64
+ pathtools==0.1.2
65
+ pbkdf2==1.3
66
+ pep517==0.13.0
67
+ Pillow==9.3.0
68
+ plaster==1.1.2
69
+ plaster-pastedeploy==1.0.1
70
+ protobuf==3.20.3
71
+ psutil==5.9.5
72
+ pydantic==1.10.7
73
+ pydub==0.25.1
74
+ Pygments==2.15.1
75
+ pyparsing==3.0.9
76
+ pyramid==2.0.1
77
+ pyramid-mailer==0.15.1
78
+ pyrsistent==0.19.3
79
+ python-dateutil==2.8.2
80
+ python-multipart==0.0.6
81
+ python3-openid==3.2.0
82
+ pytorch-toolbelt==0.6.3
83
+ pytz==2023.3
84
+ PyWavelets==1.4.1
85
+ PyYAML==6.0
86
+ qudida==0.0.4
87
+ repoze.sendmail==4.4.1
88
+ requests==2.28.1
89
+ requests-oauthlib==1.3.1
90
+ scikit-image==0.20.0
91
+ scikit-learn==1.2.2
92
+ scipy==1.9.0
93
+ semantic-version==2.10.0
94
+ sentry-sdk==1.22.2
95
+ setproctitle==1.3.2
96
+ sh==1.14.3
97
+ six==1.16.0
98
+ smmap==5.0.0
99
+ sniffio==1.3.0
100
+ SQLAlchemy==1.4.48
101
+ starlette==0.26.1
102
+ sympy==1.11.1
103
+ tensorboardX==2.6
104
+ threadpoolctl==3.1.0
105
+ tifffile==2023.4.12
106
+ timm==0.6.13
107
+ toml==0.10.2
108
+ tomli==2.0.1
109
+ toolz==0.12.0
110
+ torch==2.0.1
111
+ torchvision==0.15.2
112
+ tqdm==4.65.0
113
+ transaction==3.1.0
114
+ translationstring==1.4
115
+ triton==2.0.0
116
+ typing_extensions==4.4.0
117
+ tzdata==2023.3
118
+ uc-micro-py==1.0.2
119
+ urllib3==1.26.13
120
+ uvicorn==0.22.0
121
+ velruse==1.1.1
122
+ venusian==3.0.0
123
+ wandb==0.15.2
124
+ WebOb==1.8.7
125
+ websockets==11.0.3
126
+ WTForms==3.0.1
127
+ wtforms-recaptcha==0.3.2
128
+ yarl==1.9.2
129
+ zope.deprecation==5.0
130
+ zope.interface==6.0
131
+ zope.sqlalchemy==2.0
training/__init__.py ADDED
File without changes
training/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (148 Bytes). View file
 
training/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (146 Bytes). View file
 
training/__pycache__/losses.cpython-310.pyc ADDED
Binary file (1.54 kB). View file
 
training/__pycache__/losses.cpython-39.pyc ADDED
Binary file (1.53 kB). View file
 
training/datasets/__init__.py ADDED
File without changes
training/datasets/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (157 Bytes). View file
 
training/datasets/__pycache__/classifier_dataset.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
training/datasets/__pycache__/validation_set.cpython-310.pyc ADDED
Binary file (4.99 kB). View file
 
training/datasets/classifier_dataset.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import sys
5
+ import traceback
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import pandas as pd
10
+ import skimage.draw
11
+ from albumentations import ImageCompression, OneOf, GaussianBlur, Blur
12
+ from albumentations.augmentations.functional import image_compression
13
+ from albumentations.augmentations.geometric.functional import rot90
14
+ from albumentations.pytorch.functional import img_to_tensor
15
+ from scipy.ndimage import binary_erosion, binary_dilation
16
+ from skimage import measure
17
+ from torch.utils.data import Dataset
18
+ import dlib
19
+
20
+ from training.datasets.validation_set import PUBLIC_SET
21
+
22
+
23
+ def prepare_bit_masks(mask):
24
+ h, w = mask.shape
25
+ mid_w = w // 2
26
+ mid_h = w // 2
27
+ masks = []
28
+ ones = np.ones_like(mask)
29
+ ones[:mid_h] = 0
30
+ masks.append(ones)
31
+ ones = np.ones_like(mask)
32
+ ones[mid_h:] = 0
33
+ masks.append(ones)
34
+ ones = np.ones_like(mask)
35
+ ones[:, :mid_w] = 0
36
+ masks.append(ones)
37
+ ones = np.ones_like(mask)
38
+ ones[:, mid_w:] = 0
39
+ masks.append(ones)
40
+ ones = np.ones_like(mask)
41
+ ones[:mid_h, :mid_w] = 0
42
+ ones[mid_h:, mid_w:] = 0
43
+ masks.append(ones)
44
+ ones = np.ones_like(mask)
45
+ ones[:mid_h, mid_w:] = 0
46
+ ones[mid_h:, :mid_w] = 0
47
+ masks.append(ones)
48
+ return masks
49
+
50
+
51
+ detector = dlib.get_frontal_face_detector()
52
+ predictor = dlib.shape_predictor('libs/shape_predictor_68_face_landmarks.dat')
53
+
54
+
55
+ def blackout_convex_hull(img):
56
+ try:
57
+ rect = detector(img)[0]
58
+ sp = predictor(img, rect)
59
+ landmarks = np.array([[p.x, p.y] for p in sp.parts()])
60
+ outline = landmarks[[*range(17), *range(26, 16, -1)]]
61
+ Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0])
62
+ cropped_img = np.zeros(img.shape[:2], dtype=np.uint8)
63
+ cropped_img[Y, X] = 1
64
+ # if random.random() > 0.5:
65
+ # img[cropped_img == 0] = 0
66
+ # #leave only face
67
+ # return img
68
+
69
+ y, x = measure.centroid(cropped_img)
70
+ y = int(y)
71
+ x = int(x)
72
+ first = random.random() > 0.5
73
+ if random.random() > 0.5:
74
+ if first:
75
+ cropped_img[:y, :] = 0
76
+ else:
77
+ cropped_img[y:, :] = 0
78
+ else:
79
+ if first:
80
+ cropped_img[:, :x] = 0
81
+ else:
82
+ cropped_img[:, x:] = 0
83
+
84
+ img[cropped_img > 0] = 0
85
+ except Exception as e:
86
+ pass
87
+
88
+
89
+ def dist(p1, p2):
90
+ return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2)
91
+
92
+
93
+ def remove_eyes(image, landmarks):
94
+ image = image.copy()
95
+ (x1, y1), (x2, y2) = landmarks[:2]
96
+ mask = np.zeros_like(image[..., 0])
97
+ line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
98
+ w = dist((x1, y1), (x2, y2))
99
+ dilation = int(w // 4)
100
+ line = binary_dilation(line, iterations=dilation)
101
+ image[line, :] = 0
102
+ return image
103
+
104
+
105
+ def remove_nose(image, landmarks):
106
+ image = image.copy()
107
+ (x1, y1), (x2, y2) = landmarks[:2]
108
+ x3, y3 = landmarks[2]
109
+ mask = np.zeros_like(image[..., 0])
110
+ x4 = int((x1 + x2) / 2)
111
+ y4 = int((y1 + y2) / 2)
112
+ line = cv2.line(mask, (x3, y3), (x4, y4), color=(1), thickness=2)
113
+ w = dist((x1, y1), (x2, y2))
114
+ dilation = int(w // 4)
115
+ line = binary_dilation(line, iterations=dilation)
116
+ image[line, :] = 0
117
+ return image
118
+
119
+
120
+ def remove_mouth(image, landmarks):
121
+ image = image.copy()
122
+ (x1, y1), (x2, y2) = landmarks[-2:]
123
+ mask = np.zeros_like(image[..., 0])
124
+ line = cv2.line(mask, (x1, y1), (x2, y2), color=(1), thickness=2)
125
+ w = dist((x1, y1), (x2, y2))
126
+ dilation = int(w // 3)
127
+ line = binary_dilation(line, iterations=dilation)
128
+ image[line, :] = 0
129
+ return image
130
+
131
+
132
+ def remove_landmark(image, landmarks):
133
+ if random.random() > 0.5:
134
+ image = remove_eyes(image, landmarks)
135
+ elif random.random() > 0.5:
136
+ image = remove_mouth(image, landmarks)
137
+ elif random.random() > 0.5:
138
+ image = remove_nose(image, landmarks)
139
+ return image
140
+
141
+
142
+ def change_padding(image, part=5):
143
+ h, w = image.shape[:2]
144
+ # original padding was done with 1/3 from each side, too much
145
+ pad_h = int(((3 / 5) * h) / part)
146
+ pad_w = int(((3 / 5) * w) / part)
147
+ image = image[h // 5 - pad_h:-h // 5 + pad_h, w // 5 - pad_w:-w // 5 + pad_w]
148
+ return image
149
+
150
+
151
+ def blackout_random(image, mask, label):
152
+ binary_mask = mask > 0.4 * 255
153
+ h, w = binary_mask.shape[:2]
154
+
155
+ tries = 50
156
+ current_try = 1
157
+ while current_try < tries:
158
+ first = random.random() < 0.5
159
+ if random.random() < 0.5:
160
+ pivot = random.randint(h // 2 - h // 5, h // 2 + h // 5)
161
+ bitmap_msk = np.ones_like(binary_mask)
162
+ if first:
163
+ bitmap_msk[:pivot, :] = 0
164
+ else:
165
+ bitmap_msk[pivot:, :] = 0
166
+ else:
167
+ pivot = random.randint(w // 2 - w // 5, w // 2 + w // 5)
168
+ bitmap_msk = np.ones_like(binary_mask)
169
+ if first:
170
+ bitmap_msk[:, :pivot] = 0
171
+ else:
172
+ bitmap_msk[:, pivot:] = 0
173
+
174
+ if label < 0.5 and np.count_nonzero(image * np.expand_dims(bitmap_msk, axis=-1)) / 3 > (h * w) / 5 \
175
+ or np.count_nonzero(binary_mask * bitmap_msk) > 40:
176
+ mask *= bitmap_msk
177
+ image *= np.expand_dims(bitmap_msk, axis=-1)
178
+ break
179
+ current_try += 1
180
+ return image
181
+
182
+
183
+ def blend_original(img):
184
+ img = img.copy()
185
+ h, w = img.shape[:2]
186
+ rect = detector(img)
187
+ if len(rect) == 0:
188
+ return img
189
+ else:
190
+ rect = rect[0]
191
+ sp = predictor(img, rect)
192
+ landmarks = np.array([[p.x, p.y] for p in sp.parts()])
193
+ outline = landmarks[[*range(17), *range(26, 16, -1)]]
194
+ Y, X = skimage.draw.polygon(outline[:, 1], outline[:, 0])
195
+ raw_mask = np.zeros(img.shape[:2], dtype=np.uint8)
196
+ raw_mask[Y, X] = 1
197
+ face = img * np.expand_dims(raw_mask, -1)
198
+
199
+ # add warping
200
+ h1 = random.randint(h - h // 2, h + h // 2)
201
+ w1 = random.randint(w - w // 2, w + w // 2)
202
+ while abs(h1 - h) < h // 3 and abs(w1 - w) < w // 3:
203
+ h1 = random.randint(h - h // 2, h + h // 2)
204
+ w1 = random.randint(w - w // 2, w + w // 2)
205
+ face = cv2.resize(face, (w1, h1), interpolation=random.choice([cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC]))
206
+ face = cv2.resize(face, (w, h), interpolation=random.choice([cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC]))
207
+
208
+ raw_mask = binary_erosion(raw_mask, iterations=random.randint(4, 10))
209
+ img[raw_mask, :] = face[raw_mask, :]
210
+ if random.random() < 0.2:
211
+ img = OneOf([GaussianBlur(), Blur()], p=0.5)(image=img)["image"]
212
+ # image compression
213
+ if random.random() < 0.5:
214
+ img = ImageCompression(quality_lower=40, quality_upper=95)(image=img)["image"]
215
+ return img
216
+
217
+
218
+ class DeepFakeClassifierDataset(Dataset):
219
+
220
+ def __init__(self,
221
+ data_path="/mnt/sota/datasets/deepfake",
222
+ fold=0,
223
+ label_smoothing=0.01,
224
+ padding_part=3,
225
+ hardcore=True,
226
+ crops_dir="crops",
227
+ folds_csv="folds.csv",
228
+ normalize={"mean": [0.485, 0.456, 0.406],
229
+ "std": [0.229, 0.224, 0.225]},
230
+ rotation=False,
231
+ mode="train",
232
+ reduce_val=True,
233
+ oversample_real=True,
234
+ transforms=None
235
+ ):
236
+ super().__init__()
237
+ self.data_root = data_path
238
+ self.fold = fold
239
+ self.folds_csv = folds_csv
240
+ self.mode = mode
241
+ self.rotation = rotation
242
+ self.padding_part = padding_part
243
+ self.hardcore = hardcore
244
+ self.crops_dir = crops_dir
245
+ self.label_smoothing = label_smoothing
246
+ self.normalize = normalize
247
+ self.transforms = transforms
248
+ self.df = pd.read_csv(self.folds_csv)
249
+ self.oversample_real = oversample_real
250
+ self.reduce_val = reduce_val
251
+
252
+ def __getitem__(self, index: int):
253
+
254
+ while True:
255
+ video, img_file, label, ori_video, frame, fold = self.data[index]
256
+ try:
257
+ if self.mode == "train":
258
+ label = np.clip(label, self.label_smoothing, 1 - self.label_smoothing)
259
+ img_path = os.path.join(self.data_root, self.crops_dir, video, img_file)
260
+ image = cv2.imread(img_path, cv2.IMREAD_COLOR)
261
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
262
+ mask = np.zeros(image.shape[:2], dtype=np.uint8)
263
+ diff_path = os.path.join(self.data_root, "diffs", video, img_file[:-4] + "_diff.png")
264
+ try:
265
+ msk = cv2.imread(diff_path, cv2.IMREAD_GRAYSCALE)
266
+ if msk is not None:
267
+ mask = msk
268
+ except:
269
+ print("not found mask", diff_path)
270
+ pass
271
+ if self.mode == "train" and self.hardcore and not self.rotation:
272
+ landmark_path = os.path.join(self.data_root, "landmarks", ori_video, img_file[:-4] + ".npy")
273
+ if os.path.exists(landmark_path) and random.random() < 0.7:
274
+ landmarks = np.load(landmark_path)
275
+ image = remove_landmark(image, landmarks)
276
+ elif random.random() < 0.2:
277
+ blackout_convex_hull(image)
278
+ elif random.random() < 0.1:
279
+ binary_mask = mask > 0.4 * 255
280
+ masks = prepare_bit_masks((binary_mask * 1).astype(np.uint8))
281
+ tries = 6
282
+ current_try = 1
283
+ while current_try < tries:
284
+ bitmap_msk = random.choice(masks)
285
+ if label < 0.5 or np.count_nonzero(mask * bitmap_msk) > 20:
286
+ mask *= bitmap_msk
287
+ image *= np.expand_dims(bitmap_msk, axis=-1)
288
+ break
289
+ current_try += 1
290
+ if self.mode == "train" and self.padding_part > 3:
291
+ image = change_padding(image, self.padding_part)
292
+ valid_label = np.count_nonzero(mask[mask > 20]) > 32 or label < 0.5
293
+ valid_label = 1 if valid_label else 0
294
+ rotation = 0
295
+ if self.transforms:
296
+ data = self.transforms(image=image, mask=mask)
297
+ image = data["image"]
298
+ mask = data["mask"]
299
+ if self.mode == "train" and self.hardcore and self.rotation:
300
+ # landmark_path = os.path.join(self.data_root, "landmarks", ori_video, img_file[:-4] + ".npy")
301
+ dropout = 0.8 if label > 0.5 else 0.6
302
+ if self.rotation:
303
+ dropout *= 0.7
304
+ elif random.random() < dropout:
305
+ blackout_random(image, mask, label)
306
+
307
+ #
308
+ # os.makedirs("../images", exist_ok=True)
309
+ # cv2.imwrite(os.path.join("../images", video+ "_" + str(1 if label > 0.5 else 0) + "_"+img_file), image[...,::-1])
310
+
311
+ if self.mode == "train" and self.rotation:
312
+ rotation = random.randint(0, 3)
313
+ image = rot90(image, rotation)
314
+
315
+ image = img_to_tensor(image, self.normalize)
316
+ return {"image": image, "labels": np.array((label,)), "img_name": os.path.join(video, img_file),
317
+ "valid": valid_label, "rotations": rotation}
318
+ except Exception as e:
319
+ traceback.print_exc(file=sys.stdout)
320
+ print("Broken image", os.path.join(self.data_root, self.crops_dir, video, img_file))
321
+ index = random.randint(0, len(self.data) - 1)
322
+
323
+ def random_blackout_landmark(self, image, mask, landmarks):
324
+ x, y = random.choice(landmarks)
325
+ first = random.random() > 0.5
326
+ # crop half face either vertically or horizontally
327
+ if random.random() > 0.5:
328
+ # width
329
+ if first:
330
+ image[:, :x] = 0
331
+ mask[:, :x] = 0
332
+ else:
333
+ image[:, x:] = 0
334
+ mask[:, x:] = 0
335
+ else:
336
+ # height
337
+ if first:
338
+ image[:y, :] = 0
339
+ mask[:y, :] = 0
340
+ else:
341
+ image[y:, :] = 0
342
+ mask[y:, :] = 0
343
+
344
+ def reset(self, epoch, seed):
345
+ self.data = self._prepare_data(epoch, seed)
346
+
347
+ def __len__(self) -> int:
348
+ return len(self.data)
349
+
350
+ def get_distribution(self):
351
+ return self.n_real, self.n_fake
352
+
353
+ def _prepare_data(self, epoch, seed):
354
+ df = self.df
355
+ if self.mode == "train":
356
+ rows = df[df["fold"] != self.fold]
357
+ else:
358
+ rows = df[df["fold"] == self.fold]
359
+ seed = (epoch + 1) * seed
360
+ if self.oversample_real:
361
+ rows = self._oversample(rows, seed)
362
+ if self.mode == "val" and self.reduce_val:
363
+ # every 2nd frame, to speed up validation
364
+ rows = rows[rows["frame"] % 20 == 0]
365
+ # another option is to use public validation set
366
+ #rows = rows[rows["video"].isin(PUBLIC_SET)]
367
+
368
+ print(
369
+ "real {} fakes {} mode {}".format(len(rows[rows["label"] == 0]), len(rows[rows["label"] == 1]), self.mode))
370
+ data = rows.values
371
+
372
+ self.n_real = len(rows[rows["label"] == 0])
373
+ self.n_fake = len(rows[rows["label"] == 1])
374
+ np.random.seed(seed)
375
+ np.random.shuffle(data)
376
+ return data
377
+
378
+ def _oversample(self, rows: pd.DataFrame, seed):
379
+ real = rows[rows["label"] == 0]
380
+ fakes = rows[rows["label"] == 1]
381
+ num_real = real["video"].count()
382
+ if self.mode == "train":
383
+ fakes = fakes.sample(n=num_real, replace=False, random_state=seed)
384
+ return pd.concat([real, fakes])
training/datasets/validation_set.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ PUBLIC_SET = {'tjuihawuqm', 'prwsfljdjo', 'scrbqgpvzz', 'ziipxxchai', 'uubgqnvfdl', 'wclvkepakb', 'xjvxtuakyd',
4
+ 'qlvsqdroqo', 'bcbqxhziqz', 'yzuestxcbq', 'hxwtsaydal', 'kqlvggiqee', 'vtunvalyji', 'mohiqoogpb',
5
+ 'siebfpwuhu', 'cekwtyxdoo', 'hszwwswewp', 'orekjthsef', 'huvlwkxoxm', 'fmhiujydwo', 'lhvjzhjxdp',
6
+ 'ibxfxggtqh', 'bofrwgeyjo', 'rmufsuogzn', 'zbgssotnjm', 'dpevefkefv', 'sufvvwmbha', 'ncoeewrdlo',
7
+ 'qhsehzgxqj', 'yxadevzohx', 'aomqqjipcp', 'pcyswtgick', 'wfzjxzhdkj', 'rcjfxxhcal', 'lnjkpdviqb',
8
+ 'xmkwsnuzyq', 'ouaowjmigq', 'bkuzquigyt', 'vwxednhlwz', 'mszblrdprw', 'blnmxntbey', 'gccnvdoknm',
9
+ 'mkzaekkvej', 'hclsparpth', 'eryjktdexi', 'hfsvqabzfq', 'acazlolrpz', 'yoyhmxtrys', 'rerpivllud',
10
+ 'elackxuccp', 'zgbhzkditd', 'vjljdfopjg', 'famlupsgqm', 'nymodlmxni', 'qcbkztamqc', 'qclpbcbgeq',
11
+ 'lpkgabskbw', 'mnowxangqx', 'czfqlbcfpa', 'qyyhuvqmyf', 'toinozytsp', 'ztyvglkcsf', 'nplviymzlg',
12
+ 'opvqdabdap', 'uxuvkrjhws', 'mxahsihabr', 'cqxxumarvp', 'ptbfnkajyi', 'njzshtfmcw', 'dcqodpzomd',
13
+ 'ajiyrjfyzp', 'ywauoonmlr', 'gochxzemmq', 'lpgxwdgnio', 'hnfwagcxdf', 'gfcycflhbo', 'gunamloolc',
14
+ 'yhjlnisfel', 'srfefmyjvt', 'evysmtpnrf', 'aktnlyqpah', 'gpsxfxrjrr', 'zfobicuigx', 'mnzabbkpmt',
15
+ 'rfjuhbnlro', 'zuwwbbusgl', 'csnkohqxdv', 'bzvzpwrabw', 'yietrwuncf', 'wynotylpnm', 'ekboxwrwuv',
16
+ 'rcecrgeotc', 'rklawjhbpv', 'ilqwcbprqa', 'jsysgmycsx', 'sqixhnilfm', 'wnlubukrki', 'nikynwcvuh',
17
+ 'sjkfxrlxxs', 'btdxnajogv', 'wjhpisoeaj', 'dyjklprkoc', 'qlqhjcshpk', 'jyfvaequfg', 'dozjwhnedd',
18
+ 'owaogcehvc', 'oyqgwjdwaj', 'vvfszaosiv', 'kmcdjxmnoa', 'jiswxuqzyz', 'ddtbarpcgo', 'wqysrieiqu',
19
+ 'xcruhaccxc', 'honxqdilvv', 'nxgzmgzkfv', 'cxsvvnxpyz', 'demuhxssgl', 'hzoiotcykp', 'fwykevubzy',
20
+ 'tejfudfgpq', 'kvmpmhdxly', 'oojxonbgow', 'vurjckblge', 'oysopgovhu', 'khpipxnsvx', 'pqthmvwonf',
21
+ 'fddmkqjwsh', 'pcoxcmtroa', 'cnxccbjlct', 'ggzjfrirjh', 'jquevmhdvc', 'ecumyiowzs', 'esmqxszybs',
22
+ 'mllzkpgatp', 'ryxaqpfubf', 'hbufmvbium', 'vdtsbqidjb', 'sjwywglgym', 'qxyrtwozyw', 'upmgtackuf',
23
+ 'ucthmsajay', 'zgjosltkie', 'snlyjbnpgw', 'nswtvttxre', 'iznnzjvaxc', 'jhczqfefgw', 'htzbnroagi',
24
+ 'pdswwyyntw', 'uvrzaczrbx', 'vbcgoyxsvn', 'hzssdinxec', 'novarhxpbj', 'vizerpsvbz', 'jawgcggquk',
25
+ 'iorbtaarte', 'yarpxfqejd', 'vhbbwdflyh', 'rrrfjhugvb', 'fneqiqpqvs', 'jytrvwlewz', 'bfjsthfhbd',
26
+ 'rxdoimqble', 'ekelfsnqof', 'uqvxjfpwdo', 'cjkctqqakb', 'tynfsthodx', 'yllztsrwjw', 'bktkwbcawi',
27
+ 'wcqvzujamg', 'bcvheslzrq', 'aqrsylrzgi', 'sktpeppbkc', 'mkmgcxaztt', 'etdliwticv', 'hqzwudvhih',
28
+ 'swsaoktwgi', 'temjefwaas', 'papagllumt', 'xrtvqhdibb', 'oelqpetgwj', 'ggdpclfcgk', 'imdmhwkkni',
29
+ 'lebzjtusnr', 'xhtppuyqdr', 'nxzgekegsp', 'waucvvmtkq', 'rnfcjxynfa', 'adohdulfwb', 'tjywwgftmv',
30
+ 'fjrueenjyp', 'oaguiggjyv', 'ytopzxrswu', 'yxvmusxvcz', 'rukyxomwcx', 'qdqdsaiitt', 'mxlipjhmqk',
31
+ 'voawxrmqyl', 'kezwvsxxzj', 'oocincvedt', 'qooxnxqqjb', 'mwwploizlj', 'yaxgpxhavq', 'uhakqelqri',
32
+ 'bvpeerislp', 'bkcyglmfci', 'jyoxdvxpza', 'gkutjglghz', 'knxltsvzyu', 'ybbrkacebd', 'apvzjkvnwn',
33
+ 'ahjnxtiamx', 'hsbljbsgxr', 'fnxgqcvlsd', 'xphdfgmfmz', 'scbdenmaed', 'ywxpquomgt', 'yljecirelf',
34
+ 'wcvsqnplsk', 'vmxfwxgdei', 'icbsahlivv', 'yhylappzid', 'irqzdokcws', 'petmyhjclt', 'rmlzgerevr',
35
+ 'qarqtkvgby', 'nkhzxomani', 'viteugozpv', 'qhkzlnzruj', 'eisofhptvk', 'gqnaxievjx', 'heiyoojifp',
36
+ 'zcxcmneefk', 'wvgviwnwob', 'gcdtglsoqj', 'yqhouqakbx', 'fopjiyxiqd', 'hierggamuo', 'ypbtpunjvm',
37
+ 'sjinmmbipg', 'kmqkiihrmj', 'wmoqzxddkb', 'lnhkjhyhvw', 'wixbuuzygv', 'fsdrwikhge', 'sfsayjgzrh',
38
+ 'pqdeutauqc', 'frqfsucgao', 'pdufsewrec', 'bfdopzvxbi', 'shnsajrsow', 'rvvpazsffd', 'pxcfrszlgi',
39
+ 'itfsvvmslp', 'ayipraspbn', 'prhmixykhr', 'doniqevxeg', 'dvtpwatuja', 'jiavqbrkyk', 'ipkpxvwroe',
40
+ 'syxobtuucp', 'syuxttuyhm', 'nwvsbmyndn', 'eqslzbqfea', 'ytddugrwph', 'vokrpfjpeb', 'bdshuoldwx',
41
+ 'fmvvmcbdrw', 'bnuwxhfahw', 'gbnzicjyhz', 'txnmkabufs', 'gfdjzwnpyp', 'hweshqpfwe', 'dxgnpnowgk',
42
+ 'xugmhbetrw', 'rktrpsdlci', 'nthpnwylxo', 'ihglzxzroo', 'ocgdbrgmtq', 'ruhtnngrqv', 'xljemofssi',
43
+ 'zxacihctqp', 'ghnpsltzyn', 'lbigytrrtr', 'ndikguxzek', 'mdfndlljvt', 'lyoslorecs', 'oefukgnvel',
44
+ 'zmxeiipnqb', 'cosghhimnd', 'alrtntfxtd', 'eywdmustbb', 'ooafcxxfrs', 'fqgypsunzr', 'hevcclcklc',
45
+ 'uhrqlmlclw', 'ipvwtgdlre', 'wcssbghcpc', 'didzujjhtg', 'fjxovgmwnm', 'dmmvuaikkv', 'hitfycdavv',
46
+ 'zyufpqvpyu', 'coujjnypba', 'temeqbmzxu', 'apedduehoy', 'iksxzpqxzi', 'kwfdyqofzw', 'aassnaulhq',
47
+ 'eyguqfmgzh', 'yiykshcbaz', 'sngjsueuhs', 'okgelildpc', 'ztyuiqrhdk', 'tvhjcfnqtg', 'gfgcwxkbjd',
48
+ 'lbfqksftuo', 'kowiwvrjht', 'dkuqbduxev', 'mwnibuujwz', 'sodvtfqbpf', 'hsbwhlolsn', 'qsjiypnjwi',
49
+ 'blszgmxkvu', 'ystdtnetgj', 'rfwxcinshk', 'vnlzxqwthl', 'ljouzjaqqe', 'gahgyuwzbu', 'xxzefxwyku',
50
+ 'xitgdpzbxv', 'sylnrepacf', 'igpvrfjdzc', 'nxnmkytwze', 'psesikjaxx', 'dvwpvqdflx', 'bjyaxvggle',
51
+ 'dpmgoiwhuf', 'wadvzjhwtw', 'kcjvhgvhpt', 'eppyqpgewp', 'tyjpjpglgx', 'cekarydqba', 'dvkdfhrpph',
52
+ 'cnpanmywno', 'ljauauuyka', 'hicjuubiau', 'cqhwesrciw', 'dnmowthjcj', 'lujvyveojc', 'wndursivcx',
53
+ 'espkiocpxq', 'jsbpkpxwew', 'dsnxgrfdmd', 'hyjqolupxn', 'xdezcezszc', 'axfhbpkdlc', 'qqnlrngaft',
54
+ 'coqwgzpbhx', 'ncmpqwmnzb', 'sznkemeqro', 'omphqltjdd', 'uoccaiathd', 'jzmzdispyo', 'pxjkzvqomp',
55
+ 'udxqbhgvvx', 'dzkyxbbqkr', 'dtozwcapoa', 'qswlzfgcgj', 'tgawasvbbr', 'lmdyicksrv', 'fzvpbrzssi',
56
+ 'dxfdovivlw', 'zzmgnglanj', 'vssmlqoiti', 'vajkicalux', 'ekvwecwltj', 'ylxwcwhjjd', 'keioymnobc',
57
+ 'usqqvxcjmg', 'phjvutxpoi', 'nycmyuzpml', 'bwdmzwhdnw', 'fxuxxtryjn', 'orixbcfvdz', 'hefisnapds',
58
+ 'fpevfidstw', 'halvwiltfs', 'dzojiwfvba', 'ojsxxkalat', 'esjdyghhog', 'ptbnewtvon', 'hcanfkwivl',
59
+ 'yronlutbgm', 'llplvmcvbl', 'yxirnfyijn', 'nwvloufjty', 'rtpbawlmxr', 'aayfryxljh', 'zfrrixsimm',
60
+ 'txmnoyiyte'}
training/losses.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from pytorch_toolbelt.losses import BinaryFocalLoss
4
+ from torch import nn
5
+ from torch.nn.modules.loss import BCEWithLogitsLoss
6
+
7
+
8
+ class WeightedLosses(nn.Module):
9
+ def __init__(self, losses, weights):
10
+ super().__init__()
11
+ self.losses = losses
12
+ self.weights = weights
13
+
14
+ def forward(self, *input: Any, **kwargs: Any):
15
+ cum_loss = 0
16
+ for loss, w in zip(self.losses, self.weights):
17
+ cum_loss += w * loss.forward(*input, **kwargs)
18
+ return cum_loss
19
+
20
+
21
+ class BinaryCrossentropy(BCEWithLogitsLoss):
22
+ pass
23
+
24
+
25
+ class FocalLoss(BinaryFocalLoss):
26
+ def __init__(self, alpha=None, gamma=3, ignore_index=None, reduction="mean", normalized=False,
27
+ reduced_threshold=None):
28
+ super().__init__(alpha, gamma, ignore_index, reduction, normalized, reduced_threshold)
training/pipelines/__init__.py ADDED
File without changes
training/pipelines/train_classifier.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from collections import defaultdict
5
+
6
+ from sklearn.metrics import log_loss
7
+ from torch import topk
8
+
9
+ import sys
10
+ print('@@@@@@@@@@@@@@@@@@')
11
+ sys.path.append('..')
12
+
13
+ from training import losses
14
+ from training.datasets.classifier_dataset import DeepFakeClassifierDataset
15
+ from training.losses import WeightedLosses
16
+ from training.tools.config import load_config
17
+ from training.tools.utils import create_optimizer, AverageMeter
18
+ from training.transforms.albu import IsotropicResize
19
+ from training.zoo import classifiers
20
+
21
+ os.environ["MKL_NUM_THREADS"] = "1"
22
+ os.environ["NUMEXPR_NUM_THREADS"] = "1"
23
+ os.environ["OMP_NUM_THREADS"] = "1"
24
+
25
+ import cv2
26
+
27
+ cv2.ocl.setUseOpenCL(False)
28
+ cv2.setNumThreads(0)
29
+ import numpy as np
30
+ from albumentations import Compose, RandomBrightnessContrast, \
31
+ HorizontalFlip, FancyPCA, HueSaturationValue, OneOf, ToGray, \
32
+ ShiftScaleRotate, ImageCompression, PadIfNeeded, GaussNoise, GaussianBlur
33
+
34
+ from apex.parallel import DistributedDataParallel, convert_syncbn_model
35
+ from tensorboardX import SummaryWriter
36
+
37
+ from apex import amp
38
+
39
+ import torch
40
+ from torch.backends import cudnn
41
+ from torch.nn import DataParallel
42
+ from torch.utils.data import DataLoader
43
+ from tqdm import tqdm
44
+ import torch.distributed as dist
45
+
46
+ torch.backends.cudnn.benchmark = True
47
+
48
+ def create_train_transforms(size=300):
49
+ return Compose([
50
+ ImageCompression(quality_lower=60, quality_upper=100, p=0.5),
51
+ GaussNoise(p=0.1),
52
+ GaussianBlur(blur_limit=3, p=0.05),
53
+ HorizontalFlip(),
54
+ OneOf([
55
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
56
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_LINEAR),
57
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_LINEAR, interpolation_up=cv2.INTER_LINEAR),
58
+ ], p=1),
59
+ PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
60
+ OneOf([RandomBrightnessContrast(), FancyPCA(), HueSaturationValue()], p=0.7),
61
+ ToGray(p=0.2),
62
+ ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=10, border_mode=cv2.BORDER_CONSTANT, p=0.5),
63
+ ]
64
+ )
65
+
66
+
67
+ def create_val_transforms(size=300):
68
+ return Compose([
69
+ IsotropicResize(max_side=size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC),
70
+ PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT),
71
+ ])
72
+
73
+
74
+ def main():
75
+ parser = argparse.ArgumentParser("PyTorch Xview Pipeline")
76
+ arg = parser.add_argument
77
+ arg('--config', metavar='CONFIG_FILE', help='path to configuration file')
78
+ arg('--workers', type=int, default=6, help='number of cpu threads to use')
79
+ arg('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3')
80
+ arg('--output-dir', type=str, default='weights/')
81
+ arg('--resume', type=str, default='')
82
+ arg('--fold', type=int, default=0)
83
+ arg('--prefix', type=str, default='classifier_')
84
+ arg('--data-dir', type=str, default="/mnt/sota/datasets/deepfake")
85
+ arg('--folds-csv', type=str, default='folds.csv')
86
+ arg('--crops-dir', type=str, default='crops')
87
+ arg('--label-smoothing', type=float, default=0.01)
88
+ arg('--logdir', type=str, default='logs')
89
+ arg('--zero-score', action='store_true', default=False)
90
+ arg('--from-zero', action='store_true', default=False)
91
+ arg('--distributed', action='store_true', default=False)
92
+ arg('--freeze-epochs', type=int, default=0)
93
+ arg("--local_rank", default=0, type=int)
94
+ arg("--seed", default=777, type=int)
95
+ arg("--padding-part", default=3, type=int)
96
+ arg("--opt-level", default='O1', type=str)
97
+ arg("--test_every", type=int, default=1)
98
+ arg("--no-oversample", action="store_true")
99
+ arg("--no-hardcore", action="store_true")
100
+ arg("--only-changed-frames", action="store_true")
101
+
102
+ args = parser.parse_args()
103
+ os.makedirs(args.output_dir, exist_ok=True)
104
+ if args.distributed:
105
+ torch.cuda.set_device(args.local_rank)
106
+ torch.distributed.init_process_group(backend='nccl', init_method='env://')
107
+ else:
108
+ os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
109
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
110
+
111
+ cudnn.benchmark = True
112
+
113
+ conf = load_config(args.config)
114
+ model = classifiers.__dict__[conf['network']](encoder=conf['encoder'])
115
+
116
+ model = model.cuda()
117
+ if args.distributed:
118
+ model = convert_syncbn_model(model)
119
+ ohem = conf.get("ohem_samples", None)
120
+ reduction = "mean"
121
+ if ohem:
122
+ reduction = "none"
123
+ loss_fn = []
124
+ weights = []
125
+ for loss_name, weight in conf["losses"].items():
126
+ loss_fn.append(losses.__dict__[loss_name](reduction=reduction).cuda())
127
+ weights.append(weight)
128
+ loss = WeightedLosses(loss_fn, weights)
129
+ loss_functions = {"classifier_loss": loss}
130
+ optimizer, scheduler = create_optimizer(conf['optimizer'], model)
131
+ bce_best = 100
132
+ start_epoch = 0
133
+ batch_size = conf['optimizer']['batch_size']
134
+
135
+ data_train = DeepFakeClassifierDataset(mode="train",
136
+ oversample_real=not args.no_oversample,
137
+ fold=args.fold,
138
+ padding_part=args.padding_part,
139
+ hardcore=not args.no_hardcore,
140
+ crops_dir=args.crops_dir,
141
+ data_path=args.data_dir,
142
+ label_smoothing=args.label_smoothing,
143
+ folds_csv=args.folds_csv,
144
+ transforms=create_train_transforms(conf["size"]),
145
+ normalize=conf.get("normalize", None))
146
+ data_val = DeepFakeClassifierDataset(mode="val",
147
+ fold=args.fold,
148
+ padding_part=args.padding_part,
149
+ crops_dir=args.crops_dir,
150
+ data_path=args.data_dir,
151
+ folds_csv=args.folds_csv,
152
+ transforms=create_val_transforms(conf["size"]),
153
+ normalize=conf.get("normalize", None))
154
+ val_data_loader = DataLoader(data_val, batch_size=batch_size * 2, num_workers=args.workers, shuffle=False,
155
+ pin_memory=False)
156
+ os.makedirs(args.logdir, exist_ok=True)
157
+ summary_writer = SummaryWriter(args.logdir + '/' + conf.get("prefix", args.prefix) + conf['encoder'] + "_" + str(args.fold))
158
+ if args.resume:
159
+ if os.path.isfile(args.resume):
160
+ print("=> loading checkpoint '{}'".format(args.resume))
161
+ checkpoint = torch.load(args.resume, map_location='cpu')
162
+ state_dict = checkpoint['state_dict']
163
+ state_dict = {k[7:]: w for k, w in state_dict.items()}
164
+ model.load_state_dict(state_dict, strict=False)
165
+ if not args.from_zero:
166
+ start_epoch = checkpoint['epoch']
167
+ if not args.zero_score:
168
+ bce_best = checkpoint.get('bce_best', 0)
169
+ print("=> loaded checkpoint '{}' (epoch {}, bce_best {})"
170
+ .format(args.resume, checkpoint['epoch'], checkpoint['bce_best']))
171
+ else:
172
+ print("=> no checkpoint found at '{}'".format(args.resume))
173
+ if args.from_zero:
174
+ start_epoch = 0
175
+ current_epoch = start_epoch
176
+
177
+ if conf['fp16']:
178
+ model, optimizer = amp.initialize(model, optimizer,
179
+ opt_level=args.opt_level,
180
+ loss_scale='dynamic')
181
+
182
+ snapshot_name = "{}{}_{}_{}".format(conf.get("prefix", args.prefix), conf['network'], conf['encoder'], args.fold)
183
+
184
+ if args.distributed:
185
+ model = DistributedDataParallel(model, delay_allreduce=True)
186
+ else:
187
+ model = DataParallel(model).cuda()
188
+ data_val.reset(1, args.seed)
189
+ max_epochs = conf['optimizer']['schedule']['epochs']
190
+ for epoch in range(start_epoch, max_epochs):
191
+ data_train.reset(epoch, args.seed)
192
+ train_sampler = None
193
+ if args.distributed:
194
+ train_sampler = torch.utils.data.distributed.DistributedSampler(data_train)
195
+ train_sampler.set_epoch(epoch)
196
+ if epoch < args.freeze_epochs:
197
+ print("Freezing encoder!!!")
198
+ model.module.encoder.eval()
199
+ for p in model.module.encoder.parameters():
200
+ p.requires_grad = False
201
+ else:
202
+ model.module.encoder.train()
203
+ for p in model.module.encoder.parameters():
204
+ p.requires_grad = True
205
+
206
+ train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=args.workers,
207
+ shuffle=train_sampler is None, sampler=train_sampler, pin_memory=False,
208
+ drop_last=True)
209
+
210
+ train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
211
+ args.local_rank, args.only_changed_frames)
212
+ model = model.eval()
213
+
214
+ if args.local_rank == 0:
215
+ torch.save({
216
+ 'epoch': current_epoch + 1,
217
+ 'state_dict': model.state_dict(),
218
+ 'bce_best': bce_best,
219
+ }, args.output_dir + '/' + snapshot_name + "_last")
220
+ torch.save({
221
+ 'epoch': current_epoch + 1,
222
+ 'state_dict': model.state_dict(),
223
+ 'bce_best': bce_best,
224
+ }, args.output_dir + snapshot_name + "_{}".format(current_epoch))
225
+ if (epoch + 1) % args.test_every == 0:
226
+ bce_best = evaluate_val(args, val_data_loader, bce_best, model,
227
+ snapshot_name=snapshot_name,
228
+ current_epoch=current_epoch,
229
+ summary_writer=summary_writer)
230
+ current_epoch += 1
231
+
232
+
233
+ def evaluate_val(args, data_val, bce_best, model, snapshot_name, current_epoch, summary_writer):
234
+ print("Test phase")
235
+ model = model.eval()
236
+
237
+ bce, probs, targets = validate(model, data_loader=data_val)
238
+ if args.local_rank == 0:
239
+ summary_writer.add_scalar('val/bce', float(bce), global_step=current_epoch)
240
+ if bce < bce_best:
241
+ print("Epoch {} improved from {} to {}".format(current_epoch, bce_best, bce))
242
+ if args.output_dir is not None:
243
+ torch.save({
244
+ 'epoch': current_epoch + 1,
245
+ 'state_dict': model.state_dict(),
246
+ 'bce_best': bce,
247
+ }, args.output_dir + snapshot_name + "_best_dice")
248
+ bce_best = bce
249
+ with open("predictions_{}.json".format(args.fold), "w") as f:
250
+ json.dump({"probs": probs, "targets": targets}, f)
251
+ torch.save({
252
+ 'epoch': current_epoch + 1,
253
+ 'state_dict': model.state_dict(),
254
+ 'bce_best': bce_best,
255
+ }, args.output_dir + snapshot_name + "_last")
256
+ print("Epoch: {} bce: {}, bce_best: {}".format(current_epoch, bce, bce_best))
257
+ return bce_best
258
+
259
+
260
+ def validate(net, data_loader, prefix=""):
261
+ probs = defaultdict(list)
262
+ targets = defaultdict(list)
263
+
264
+ with torch.no_grad():
265
+ for sample in tqdm(data_loader):
266
+ imgs = sample["image"].cuda()
267
+ img_names = sample["img_name"]
268
+ labels = sample["labels"].cuda().float()
269
+ out = net(imgs)
270
+ labels = labels.cpu().numpy()
271
+ preds = torch.sigmoid(out).cpu().numpy()
272
+ for i in range(out.shape[0]):
273
+ video, img_id = img_names[i].split("/")
274
+ probs[video].append(preds[i].tolist())
275
+ targets[video].append(labels[i].tolist())
276
+ data_x = []
277
+ data_y = []
278
+ for vid, score in probs.items():
279
+ score = np.array(score)
280
+ lbl = targets[vid]
281
+
282
+ score = np.mean(score)
283
+ lbl = np.mean(lbl)
284
+ data_x.append(score)
285
+ data_y.append(lbl)
286
+ y = np.array(data_y)
287
+ x = np.array(data_x)
288
+ fake_idx = y > 0.1
289
+ real_idx = y < 0.1
290
+ fake_loss = log_loss(y[fake_idx], x[fake_idx], labels=[0, 1])
291
+ real_loss = log_loss(y[real_idx], x[real_idx], labels=[0, 1])
292
+ print("{}fake_loss".format(prefix), fake_loss)
293
+ print("{}real_loss".format(prefix), real_loss)
294
+
295
+ return (fake_loss + real_loss) / 2, probs, targets
296
+
297
+
298
+ def train_epoch(current_epoch, loss_functions, model, optimizer, scheduler, train_data_loader, summary_writer, conf,
299
+ local_rank, only_valid):
300
+ losses = AverageMeter()
301
+ fake_losses = AverageMeter()
302
+ real_losses = AverageMeter()
303
+ max_iters = conf["batches_per_epoch"]
304
+ print("training epoch {}".format(current_epoch))
305
+ model.train()
306
+ pbar = tqdm(enumerate(train_data_loader), total=max_iters, desc="Epoch {}".format(current_epoch), ncols=0)
307
+ if conf["optimizer"]["schedule"]["mode"] == "epoch":
308
+ scheduler.step(current_epoch)
309
+ for i, sample in pbar:
310
+ imgs = sample["image"].cuda()
311
+ labels = sample["labels"].cuda().float()
312
+ out_labels = model(imgs)
313
+ if only_valid:
314
+ valid_idx = sample["valid"].cuda().float() > 0
315
+ out_labels = out_labels[valid_idx]
316
+ labels = labels[valid_idx]
317
+ if labels.size(0) == 0:
318
+ continue
319
+
320
+ fake_loss = 0
321
+ real_loss = 0
322
+ fake_idx = labels > 0.5
323
+ real_idx = labels <= 0.5
324
+
325
+ ohem = conf.get("ohem_samples", None)
326
+ if torch.sum(fake_idx * 1) > 0:
327
+ fake_loss = loss_functions["classifier_loss"](out_labels[fake_idx], labels[fake_idx])
328
+ if torch.sum(real_idx * 1) > 0:
329
+ real_loss = loss_functions["classifier_loss"](out_labels[real_idx], labels[real_idx])
330
+ if ohem:
331
+ fake_loss = topk(fake_loss, k=min(ohem, fake_loss.size(0)), sorted=False)[0].mean()
332
+ real_loss = topk(real_loss, k=min(ohem, real_loss.size(0)), sorted=False)[0].mean()
333
+
334
+ loss = (fake_loss + real_loss) / 2
335
+ losses.update(loss.item(), imgs.size(0))
336
+ fake_losses.update(0 if fake_loss == 0 else fake_loss.item(), imgs.size(0))
337
+ real_losses.update(0 if real_loss == 0 else real_loss.item(), imgs.size(0))
338
+
339
+ optimizer.zero_grad()
340
+ pbar.set_postfix({"lr": float(scheduler.get_lr()[-1]), "epoch": current_epoch, "loss": losses.avg,
341
+ "fake_loss": fake_losses.avg, "real_loss": real_losses.avg})
342
+
343
+ if conf['fp16']:
344
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
345
+ scaled_loss.backward()
346
+ else:
347
+ loss.backward()
348
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
349
+ optimizer.step()
350
+ torch.cuda.synchronize()
351
+ if conf["optimizer"]["schedule"]["mode"] in ("step", "poly"):
352
+ scheduler.step(i + current_epoch * max_iters)
353
+ if i == max_iters - 1:
354
+ break
355
+ pbar.close()
356
+ if local_rank == 0:
357
+ for idx, param_group in enumerate(optimizer.param_groups):
358
+ lr = param_group['lr']
359
+ summary_writer.add_scalar('group{}/lr'.format(idx), float(lr), global_step=current_epoch)
360
+ summary_writer.add_scalar('train/loss', float(losses.avg), global_step=current_epoch)
361
+
362
+
363
+ if __name__ == '__main__':
364
+ main()
training/tools/__init__.py ADDED
File without changes
training/tools/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (154 Bytes). View file
 
training/tools/__pycache__/config.cpython-310.pyc ADDED
Binary file (1.06 kB). View file
 
training/tools/__pycache__/schedulers.cpython-310.pyc ADDED
Binary file (3.01 kB). View file
 
training/tools/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.65 kB). View file
 
training/tools/config.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ DEFAULTS = {
4
+ "network": "dpn",
5
+ "encoder": "dpn92",
6
+ "model_params": {},
7
+ "optimizer": {
8
+ "batch_size": 32,
9
+ "type": "SGD", # supported: SGD, Adam
10
+ "momentum": 0.9,
11
+ "weight_decay": 0,
12
+ "clip": 1.,
13
+ "learning_rate": 0.1,
14
+ "classifier_lr": -1,
15
+ "nesterov": True,
16
+ "schedule": {
17
+ "type": "constant", # supported: constant, step, multistep, exponential, linear, poly
18
+ "mode": "epoch", # supported: epoch, step
19
+ "epochs": 10,
20
+ "params": {}
21
+ }
22
+ },
23
+ "normalize": {
24
+ "mean": [0.485, 0.456, 0.406],
25
+ "std": [0.229, 0.224, 0.225]
26
+ }
27
+ }
28
+
29
+
30
+ def _merge(src, dst):
31
+ for k, v in src.items():
32
+ if k in dst:
33
+ if isinstance(v, dict):
34
+ _merge(src[k], dst[k])
35
+ else:
36
+ dst[k] = v
37
+
38
+
39
+ def load_config(config_file, defaults=DEFAULTS):
40
+ with open(config_file, "r") as fd:
41
+ config = json.load(fd)
42
+ _merge(defaults, config)
43
+ return config
training/tools/schedulers.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bisect import bisect_right
2
+
3
+ from torch.optim.lr_scheduler import _LRScheduler
4
+
5
+
6
+ class LRStepScheduler(_LRScheduler):
7
+ def __init__(self, optimizer, steps, last_epoch=-1):
8
+ self.lr_steps = steps
9
+ super().__init__(optimizer, last_epoch)
10
+
11
+ def get_lr(self):
12
+ pos = max(bisect_right([x for x, y in self.lr_steps], self.last_epoch) - 1, 0)
13
+ return [self.lr_steps[pos][1] if self.lr_steps[pos][0] <= self.last_epoch else base_lr for base_lr in self.base_lrs]
14
+
15
+
16
+ class PolyLR(_LRScheduler):
17
+ """Sets the learning rate of each parameter group according to poly learning rate policy
18
+ """
19
+ def __init__(self, optimizer, max_iter=90000, power=0.9, last_epoch=-1):
20
+ self.max_iter = max_iter
21
+ self.power = power
22
+ super(PolyLR, self).__init__(optimizer, last_epoch)
23
+
24
+ def get_lr(self):
25
+ self.last_epoch = (self.last_epoch + 1) % self.max_iter
26
+ return [base_lr * ((1 - float(self.last_epoch) / self.max_iter) ** (self.power)) for base_lr in self.base_lrs]
27
+
28
+ class ExponentialLRScheduler(_LRScheduler):
29
+ """Decays the learning rate of each parameter group by gamma every epoch.
30
+ When last_epoch=-1, sets initial lr as lr.
31
+
32
+ Args:
33
+ optimizer (Optimizer): Wrapped optimizer.
34
+ gamma (float): Multiplicative factor of learning rate decay.
35
+ last_epoch (int): The index of last epoch. Default: -1.
36
+ """
37
+
38
+ def __init__(self, optimizer, gamma, last_epoch=-1):
39
+ self.gamma = gamma
40
+ super(ExponentialLRScheduler, self).__init__(optimizer, last_epoch)
41
+
42
+ def get_lr(self):
43
+ if self.last_epoch <= 0:
44
+ return self.base_lrs
45
+ return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]
46
+
training/tools/utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ from apex.optimizers import FusedAdam, FusedSGD
3
+ from timm.optim import AdamW
4
+ from torch import optim
5
+ from torch.optim import lr_scheduler
6
+ from torch.optim.rmsprop import RMSprop
7
+ from torch.optim.adamw import AdamW
8
+ from torch.optim.lr_scheduler import MultiStepLR, CyclicLR
9
+
10
+ from training.tools.schedulers import ExponentialLRScheduler, PolyLR, LRStepScheduler
11
+
12
+ cv2.ocl.setUseOpenCL(False)
13
+ cv2.setNumThreads(0)
14
+
15
+
16
+ class AverageMeter(object):
17
+ """Computes and stores the average and current value"""
18
+
19
+ def __init__(self):
20
+ self.reset()
21
+
22
+ def reset(self):
23
+ self.val = 0
24
+ self.avg = 0
25
+ self.sum = 0
26
+ self.count = 0
27
+
28
+ def update(self, val, n=1):
29
+ self.val = val
30
+ self.sum += val * n
31
+ self.count += n
32
+ self.avg = self.sum / self.count
33
+
34
+ def create_optimizer(optimizer_config, model, master_params=None):
35
+ """Creates optimizer and schedule from configuration
36
+
37
+ Parameters
38
+ ----------
39
+ optimizer_config : dict
40
+ Dictionary containing the configuration options for the optimizer.
41
+ model : Model
42
+ The network model.
43
+
44
+ Returns
45
+ -------
46
+ optimizer : Optimizer
47
+ The optimizer.
48
+ scheduler : LRScheduler
49
+ The learning rate scheduler.
50
+ """
51
+ if optimizer_config.get("classifier_lr", -1) != -1:
52
+ # Separate classifier parameters from all others
53
+ net_params = []
54
+ classifier_params = []
55
+ for k, v in model.named_parameters():
56
+ if not v.requires_grad:
57
+ continue
58
+ if k.find("encoder") != -1:
59
+ net_params.append(v)
60
+ else:
61
+ classifier_params.append(v)
62
+ params = [
63
+ {"params": net_params},
64
+ {"params": classifier_params, "lr": optimizer_config["classifier_lr"]},
65
+ ]
66
+ else:
67
+ if master_params:
68
+ params = master_params
69
+ else:
70
+ params = model.parameters()
71
+
72
+ if optimizer_config["type"] == "SGD":
73
+ optimizer = optim.SGD(params,
74
+ lr=optimizer_config["learning_rate"],
75
+ momentum=optimizer_config["momentum"],
76
+ weight_decay=optimizer_config["weight_decay"],
77
+ nesterov=optimizer_config["nesterov"])
78
+ elif optimizer_config["type"] == "FusedSGD":
79
+ optimizer = FusedSGD(params,
80
+ lr=optimizer_config["learning_rate"],
81
+ momentum=optimizer_config["momentum"],
82
+ weight_decay=optimizer_config["weight_decay"],
83
+ nesterov=optimizer_config["nesterov"])
84
+ elif optimizer_config["type"] == "Adam":
85
+ optimizer = optim.Adam(params,
86
+ lr=optimizer_config["learning_rate"],
87
+ weight_decay=optimizer_config["weight_decay"])
88
+ elif optimizer_config["type"] == "FusedAdam":
89
+ optimizer = FusedAdam(params,
90
+ lr=optimizer_config["learning_rate"],
91
+ weight_decay=optimizer_config["weight_decay"])
92
+ elif optimizer_config["type"] == "AdamW":
93
+ optimizer = AdamW(params,
94
+ lr=optimizer_config["learning_rate"],
95
+ weight_decay=optimizer_config["weight_decay"])
96
+ elif optimizer_config["type"] == "RmsProp":
97
+ optimizer = RMSprop(params,
98
+ lr=optimizer_config["learning_rate"],
99
+ weight_decay=optimizer_config["weight_decay"])
100
+ else:
101
+ raise KeyError("unrecognized optimizer {}".format(optimizer_config["type"]))
102
+
103
+ if optimizer_config["schedule"]["type"] == "step":
104
+ scheduler = LRStepScheduler(optimizer, **optimizer_config["schedule"]["params"])
105
+ elif optimizer_config["schedule"]["type"] == "clr":
106
+ scheduler = CyclicLR(optimizer, **optimizer_config["schedule"]["params"])
107
+ elif optimizer_config["schedule"]["type"] == "multistep":
108
+ scheduler = MultiStepLR(optimizer, **optimizer_config["schedule"]["params"])
109
+ elif optimizer_config["schedule"]["type"] == "exponential":
110
+ scheduler = ExponentialLRScheduler(optimizer, **optimizer_config["schedule"]["params"])
111
+ elif optimizer_config["schedule"]["type"] == "poly":
112
+ scheduler = PolyLR(optimizer, **optimizer_config["schedule"]["params"])
113
+ elif optimizer_config["schedule"]["type"] == "constant":
114
+ scheduler = lr_scheduler.LambdaLR(optimizer, lambda epoch: 1.0)
115
+ elif optimizer_config["schedule"]["type"] == "linear":
116
+ def linear_lr(it):
117
+ return it * optimizer_config["schedule"]["params"]["alpha"] + optimizer_config["schedule"]["params"]["beta"]
118
+
119
+ scheduler = lr_scheduler.LambdaLR(optimizer, linear_lr)
120
+
121
+ return optimizer, scheduler
training/transforms/__init__.py ADDED
File without changes
training/transforms/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (159 Bytes). View file
 
training/transforms/__pycache__/albu.cpython-310.pyc ADDED
Binary file (4.36 kB). View file
 
training/transforms/albu.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import cv2
4
+ import numpy as np
5
+ from albumentations import DualTransform, ImageOnlyTransform
6
+ from albumentations.augmentations.crops.functional import crop
7
+ #from albumentations.augmentations.functional import crop
8
+
9
+
10
+ def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC):
11
+ h, w = img.shape[:2]
12
+ if max(w, h) == size:
13
+ return img
14
+ if w > h:
15
+ scale = size / w
16
+ h = h * scale
17
+ w = size
18
+ else:
19
+ scale = size / h
20
+ w = w * scale
21
+ h = size
22
+ interpolation = interpolation_up if scale > 1 else interpolation_down
23
+ resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation)
24
+ return resized
25
+
26
+
27
+ class IsotropicResize(DualTransform):
28
+ def __init__(self, max_side, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC,
29
+ always_apply=False, p=1):
30
+ super(IsotropicResize, self).__init__(always_apply, p)
31
+ self.max_side = max_side
32
+ self.interpolation_down = interpolation_down
33
+ self.interpolation_up = interpolation_up
34
+
35
+ def apply(self, img, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC, **params):
36
+ return isotropically_resize_image(img, size=self.max_side, interpolation_down=interpolation_down,
37
+ interpolation_up=interpolation_up)
38
+
39
+ def apply_to_mask(self, img, **params):
40
+ return self.apply(img, interpolation_down=cv2.INTER_NEAREST, interpolation_up=cv2.INTER_NEAREST, **params)
41
+
42
+ def get_transform_init_args_names(self):
43
+ return ("max_side", "interpolation_down", "interpolation_up")
44
+
45
+
46
+ class Resize4xAndBack(ImageOnlyTransform):
47
+ def __init__(self, always_apply=False, p=0.5):
48
+ super(Resize4xAndBack, self).__init__(always_apply, p)
49
+
50
+ def apply(self, img, **params):
51
+ h, w = img.shape[:2]
52
+ scale = random.choice([2, 4])
53
+ img = cv2.resize(img, (w // scale, h // scale), interpolation=cv2.INTER_AREA)
54
+ img = cv2.resize(img, (w, h),
55
+ interpolation=random.choice([cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST]))
56
+ return img
57
+
58
+
59
+ class RandomSizedCropNonEmptyMaskIfExists(DualTransform):
60
+
61
+ def __init__(self, min_max_height, w2h_ratio=[0.7, 1.3], always_apply=False, p=0.5):
62
+ super(RandomSizedCropNonEmptyMaskIfExists, self).__init__(always_apply, p)
63
+
64
+ self.min_max_height = min_max_height
65
+ self.w2h_ratio = w2h_ratio
66
+
67
+ def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params):
68
+ cropped = crop(img, x_min, y_min, x_max, y_max)
69
+ return cropped
70
+
71
+ @property
72
+ def targets_as_params(self):
73
+ return ["mask"]
74
+
75
+ def get_params_dependent_on_targets(self, params):
76
+ mask = params["mask"]
77
+ mask_height, mask_width = mask.shape[:2]
78
+ crop_height = int(mask_height * random.uniform(self.min_max_height[0], self.min_max_height[1]))
79
+ w2h_ratio = random.uniform(*self.w2h_ratio)
80
+ crop_width = min(int(crop_height * w2h_ratio), mask_width - 1)
81
+ if mask.sum() == 0:
82
+ x_min = random.randint(0, mask_width - crop_width + 1)
83
+ y_min = random.randint(0, mask_height - crop_height + 1)
84
+ else:
85
+ mask = mask.sum(axis=-1) if mask.ndim == 3 else mask
86
+ non_zero_yx = np.argwhere(mask)
87
+ y, x = random.choice(non_zero_yx)
88
+ x_min = x - random.randint(0, crop_width - 1)
89
+ y_min = y - random.randint(0, crop_height - 1)
90
+ x_min = np.clip(x_min, 0, mask_width - crop_width)
91
+ y_min = np.clip(y_min, 0, mask_height - crop_height)
92
+
93
+ x_max = x_min + crop_height
94
+ y_max = y_min + crop_width
95
+ y_max = min(mask_height, y_max)
96
+ x_max = min(mask_width, x_max)
97
+ return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
98
+
99
+ def get_transform_init_args_names(self):
100
+ return "min_max_height", "height", "width", "w2h_ratio"
training/zoo/__init__.py ADDED
File without changes
training/zoo/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
training/zoo/__pycache__/classifiers.cpython-310.pyc ADDED
Binary file (5.55 kB). View file
 
training/zoo/classifiers.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ import torch
5
+ from timm.models.efficientnet import tf_efficientnet_b4_ns, tf_efficientnet_b3_ns, \
6
+ tf_efficientnet_b5_ns, tf_efficientnet_b2_ns, tf_efficientnet_b6_ns, tf_efficientnet_b7_ns
7
+ from torch import nn
8
+ from torch.nn.modules.dropout import Dropout
9
+ from torch.nn.modules.linear import Linear
10
+ from torch.nn.modules.pooling import AdaptiveAvgPool2d
11
+
12
+ encoder_params = {
13
+ "tf_efficientnet_b3_ns": {
14
+ "features": 1536,
15
+ "init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
16
+ },
17
+ "tf_efficientnet_b2_ns": {
18
+ "features": 1408,
19
+ "init_op": partial(tf_efficientnet_b2_ns, pretrained=False, drop_path_rate=0.2)
20
+ },
21
+ "tf_efficientnet_b4_ns": {
22
+ "features": 1792,
23
+ "init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.5)
24
+ },
25
+ "tf_efficientnet_b5_ns": {
26
+ "features": 2048,
27
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
28
+ },
29
+ "tf_efficientnet_b4_ns_03d": {
30
+ "features": 1792,
31
+ "init_op": partial(tf_efficientnet_b4_ns, pretrained=True, drop_path_rate=0.3)
32
+ },
33
+ "tf_efficientnet_b5_ns_03d": {
34
+ "features": 2048,
35
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.3)
36
+ },
37
+ "tf_efficientnet_b5_ns_04d": {
38
+ "features": 2048,
39
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.4)
40
+ },
41
+ "tf_efficientnet_b6_ns": {
42
+ "features": 2304,
43
+ "init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.2)
44
+ },
45
+ "tf_efficientnet_b7_ns": {
46
+ "features": 2560,
47
+ "init_op": partial(tf_efficientnet_b7_ns, pretrained=True, drop_path_rate=0.2)
48
+ },
49
+ "tf_efficientnet_b6_ns_04d": {
50
+ "features": 2304,
51
+ "init_op": partial(tf_efficientnet_b6_ns, pretrained=True, drop_path_rate=0.4)
52
+ },
53
+ }
54
+
55
+
56
+ def setup_srm_weights(input_channels: int = 3) -> torch.Tensor:
57
+ """Creates the SRM kernels for noise analysis."""
58
+ # note: values taken from Zhou et al., "Learning Rich Features for Image Manipulation Detection", CVPR2018
59
+ srm_kernel = torch.from_numpy(np.array([
60
+ [ # srm 1/2 horiz
61
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
62
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
63
+ [0., 1., -2., 1., 0.], # noqa: E241,E201
64
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
65
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
66
+ ], [ # srm 1/4
67
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
68
+ [0., -1., 2., -1., 0.], # noqa: E241,E201
69
+ [0., 2., -4., 2., 0.], # noqa: E241,E201
70
+ [0., -1., 2., -1., 0.], # noqa: E241,E201
71
+ [0., 0., 0., 0., 0.], # noqa: E241,E201
72
+ ], [ # srm 1/12
73
+ [-1., 2., -2., 2., -1.], # noqa: E241,E201
74
+ [2., -6., 8., -6., 2.], # noqa: E241,E201
75
+ [-2., 8., -12., 8., -2.], # noqa: E241,E201
76
+ [2., -6., 8., -6., 2.], # noqa: E241,E201
77
+ [-1., 2., -2., 2., -1.], # noqa: E241,E201
78
+ ]
79
+ ])).float()
80
+ srm_kernel[0] /= 2
81
+ srm_kernel[1] /= 4
82
+ srm_kernel[2] /= 12
83
+ return srm_kernel.view(3, 1, 5, 5).repeat(1, input_channels, 1, 1)
84
+
85
+
86
+ def setup_srm_layer(input_channels: int = 3) -> torch.nn.Module:
87
+ """Creates a SRM convolution layer for noise analysis."""
88
+ weights = setup_srm_weights(input_channels)
89
+ conv = torch.nn.Conv2d(input_channels, out_channels=3, kernel_size=5, stride=1, padding=2, bias=False)
90
+ with torch.no_grad():
91
+ conv.weight = torch.nn.Parameter(weights, requires_grad=False)
92
+ return conv
93
+
94
+
95
+ class DeepFakeClassifierSRM(nn.Module):
96
+ def __init__(self, encoder, dropout_rate=0.5) -> None:
97
+ super().__init__()
98
+ self.encoder = encoder_params[encoder]["init_op"]()
99
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
100
+ self.srm_conv = setup_srm_layer(3)
101
+ self.dropout = Dropout(dropout_rate)
102
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
103
+
104
+ def forward(self, x):
105
+ noise = self.srm_conv(x)
106
+ x = self.encoder.forward_features(noise)
107
+ x = self.avg_pool(x).flatten(1)
108
+ x = self.dropout(x)
109
+ x = self.fc(x)
110
+ return x
111
+
112
+
113
+ class GlobalWeightedAvgPool2d(nn.Module):
114
+ """
115
+ Global Weighted Average Pooling from paper "Global Weighted Average
116
+ Pooling Bridges Pixel-level Localization and Image-level Classification"
117
+ """
118
+
119
+ def __init__(self, features: int, flatten=False):
120
+ super().__init__()
121
+ self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True)
122
+ self.flatten = flatten
123
+
124
+ def fscore(self, x):
125
+ m = self.conv(x)
126
+ m = m.sigmoid().exp()
127
+ return m
128
+
129
+ def norm(self, x: torch.Tensor):
130
+ return x / x.sum(dim=[2, 3], keepdim=True)
131
+
132
+ def forward(self, x):
133
+ input_x = x
134
+ x = self.fscore(x)
135
+ x = self.norm(x)
136
+ x = x * input_x
137
+ x = x.sum(dim=[2, 3], keepdim=not self.flatten)
138
+ return x
139
+
140
+
141
+ class DeepFakeClassifier(nn.Module):
142
+ def __init__(self, encoder, dropout_rate=0.0) -> None:
143
+ super().__init__()
144
+ self.encoder = encoder_params[encoder]["init_op"]()
145
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
146
+ self.dropout = Dropout(dropout_rate)
147
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
148
+
149
+ def forward(self, x):
150
+ x = self.encoder.forward_features(x)
151
+ x = self.avg_pool(x).flatten(1)
152
+ x = self.dropout(x)
153
+ x = self.fc(x)
154
+ return x
155
+
156
+
157
+
158
+
159
+ class DeepFakeClassifierGWAP(nn.Module):
160
+ def __init__(self, encoder, dropout_rate=0.5) -> None:
161
+ super().__init__()
162
+ self.encoder = encoder_params[encoder]["init_op"]()
163
+ self.avg_pool = GlobalWeightedAvgPool2d(encoder_params[encoder]["features"])
164
+ self.dropout = Dropout(dropout_rate)
165
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
166
+
167
+ def forward(self, x):
168
+ x = self.encoder.forward_features(x)
169
+ x = self.avg_pool(x).flatten(1)
170
+ x = self.dropout(x)
171
+ x = self.fc(x)
172
+ return x
training/zoo/unet.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from timm.models.efficientnet import tf_efficientnet_b3_ns, tf_efficientnet_b5_ns
5
+ from torch import nn
6
+ from torch.nn import Dropout2d, Conv2d
7
+ from torch.nn.modules.dropout import Dropout
8
+ from torch.nn.modules.linear import Linear
9
+ from torch.nn.modules.pooling import AdaptiveAvgPool2d
10
+ from torch.nn.modules.upsampling import UpsamplingBilinear2d
11
+
12
+ encoder_params = {
13
+ "tf_efficientnet_b3_ns": {
14
+ "features": 1536,
15
+ "filters": [40, 32, 48, 136, 1536],
16
+ "decoder_filters": [64, 128, 256, 256],
17
+ "init_op": partial(tf_efficientnet_b3_ns, pretrained=True, drop_path_rate=0.2)
18
+ },
19
+ "tf_efficientnet_b5_ns": {
20
+ "features": 2048,
21
+ "filters": [48, 40, 64, 176, 2048],
22
+ "decoder_filters": [64, 128, 256, 256],
23
+ "init_op": partial(tf_efficientnet_b5_ns, pretrained=True, drop_path_rate=0.2)
24
+ },
25
+ }
26
+
27
+
28
+ class DecoderBlock(nn.Module):
29
+ def __init__(self, in_channels, out_channels):
30
+ super().__init__()
31
+ self.layer = nn.Sequential(
32
+ nn.Upsample(scale_factor=2),
33
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
34
+ nn.ReLU(inplace=True)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.layer(x)
39
+
40
+
41
+ class ConcatBottleneck(nn.Module):
42
+ def __init__(self, in_channels, out_channels):
43
+ super().__init__()
44
+ self.seq = nn.Sequential(
45
+ nn.Conv2d(in_channels, out_channels, 3, padding=1),
46
+ nn.ReLU(inplace=True)
47
+ )
48
+
49
+ def forward(self, dec, enc):
50
+ x = torch.cat([dec, enc], dim=1)
51
+ return self.seq(x)
52
+
53
+
54
+ class Decoder(nn.Module):
55
+ def __init__(self, decoder_filters, filters, upsample_filters=None,
56
+ decoder_block=DecoderBlock, bottleneck=ConcatBottleneck, dropout=0):
57
+ super().__init__()
58
+ self.decoder_filters = decoder_filters
59
+ self.filters = filters
60
+ self.decoder_block = decoder_block
61
+ self.decoder_stages = nn.ModuleList([self._get_decoder(idx) for idx in range(0, len(decoder_filters))])
62
+ self.bottlenecks = nn.ModuleList([bottleneck(self.filters[-i - 2] + f, f)
63
+ for i, f in enumerate(reversed(decoder_filters))])
64
+ self.dropout = Dropout2d(dropout) if dropout > 0 else None
65
+ self.last_block = None
66
+ if upsample_filters:
67
+ self.last_block = decoder_block(decoder_filters[0], out_channels=upsample_filters)
68
+ else:
69
+ self.last_block = UpsamplingBilinear2d(scale_factor=2)
70
+
71
+ def forward(self, encoder_results: list):
72
+ x = encoder_results[0]
73
+ bottlenecks = self.bottlenecks
74
+ for idx, bottleneck in enumerate(bottlenecks):
75
+ rev_idx = - (idx + 1)
76
+ x = self.decoder_stages[rev_idx](x)
77
+ x = bottleneck(x, encoder_results[-rev_idx])
78
+ if self.last_block:
79
+ x = self.last_block(x)
80
+ if self.dropout:
81
+ x = self.dropout(x)
82
+ return x
83
+
84
+ def _get_decoder(self, layer):
85
+ idx = layer + 1
86
+ if idx == len(self.decoder_filters):
87
+ in_channels = self.filters[idx]
88
+ else:
89
+ in_channels = self.decoder_filters[idx]
90
+ return self.decoder_block(in_channels, self.decoder_filters[max(layer, 0)])
91
+
92
+
93
+ def _initialize_weights(module):
94
+ for m in module.modules():
95
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
96
+ m.weight.data = nn.init.kaiming_normal_(m.weight.data)
97
+ if m.bias is not None:
98
+ m.bias.data.zero_()
99
+ elif isinstance(m, nn.BatchNorm2d):
100
+ m.weight.data.fill_(1)
101
+ m.bias.data.zero_()
102
+
103
+
104
+ class EfficientUnetClassifier(nn.Module):
105
+ def __init__(self, encoder, dropout_rate=0.5) -> None:
106
+ super().__init__()
107
+ self.decoder = Decoder(decoder_filters=encoder_params[encoder]["decoder_filters"],
108
+ filters=encoder_params[encoder]["filters"])
109
+ self.avg_pool = AdaptiveAvgPool2d((1, 1))
110
+ self.dropout = Dropout(dropout_rate)
111
+ self.fc = Linear(encoder_params[encoder]["features"], 1)
112
+ self.final = Conv2d(encoder_params[encoder]["decoder_filters"][0], out_channels=1, kernel_size=1, bias=False)
113
+ _initialize_weights(self)
114
+ self.encoder = encoder_params[encoder]["init_op"]()
115
+
116
+ def get_encoder_features(self, x):
117
+ encoder_results = []
118
+ x = self.encoder.conv_stem(x)
119
+ x = self.encoder.bn1(x)
120
+ x = self.encoder.act1(x)
121
+ encoder_results.append(x)
122
+ x = self.encoder.blocks[:2](x)
123
+ encoder_results.append(x)
124
+ x = self.encoder.blocks[2:3](x)
125
+ encoder_results.append(x)
126
+ x = self.encoder.blocks[3:5](x)
127
+ encoder_results.append(x)
128
+ x = self.encoder.blocks[5:](x)
129
+ x = self.encoder.conv_head(x)
130
+ x = self.encoder.bn2(x)
131
+ x = self.encoder.act2(x)
132
+ encoder_results.append(x)
133
+ encoder_results = list(reversed(encoder_results))
134
+ return encoder_results
135
+
136
+ def forward(self, x):
137
+ encoder_results = self.get_encoder_features(x)
138
+ seg = self.final(self.decoder(encoder_results))
139
+ x = encoder_results[0]
140
+ x = self.avg_pool(x).flatten(1)
141
+ x = self.dropout(x)
142
+ x = self.fc(x)
143
+ return x, seg
144
+
145
+
146
+ if __name__ == '__main__':
147
+ model = EfficientUnetClassifier("tf_efficientnet_b5_ns")
148
+ model.eval()
149
+ with torch.no_grad():
150
+ input = torch.rand(4, 3, 224, 224)
151
+ print(model(input))
weights/.gitkeep ADDED
File without changes
weights/b7_ns_best.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9db77ab9318863e2f8ab287c8eb83c2232584b82dc2fb41f1d614ddd7900cccb
3
+ size 266910617