Spaces:
Build error
Build error
=
commited on
Commit
Β·
b63fd37
1
Parent(s):
17b04e7
adding the best model to hugging face
Browse files- app.py +8 -2
- data/checkpoints/model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500/trainer_state.json +0 -64
- data/checkpoints/{model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500 β model_lhGqMDq/checkpoint-440}/config.json +0 -0
- data/checkpoints/{model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500 β model_lhGqMDq/checkpoint-440}/optimizer.pt +1 -1
- data/checkpoints/{model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500 β model_lhGqMDq/checkpoint-440}/pytorch_model.bin +1 -1
- data/checkpoints/{model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500 β model_lhGqMDq/checkpoint-440}/rng_state.pth +1 -1
- data/checkpoints/{model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500 β model_lhGqMDq/checkpoint-440}/scaler.pt +1 -1
- data/checkpoints/{model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500 β model_lhGqMDq/checkpoint-440}/scheduler.pt +1 -1
- data/checkpoints/model_lhGqMDq/checkpoint-440/trainer_state.json +80 -0
- data/checkpoints/{model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500 β model_lhGqMDq/checkpoint-440}/training_args.bin +2 -2
- fake_face_detection/__pycache__/__init__.cpython-310.pyc +0 -0
- fake_face_detection/data/__pycache__/__init__.cpython-310.pyc +0 -0
- fake_face_detection/data/__pycache__/collator.cpython-310.pyc +0 -0
- fake_face_detection/data/__pycache__/fake_face_dataset.cpython-310.pyc +0 -0
- fake_face_detection/data/__pycache__/lion_cheetah_collator.cpython-310.pyc +0 -0
- fake_face_detection/data/__pycache__/lion_cheetah_dataset.cpython-310.pyc +0 -0
- fake_face_detection/data/lion_cheetah_collator.py +33 -0
- fake_face_detection/data/lion_cheetah_dataset.py +102 -0
- fake_face_detection/metrics/__pycache__/__init__.cpython-310.pyc +0 -0
- fake_face_detection/metrics/__pycache__/compute_metrics.cpython-310.pyc +0 -0
- fake_face_detection/metrics/__pycache__/make_predictions.cpython-310.pyc +0 -0
- fake_face_detection/metrics/make_predictions.py +36 -12
- fake_face_detection/optimization/__pycache__/bayesian_optimization.cpython-310.pyc +0 -0
- fake_face_detection/optimization/__pycache__/fake_face_bayesian_optimization.cpython-310.pyc +0 -0
- fake_face_detection/optimization/fake_face_bayesian_optimization.py +26 -16
- fake_face_detection/trainers/lion_cheetah_search_train.py +80 -0
- fake_face_detection/trainers/search_train.py +13 -8
- fake_face_detection/utils/__pycache__/compute_weights.cpython-310.pyc +0 -0
- fake_face_detection/utils/visualize_images.py +15 -1
app.py
CHANGED
@@ -51,7 +51,7 @@ def get_model():
|
|
51 |
|
52 |
# recuperate the model
|
53 |
model = ViTForImageClassification.from_pretrained(
|
54 |
-
'data/checkpoints/
|
55 |
num_labels = len(characs['ids']),
|
56 |
id2label = {name: key for key, name in characs['ids'].items()},
|
57 |
label2id = characs['ids']
|
@@ -84,6 +84,12 @@ if file is not None:
|
|
84 |
|
85 |
left.markdown("""---""")
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
if left.button("SUBMIT"):
|
88 |
|
89 |
# Let us convert the image format to 'RGB'
|
@@ -116,7 +122,7 @@ if file is not None:
|
|
116 |
attention = outputs.attentions[-1][0]
|
117 |
|
118 |
# Let us recuperate the attention image
|
119 |
-
attention_image = get_attention(image, attention, size = (224, 224), patch_size = (14, 14))
|
120 |
|
121 |
# Let us transform the attention image to a opencv image
|
122 |
attention_image = cv2.cvtColor(attention_image.astype('float32'), cv2.COLOR_RGB2BGR)
|
|
|
51 |
|
52 |
# recuperate the model
|
53 |
model = ViTForImageClassification.from_pretrained(
|
54 |
+
'data/checkpoints/model_lhGqMDq/checkpoint-440',
|
55 |
num_labels = len(characs['ids']),
|
56 |
id2label = {name: key for key, name in characs['ids'].items()},
|
57 |
label2id = characs['ids']
|
|
|
84 |
|
85 |
left.markdown("""---""")
|
86 |
|
87 |
+
# add a side for the scaler and the head number
|
88 |
+
scale = st.sidebar.slider("Attention scale", min_value=30, max_value =200)
|
89 |
+
|
90 |
+
head = int(st.sidebar.selectbox("Attention head", options=list(range(1, 13))))
|
91 |
+
|
92 |
+
|
93 |
if left.button("SUBMIT"):
|
94 |
|
95 |
# Let us convert the image format to 'RGB'
|
|
|
122 |
attention = outputs.attentions[-1][0]
|
123 |
|
124 |
# Let us recuperate the attention image
|
125 |
+
attention_image = get_attention(image, attention, size = (224, 224), patch_size = (14, 14), scale = scale, head = head)
|
126 |
|
127 |
# Let us transform the attention image to a opencv image
|
128 |
attention_image = cv2.cvtColor(attention_image.astype('float32'), cv2.COLOR_RGB2BGR)
|
data/checkpoints/model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500/trainer_state.json
DELETED
@@ -1,64 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"best_metric": 0.6927365064620972,
|
3 |
-
"best_model_checkpoint": "data/checkpoints/model_2yW4AcqNIb6zLKNIb6zLK\\checkpoint-1500",
|
4 |
-
"epoch": 1.710376282782212,
|
5 |
-
"global_step": 1500,
|
6 |
-
"is_hyper_param_search": false,
|
7 |
-
"is_local_process_zero": true,
|
8 |
-
"is_world_process_zero": true,
|
9 |
-
"log_history": [
|
10 |
-
{
|
11 |
-
"epoch": 0.57,
|
12 |
-
"learning_rate": 0.00012064414686134504,
|
13 |
-
"loss": 0.6945,
|
14 |
-
"step": 500
|
15 |
-
},
|
16 |
-
{
|
17 |
-
"epoch": 0.57,
|
18 |
-
"eval_accuracy": 0.5081081081081081,
|
19 |
-
"eval_f1": 0.38095238095238093,
|
20 |
-
"eval_loss": 0.6931825280189514,
|
21 |
-
"eval_runtime": 6.1462,
|
22 |
-
"eval_samples_per_second": 30.1,
|
23 |
-
"eval_steps_per_second": 3.905,
|
24 |
-
"step": 500
|
25 |
-
},
|
26 |
-
{
|
27 |
-
"epoch": 1.14,
|
28 |
-
"learning_rate": 0.00010514828346782865,
|
29 |
-
"loss": 0.6937,
|
30 |
-
"step": 1000
|
31 |
-
},
|
32 |
-
{
|
33 |
-
"epoch": 1.14,
|
34 |
-
"eval_accuracy": 0.4702702702702703,
|
35 |
-
"eval_f1": 0.0,
|
36 |
-
"eval_loss": 0.6942673325538635,
|
37 |
-
"eval_runtime": 11.0225,
|
38 |
-
"eval_samples_per_second": 16.784,
|
39 |
-
"eval_steps_per_second": 2.177,
|
40 |
-
"step": 1000
|
41 |
-
},
|
42 |
-
{
|
43 |
-
"epoch": 1.71,
|
44 |
-
"learning_rate": 8.962136623985633e-05,
|
45 |
-
"loss": 0.6936,
|
46 |
-
"step": 1500
|
47 |
-
},
|
48 |
-
{
|
49 |
-
"epoch": 1.71,
|
50 |
-
"eval_accuracy": 0.5297297297297298,
|
51 |
-
"eval_f1": 0.6925795053003534,
|
52 |
-
"eval_loss": 0.6927365064620972,
|
53 |
-
"eval_runtime": 6.7463,
|
54 |
-
"eval_samples_per_second": 27.423,
|
55 |
-
"eval_steps_per_second": 3.558,
|
56 |
-
"step": 1500
|
57 |
-
}
|
58 |
-
],
|
59 |
-
"max_steps": 4385,
|
60 |
-
"num_train_epochs": 5,
|
61 |
-
"total_flos": 2.323984768541614e+17,
|
62 |
-
"trial_name": null,
|
63 |
-
"trial_params": null
|
64 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data/checkpoints/{model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500 β model_lhGqMDq/checkpoint-440}/config.json
RENAMED
File without changes
|
data/checkpoints/{model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500 β model_lhGqMDq/checkpoint-440}/optimizer.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 686518917
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:34ae20bd389738b6680714ac81db29b7818f40d430f3ee5f99989e50498d9dbd
|
3 |
size 686518917
|
data/checkpoints/{model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500 β model_lhGqMDq/checkpoint-440}/pytorch_model.bin
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 343268717
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2e40727744f63eb801348e4e0730accb61a5fdf014bb8086765db19a592fe248
|
3 |
size 343268717
|
data/checkpoints/{model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500 β model_lhGqMDq/checkpoint-440}/rng_state.pth
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 14575
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:26c8e955c34ff14bfa8f687fd165d49decdd3e034d94027e5187c7e7e7496c1a
|
3 |
size 14575
|
data/checkpoints/{model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500 β model_lhGqMDq/checkpoint-440}/scaler.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 557
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:76c820f9778523d807e0a80d015f245f8f1a75ff5c8ee1aa94e258d65b1066f5
|
3 |
size 557
|
data/checkpoints/{model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500 β model_lhGqMDq/checkpoint-440}/scheduler.pt
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 627
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d56c15a0722c57b78cdab7fc0e6a1424469ef928d58365fbf9e1bfe0832b43e7
|
3 |
size 627
|
data/checkpoints/model_lhGqMDq/checkpoint-440/trainer_state.json
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": 0.32448598742485046,
|
3 |
+
"best_model_checkpoint": "data/checkpoints/model_lhGqMDq\\checkpoint-440",
|
4 |
+
"epoch": 4.0,
|
5 |
+
"global_step": 440,
|
6 |
+
"is_hyper_param_search": false,
|
7 |
+
"is_local_process_zero": true,
|
8 |
+
"is_world_process_zero": true,
|
9 |
+
"log_history": [
|
10 |
+
{
|
11 |
+
"epoch": 1.0,
|
12 |
+
"learning_rate": 6.923196230748668e-05,
|
13 |
+
"loss": 0.6551,
|
14 |
+
"step": 110
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"epoch": 1.0,
|
18 |
+
"eval_accuracy": 0.6702702702702703,
|
19 |
+
"eval_f1": 0.6772486772486773,
|
20 |
+
"eval_loss": 0.6143904328346252,
|
21 |
+
"eval_runtime": 7.9911,
|
22 |
+
"eval_samples_per_second": 23.151,
|
23 |
+
"eval_steps_per_second": 3.003,
|
24 |
+
"step": 110
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"epoch": 2.0,
|
28 |
+
"learning_rate": 4.615464153832446e-05,
|
29 |
+
"loss": 0.5106,
|
30 |
+
"step": 220
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"epoch": 2.0,
|
34 |
+
"eval_accuracy": 0.7675675675675676,
|
35 |
+
"eval_f1": 0.7860696517412936,
|
36 |
+
"eval_loss": 0.4895593523979187,
|
37 |
+
"eval_runtime": 7.6354,
|
38 |
+
"eval_samples_per_second": 24.229,
|
39 |
+
"eval_steps_per_second": 3.143,
|
40 |
+
"step": 220
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"epoch": 3.0,
|
44 |
+
"learning_rate": 2.307732076916223e-05,
|
45 |
+
"loss": 0.4299,
|
46 |
+
"step": 330
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"epoch": 3.0,
|
50 |
+
"eval_accuracy": 0.8108108108108109,
|
51 |
+
"eval_f1": 0.8044692737430168,
|
52 |
+
"eval_loss": 0.435648649930954,
|
53 |
+
"eval_runtime": 8.1072,
|
54 |
+
"eval_samples_per_second": 22.819,
|
55 |
+
"eval_steps_per_second": 2.96,
|
56 |
+
"step": 330
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"epoch": 4.0,
|
60 |
+
"learning_rate": 0.0,
|
61 |
+
"loss": 0.2903,
|
62 |
+
"step": 440
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"epoch": 4.0,
|
66 |
+
"eval_accuracy": 0.8594594594594595,
|
67 |
+
"eval_f1": 0.8673469387755102,
|
68 |
+
"eval_loss": 0.32448598742485046,
|
69 |
+
"eval_runtime": 7.1509,
|
70 |
+
"eval_samples_per_second": 25.871,
|
71 |
+
"eval_steps_per_second": 3.356,
|
72 |
+
"step": 440
|
73 |
+
}
|
74 |
+
],
|
75 |
+
"max_steps": 440,
|
76 |
+
"num_train_epochs": 4,
|
77 |
+
"total_flos": 5.433738311775191e+17,
|
78 |
+
"trial_name": null,
|
79 |
+
"trial_params": null
|
80 |
+
}
|
data/checkpoints/{model_2yW4AcqNIb6zLKNIb6zLK/checkpoint-1500 β model_lhGqMDq/checkpoint-440}/training_args.bin
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:74817fa301a4e63d886a80850533c40044c6316a596cd85b00f752249a979780
|
3 |
+
size 3579
|
fake_face_detection/__pycache__/__init__.cpython-310.pyc
CHANGED
Binary files a/fake_face_detection/__pycache__/__init__.cpython-310.pyc and b/fake_face_detection/__pycache__/__init__.cpython-310.pyc differ
|
|
fake_face_detection/data/__pycache__/__init__.cpython-310.pyc
CHANGED
Binary files a/fake_face_detection/data/__pycache__/__init__.cpython-310.pyc and b/fake_face_detection/data/__pycache__/__init__.cpython-310.pyc differ
|
|
fake_face_detection/data/__pycache__/collator.cpython-310.pyc
CHANGED
Binary files a/fake_face_detection/data/__pycache__/collator.cpython-310.pyc and b/fake_face_detection/data/__pycache__/collator.cpython-310.pyc differ
|
|
fake_face_detection/data/__pycache__/fake_face_dataset.cpython-310.pyc
CHANGED
Binary files a/fake_face_detection/data/__pycache__/fake_face_dataset.cpython-310.pyc and b/fake_face_detection/data/__pycache__/fake_face_dataset.cpython-310.pyc differ
|
|
fake_face_detection/data/__pycache__/lion_cheetah_collator.cpython-310.pyc
ADDED
Binary file (953 Bytes). View file
|
|
fake_face_detection/data/__pycache__/lion_cheetah_dataset.cpython-310.pyc
ADDED
Binary file (2.09 kB). View file
|
|
fake_face_detection/data/lion_cheetah_collator.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
def lion_cheetah_collator(batch):
|
6 |
+
"""The data collator for training vision transformer models on the lion cheetah dataset
|
7 |
+
|
8 |
+
Args:
|
9 |
+
batch (list): A dictionary containing the pixel values and the labels
|
10 |
+
|
11 |
+
Returns:
|
12 |
+
dict: The final dictionary
|
13 |
+
"""
|
14 |
+
|
15 |
+
new_batch = {
|
16 |
+
'pixel_values': [],
|
17 |
+
'labels': []
|
18 |
+
}
|
19 |
+
|
20 |
+
for x in batch:
|
21 |
+
|
22 |
+
pixel_values = torch.from_numpy(x['pixel_values'][0]) if isinstance(x['pixel_values'][0], np.ndarray) \
|
23 |
+
else x['pixel_values'][0]
|
24 |
+
|
25 |
+
new_batch['pixel_values'].append(pixel_values)
|
26 |
+
|
27 |
+
new_batch['labels'].append(torch.tensor(x['labels']))
|
28 |
+
|
29 |
+
new_batch['pixel_values'] = torch.stack(new_batch['pixel_values'])
|
30 |
+
|
31 |
+
new_batch['labels'] = torch.stack(new_batch['labels'])
|
32 |
+
|
33 |
+
return new_batch
|
fake_face_detection/data/lion_cheetah_dataset.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from fake_face_detection.utils.compute_weights import compute_weights
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
from PIL import Image
|
5 |
+
from glob import glob
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import os
|
9 |
+
|
10 |
+
class LionCheetahDataset(Dataset):
|
11 |
+
|
12 |
+
def __init__(self, lion_path: str, cheetah_path: str, id_map: dict, transformer, **transformer_kwargs):
|
13 |
+
|
14 |
+
# let us recuperate the transformer
|
15 |
+
self.transformer = transformer
|
16 |
+
|
17 |
+
# let us recuperate the transformer kwargs
|
18 |
+
self.transformer_kwargs = transformer_kwargs
|
19 |
+
|
20 |
+
# let us load the images
|
21 |
+
lion_images = glob(os.path.join(lion_path, "*"))
|
22 |
+
|
23 |
+
cheetah_images = glob(os.path.join(cheetah_path, "*"))
|
24 |
+
|
25 |
+
# recuperate rgb images
|
26 |
+
self.lion_images = []
|
27 |
+
|
28 |
+
self.cheetah_images = []
|
29 |
+
|
30 |
+
for lion in lion_images:
|
31 |
+
|
32 |
+
try:
|
33 |
+
|
34 |
+
with Image.open(lion) as img:
|
35 |
+
|
36 |
+
# let us add a transformation on the images
|
37 |
+
if self.transformer:
|
38 |
+
|
39 |
+
image = self.transformer(img, **self.transformer_kwargs)
|
40 |
+
|
41 |
+
self.lion_images.append(lion)
|
42 |
+
|
43 |
+
except Exception as e:
|
44 |
+
|
45 |
+
pass
|
46 |
+
|
47 |
+
for cheetah in cheetah_images:
|
48 |
+
|
49 |
+
try:
|
50 |
+
|
51 |
+
with Image.open(cheetah) as img:
|
52 |
+
|
53 |
+
# let us add a transformation on the images
|
54 |
+
if self.transformer:
|
55 |
+
|
56 |
+
image = self.transformer(img, **self.transformer_kwargs)
|
57 |
+
|
58 |
+
self.cheetah_images.append(cheetah)
|
59 |
+
|
60 |
+
except Exception as e:
|
61 |
+
|
62 |
+
pass
|
63 |
+
|
64 |
+
self.images = self.lion_images + self.cheetah_images
|
65 |
+
|
66 |
+
# let us recuperate the labels
|
67 |
+
self.lion_labels = [int(id_map['lion'])] * len(self.lion_images)
|
68 |
+
|
69 |
+
self.cheetah_labels = [int(id_map['cheetah'])] * len(self.cheetah_images)
|
70 |
+
|
71 |
+
self.labels = self.lion_labels + self.cheetah_labels
|
72 |
+
|
73 |
+
# let us recuperate the weights
|
74 |
+
self.weights = torch.from_numpy(compute_weights(self.labels))
|
75 |
+
|
76 |
+
# let us recuperate the length
|
77 |
+
self.length = len(self.labels)
|
78 |
+
|
79 |
+
def __getitem__(self, index):
|
80 |
+
|
81 |
+
# let us recuperate an image
|
82 |
+
image = self.images[index]
|
83 |
+
|
84 |
+
with Image.open(image) as img:
|
85 |
+
|
86 |
+
# let us recuperate a label
|
87 |
+
label = self.labels[index]
|
88 |
+
|
89 |
+
# let us add a transformation on the images
|
90 |
+
if self.transformer:
|
91 |
+
|
92 |
+
image = self.transformer(img, **self.transformer_kwargs)
|
93 |
+
|
94 |
+
# let us add the label inside the obtained dictionary
|
95 |
+
image['labels'] = label
|
96 |
+
|
97 |
+
return image
|
98 |
+
|
99 |
+
def __len__(self):
|
100 |
+
|
101 |
+
return self.length
|
102 |
+
|
fake_face_detection/metrics/__pycache__/__init__.cpython-310.pyc
CHANGED
Binary files a/fake_face_detection/metrics/__pycache__/__init__.cpython-310.pyc and b/fake_face_detection/metrics/__pycache__/__init__.cpython-310.pyc differ
|
|
fake_face_detection/metrics/__pycache__/compute_metrics.cpython-310.pyc
CHANGED
Binary files a/fake_face_detection/metrics/__pycache__/compute_metrics.cpython-310.pyc and b/fake_face_detection/metrics/__pycache__/compute_metrics.cpython-310.pyc differ
|
|
fake_face_detection/metrics/__pycache__/make_predictions.cpython-310.pyc
CHANGED
Binary files a/fake_face_detection/metrics/__pycache__/make_predictions.cpython-310.pyc and b/fake_face_detection/metrics/__pycache__/make_predictions.cpython-310.pyc differ
|
|
fake_face_detection/metrics/make_predictions.py
CHANGED
@@ -16,7 +16,7 @@ import numpy as np
|
|
16 |
import torch
|
17 |
import os
|
18 |
|
19 |
-
def get_attention(image: Union[str, JpegImageFile], attention: torch.Tensor, size: tuple, patch_size: tuple):
|
20 |
|
21 |
# recuperate the image as a numpy array
|
22 |
if isinstance(image, str):
|
@@ -33,7 +33,7 @@ def get_attention(image: Union[str, JpegImageFile], attention: torch.Tensor, siz
|
|
33 |
attention = attention[:, -1, 1:]
|
34 |
|
35 |
# calculate the mean attention
|
36 |
-
attention = attention
|
37 |
|
38 |
# let us reshape transform the image to a numpy array
|
39 |
|
@@ -48,9 +48,9 @@ def get_attention(image: Union[str, JpegImageFile], attention: torch.Tensor, siz
|
|
48 |
attention = attention.reshape(size[0], size[1], 1)
|
49 |
|
50 |
# recuperate the result
|
51 |
-
attention_image = img / 255 * attention.numpy()
|
52 |
|
53 |
-
return attention_image
|
54 |
|
55 |
|
56 |
def make_predictions(test_dataset: FakeFaceDetectionDataset,
|
@@ -60,7 +60,28 @@ def make_predictions(test_dataset: FakeFaceDetectionDataset,
|
|
60 |
batch_size: int = 3,
|
61 |
size: tuple = (224, 224),
|
62 |
patch_size: tuple = (14, 14),
|
63 |
-
figsize: tuple = (24, 24)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
with torch.no_grad():
|
66 |
|
@@ -86,22 +107,22 @@ def make_predictions(test_dataset: FakeFaceDetectionDataset,
|
|
86 |
for data in test_dataloader:
|
87 |
|
88 |
# recuperate the pixel values
|
89 |
-
pixel_values = data['pixel_values'][0]
|
90 |
|
91 |
# recuperate the labels
|
92 |
-
labels_ = data['labels']
|
93 |
|
94 |
# # recuperate the outputs
|
95 |
outputs = model(pixel_values, labels = labels_, output_attentions = True)
|
96 |
|
97 |
# recuperate the predictions
|
98 |
-
predictions['predictions'].append(torch.softmax(outputs.logits.detach()
|
99 |
|
100 |
# recuperate the attentions of the last encoder layer
|
101 |
-
predictions['attentions'].append(outputs.attentions[-1].detach()
|
102 |
|
103 |
# add the loss
|
104 |
-
loss += outputs.loss.detach().
|
105 |
|
106 |
predictions['predictions'] = np.concatenate(predictions['predictions'], axis = 0)
|
107 |
|
@@ -140,8 +161,11 @@ def make_predictions(test_dataset: FakeFaceDetectionDataset,
|
|
140 |
del predictions['predictions']
|
141 |
del predictions['attentions']
|
142 |
|
143 |
-
#
|
144 |
-
return pd.DataFrame(predictions), metrics
|
|
|
|
|
|
|
145 |
|
146 |
|
147 |
|
|
|
16 |
import torch
|
17 |
import os
|
18 |
|
19 |
+
def get_attention(image: Union[str, JpegImageFile], attention: torch.Tensor, size: tuple, patch_size: tuple, scale: int = 50, head: int = 1):
|
20 |
|
21 |
# recuperate the image as a numpy array
|
22 |
if isinstance(image, str):
|
|
|
33 |
attention = attention[:, -1, 1:]
|
34 |
|
35 |
# calculate the mean attention
|
36 |
+
attention = attention[head - 1]
|
37 |
|
38 |
# let us reshape transform the image to a numpy array
|
39 |
|
|
|
48 |
attention = attention.reshape(size[0], size[1], 1)
|
49 |
|
50 |
# recuperate the result
|
51 |
+
attention_image = img / 255 * attention.numpy() * scale
|
52 |
|
53 |
+
return np.clip(attention_image, 0, 1)
|
54 |
|
55 |
|
56 |
def make_predictions(test_dataset: FakeFaceDetectionDataset,
|
|
|
60 |
batch_size: int = 3,
|
61 |
size: tuple = (224, 224),
|
62 |
patch_size: tuple = (14, 14),
|
63 |
+
figsize: tuple = (24, 24),
|
64 |
+
attention_scale: int = 50,
|
65 |
+
show: bool = True,
|
66 |
+
head: int = 1):
|
67 |
+
"""Make predictions with a vision transformer model
|
68 |
+
|
69 |
+
Args:
|
70 |
+
test_dataset (FakeFaceDetectionDataset): The test dataset
|
71 |
+
model (_type_): The model
|
72 |
+
log_dir (str, optional): The log directory. Defaults to "fake_face_logs".
|
73 |
+
tag (str, optional): The tag. Defaults to "Attentions".
|
74 |
+
batch_size (int, optional): The batch size. Defaults to 3.
|
75 |
+
size (tuple, optional): The size of the attention image. Defaults to (224, 224).
|
76 |
+
patch_size (tuple, optional): The path size. Defaults to (14, 14).
|
77 |
+
figsize (tuple, optional): The figure size. Defaults to (24, 24).
|
78 |
+
attention_scale (int, optional): The attention scale. Defaults to 50.
|
79 |
+
show (bool, optional): A boolean value indicating if we want to recuperate the figure. Defaults to True.
|
80 |
+
head (int, optional): The head number. Defaults to 1.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
Union[Tuple[pd.DataFrame, dict], Tuple[pd.DataFame, dict, figure]]: The return prediction and the metrics
|
84 |
+
"""
|
85 |
|
86 |
with torch.no_grad():
|
87 |
|
|
|
107 |
for data in test_dataloader:
|
108 |
|
109 |
# recuperate the pixel values
|
110 |
+
pixel_values = data['pixel_values'][0]
|
111 |
|
112 |
# recuperate the labels
|
113 |
+
labels_ = data['labels']
|
114 |
|
115 |
# # recuperate the outputs
|
116 |
outputs = model(pixel_values, labels = labels_, output_attentions = True)
|
117 |
|
118 |
# recuperate the predictions
|
119 |
+
predictions['predictions'].append(torch.softmax(outputs.logits.detach(), axis = -1).numpy())
|
120 |
|
121 |
# recuperate the attentions of the last encoder layer
|
122 |
+
predictions['attentions'].append(outputs.attentions[-1].detach())
|
123 |
|
124 |
# add the loss
|
125 |
+
loss += outputs.loss.detach().item()
|
126 |
|
127 |
predictions['predictions'] = np.concatenate(predictions['predictions'], axis = 0)
|
128 |
|
|
|
161 |
del predictions['predictions']
|
162 |
del predictions['attentions']
|
163 |
|
164 |
+
# show the figure if necessary
|
165 |
+
if show: return pd.DataFrame(predictions), metrics, fig
|
166 |
+
else:
|
167 |
+
# let us recuperate the metrics and the predictions
|
168 |
+
return pd.DataFrame(predictions), metrics
|
169 |
|
170 |
|
171 |
|
fake_face_detection/optimization/__pycache__/bayesian_optimization.cpython-310.pyc
CHANGED
Binary files a/fake_face_detection/optimization/__pycache__/bayesian_optimization.cpython-310.pyc and b/fake_face_detection/optimization/__pycache__/bayesian_optimization.cpython-310.pyc differ
|
|
fake_face_detection/optimization/__pycache__/fake_face_bayesian_optimization.cpython-310.pyc
CHANGED
Binary files a/fake_face_detection/optimization/__pycache__/fake_face_bayesian_optimization.cpython-310.pyc and b/fake_face_detection/optimization/__pycache__/fake_face_bayesian_optimization.cpython-310.pyc differ
|
|
fake_face_detection/optimization/fake_face_bayesian_optimization.py
CHANGED
@@ -53,15 +53,15 @@ class SimpleBayesianOptimizationForFakeReal:
|
|
53 |
|
54 |
pickler = pickle.Unpickler(f)
|
55 |
|
56 |
-
|
57 |
|
58 |
-
self.data =
|
59 |
|
60 |
-
self.scores =
|
61 |
|
62 |
-
self.model =
|
63 |
|
64 |
-
self.current_trial =
|
65 |
|
66 |
print(f"Checkpoint loaded at trial {self.current_trial}")
|
67 |
|
@@ -113,17 +113,8 @@ class SimpleBayesianOptimizationForFakeReal:
|
|
113 |
new_sample = generate_sample(self.data, self.model, self.search_spaces, n_tests, maximize = self.maximize)
|
114 |
config = {key: new_sample[i] for i, key in enumerate(self.search_spaces)}
|
115 |
|
116 |
-
#
|
117 |
-
|
118 |
-
|
119 |
-
# add random kwargs to the kwargs
|
120 |
-
self.kwargs.update(random_kwargs)
|
121 |
-
|
122 |
-
# add config to kwargs
|
123 |
-
self.kwargs['config'] = config
|
124 |
-
|
125 |
-
# calculate the first score
|
126 |
-
new_score = self.objective(**self.kwargs)
|
127 |
|
128 |
# let us add the new sample, target and score to their lists
|
129 |
self.data.append(new_sample)
|
@@ -148,6 +139,25 @@ class SimpleBayesianOptimizationForFakeReal:
|
|
148 |
}
|
149 |
|
150 |
pickler.dump(checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
def get_results(self):
|
153 |
"""Recuperate the generated samples and the scores
|
|
|
53 |
|
54 |
pickler = pickle.Unpickler(f)
|
55 |
|
56 |
+
checkpoint = pickler.load()
|
57 |
|
58 |
+
self.data = checkpoint['data']
|
59 |
|
60 |
+
self.scores = checkpoint['scores']
|
61 |
|
62 |
+
self.model = checkpoint['model']
|
63 |
|
64 |
+
self.current_trial = checkpoint['trial']
|
65 |
|
66 |
print(f"Checkpoint loaded at trial {self.current_trial}")
|
67 |
|
|
|
113 |
new_sample = generate_sample(self.data, self.model, self.search_spaces, n_tests, maximize = self.maximize)
|
114 |
config = {key: new_sample[i] for i, key in enumerate(self.search_spaces)}
|
115 |
|
116 |
+
# recuperate a new score
|
117 |
+
new_score = self.get_score(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
# let us add the new sample, target and score to their lists
|
120 |
self.data.append(new_sample)
|
|
|
139 |
}
|
140 |
|
141 |
pickler.dump(checkpoint)
|
142 |
+
|
143 |
+
def get_score(self, config: dict):
|
144 |
+
|
145 |
+
# add random seed (since we have always the same problem of randomizing the seed)
|
146 |
+
random.seed(None)
|
147 |
+
|
148 |
+
# initialize the random kwargs with a random values
|
149 |
+
random_kwargs = {key: value + ''.join(random.choice(letters) for i in range(7)) for key, value in self.random_kwargs.items()}
|
150 |
+
print(random_kwargs)
|
151 |
+
# add random kwargs to the kwargs
|
152 |
+
self.kwargs.update(random_kwargs)
|
153 |
+
|
154 |
+
# add config to kwargs
|
155 |
+
self.kwargs['config'] = config
|
156 |
+
|
157 |
+
# calculate the first score
|
158 |
+
new_score = self.objective(**self.kwargs)
|
159 |
+
|
160 |
+
return new_score
|
161 |
|
162 |
def get_results(self):
|
163 |
"""Recuperate the generated samples and the scores
|
fake_face_detection/trainers/lion_cheetah_search_train.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from fake_face_detection.metrics.compute_metrics import compute_metrics
|
3 |
+
from fake_face_detection.data.lion_cheetah_collator import lion_cheetah_collator
|
4 |
+
from transformers import Trainer, TrainingArguments, set_seed
|
5 |
+
from torch.utils.tensorboard import SummaryWriter
|
6 |
+
from torch import nn
|
7 |
+
from typing import *
|
8 |
+
import numpy as np
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
|
12 |
+
def train(epochs: int, output_dir: str, config: dict, model: nn.Module, trainer, get_datasets: Callable, log_dir: str = "fake_face_logs", metric = 'accuracy', seed: int = 0):
|
13 |
+
|
14 |
+
print("------------------------- Beginning of training")
|
15 |
+
|
16 |
+
set_seed(seed)
|
17 |
+
|
18 |
+
# initialize the model
|
19 |
+
model = model()
|
20 |
+
|
21 |
+
# reformat the config integer type
|
22 |
+
for key, value in config.items():
|
23 |
+
|
24 |
+
if isinstance(value, np.int32): config[key] = int(value)
|
25 |
+
|
26 |
+
pretty = json.dumps(config, indent = 4)
|
27 |
+
|
28 |
+
print(f"Current Config: \n {pretty}")
|
29 |
+
|
30 |
+
print(f"Checkpoints in {output_dir}")
|
31 |
+
|
32 |
+
# recuperate the dataset
|
33 |
+
train_dataset, test_dataset = get_datasets(config['h_flip_p'], config['v_flip_p'], config['gray_scale_p'], config['rotation'])
|
34 |
+
|
35 |
+
# initialize the arguments of the training
|
36 |
+
training_args = TrainingArguments(output_dir,
|
37 |
+
per_device_train_batch_size=config['batch_size'],
|
38 |
+
evaluation_strategy='steps',
|
39 |
+
save_strategy='steps',
|
40 |
+
logging_strategy='steps',
|
41 |
+
num_train_epochs=epochs,
|
42 |
+
fp16=True,
|
43 |
+
save_total_limit=2,
|
44 |
+
remove_unused_columns=True,
|
45 |
+
push_to_hub=False,
|
46 |
+
logging_dir=os.path.join(log_dir, os.path.basename(output_dir)),
|
47 |
+
load_best_model_at_end=True,
|
48 |
+
learning_rate=config['lr'],
|
49 |
+
weight_decay=config['weight_decay']
|
50 |
+
)
|
51 |
+
|
52 |
+
# train the model
|
53 |
+
trainer_ = trainer(
|
54 |
+
model = model,
|
55 |
+
args = training_args,
|
56 |
+
data_collator = lion_cheetah_collator,
|
57 |
+
compute_metrics = compute_metrics,
|
58 |
+
train_dataset = train_dataset,
|
59 |
+
eval_dataset = test_dataset
|
60 |
+
)
|
61 |
+
|
62 |
+
# train the model
|
63 |
+
trainer_.train()
|
64 |
+
|
65 |
+
# evaluate the model and recuperate metrics
|
66 |
+
metrics = trainer_.evaluate(test_dataset)
|
67 |
+
|
68 |
+
# add metrics and config to the hyperparameter panel of tensorboard
|
69 |
+
with SummaryWriter(os.path.join(log_dir, 'lchparams')) as logger:
|
70 |
+
|
71 |
+
logger.add_hparams(
|
72 |
+
config, metrics
|
73 |
+
)
|
74 |
+
|
75 |
+
print(metrics)
|
76 |
+
|
77 |
+
print("------------------------- End of training")
|
78 |
+
# recuperate the metric to evaluate
|
79 |
+
return metrics[f'eval_{metric}']
|
80 |
+
|
fake_face_detection/trainers/search_train.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
|
2 |
from fake_face_detection.metrics.compute_metrics import compute_metrics
|
3 |
from fake_face_detection.data.collator import fake_face_collator
|
4 |
-
from transformers import Trainer, TrainingArguments
|
5 |
from torch.utils.tensorboard import SummaryWriter
|
6 |
from torch import nn
|
7 |
from typing import *
|
@@ -9,10 +9,15 @@ import numpy as np
|
|
9 |
import json
|
10 |
import os
|
11 |
|
12 |
-
def train(epochs: int, output_dir: str, config: dict, model: nn.Module, trainer, get_datasets: Callable, log_dir: str = "fake_face_logs", metric = 'accuracy'):
|
13 |
|
14 |
print("------------------------- Beginning of training")
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
# reformat the config integer type
|
17 |
for key, value in config.items():
|
18 |
|
@@ -22,24 +27,24 @@ def train(epochs: int, output_dir: str, config: dict, model: nn.Module, trainer,
|
|
22 |
|
23 |
print(f"Current Config: \n {pretty}")
|
24 |
|
|
|
|
|
25 |
# recuperate the dataset
|
26 |
train_dataset, test_dataset = get_datasets(config['h_flip_p'], config['v_flip_p'], config['gray_scale_p'], config['rotation'])
|
27 |
|
28 |
# initialize the arguments of the training
|
29 |
training_args = TrainingArguments(output_dir,
|
30 |
per_device_train_batch_size=config['batch_size'],
|
31 |
-
evaluation_strategy='
|
32 |
-
save_strategy='
|
33 |
-
logging_strategy='
|
34 |
num_train_epochs=epochs,
|
35 |
fp16=True,
|
36 |
save_total_limit=2,
|
37 |
-
remove_unused_columns=True,
|
38 |
push_to_hub=False,
|
39 |
logging_dir=os.path.join(log_dir, os.path.basename(output_dir)),
|
40 |
load_best_model_at_end=True,
|
41 |
-
learning_rate=config['lr']
|
42 |
-
weight_decay=config['weight_decay']
|
43 |
)
|
44 |
|
45 |
# train the model
|
|
|
1 |
|
2 |
from fake_face_detection.metrics.compute_metrics import compute_metrics
|
3 |
from fake_face_detection.data.collator import fake_face_collator
|
4 |
+
from transformers import Trainer, TrainingArguments, set_seed
|
5 |
from torch.utils.tensorboard import SummaryWriter
|
6 |
from torch import nn
|
7 |
from typing import *
|
|
|
9 |
import json
|
10 |
import os
|
11 |
|
12 |
+
def train(epochs: int, output_dir: str, config: dict, model: nn.Module, trainer, get_datasets: Callable, log_dir: str = "fake_face_logs", metric = 'accuracy', seed: int = 0):
|
13 |
|
14 |
print("------------------------- Beginning of training")
|
15 |
|
16 |
+
set_seed(seed)
|
17 |
+
|
18 |
+
# initialize the model
|
19 |
+
model = model()
|
20 |
+
|
21 |
# reformat the config integer type
|
22 |
for key, value in config.items():
|
23 |
|
|
|
27 |
|
28 |
print(f"Current Config: \n {pretty}")
|
29 |
|
30 |
+
print(f"Checkpoints in {output_dir}")
|
31 |
+
|
32 |
# recuperate the dataset
|
33 |
train_dataset, test_dataset = get_datasets(config['h_flip_p'], config['v_flip_p'], config['gray_scale_p'], config['rotation'])
|
34 |
|
35 |
# initialize the arguments of the training
|
36 |
training_args = TrainingArguments(output_dir,
|
37 |
per_device_train_batch_size=config['batch_size'],
|
38 |
+
evaluation_strategy='epoch',
|
39 |
+
save_strategy='epoch',
|
40 |
+
logging_strategy='epoch',
|
41 |
num_train_epochs=epochs,
|
42 |
fp16=True,
|
43 |
save_total_limit=2,
|
|
|
44 |
push_to_hub=False,
|
45 |
logging_dir=os.path.join(log_dir, os.path.basename(output_dir)),
|
46 |
load_best_model_at_end=True,
|
47 |
+
learning_rate=config['lr']
|
|
|
48 |
)
|
49 |
|
50 |
# train the model
|
fake_face_detection/utils/__pycache__/compute_weights.cpython-310.pyc
CHANGED
Binary files a/fake_face_detection/utils/__pycache__/compute_weights.cpython-310.pyc and b/fake_face_detection/utils/__pycache__/compute_weights.cpython-310.pyc differ
|
|
fake_face_detection/utils/visualize_images.py
CHANGED
@@ -15,7 +15,19 @@ def visualize_images(images_dict: Dict[str, Iterable[Union[JpegImageFile, torch.
|
|
15 |
log_directory: str = "fake_face_logs",
|
16 |
n_images: int = 40,
|
17 |
figsize = (15, 15),
|
18 |
-
seed: Union[int, None] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
assert len(images_dict) > 0
|
21 |
|
@@ -77,3 +89,5 @@ def visualize_images(images_dict: Dict[str, Iterable[Union[JpegImageFile, torch.
|
|
77 |
|
78 |
writer.add_figure(tag = tag, figure = fig)
|
79 |
|
|
|
|
|
|
15 |
log_directory: str = "fake_face_logs",
|
16 |
n_images: int = 40,
|
17 |
figsize = (15, 15),
|
18 |
+
seed: Union[int, None] = None,
|
19 |
+
show: bool = True
|
20 |
+
):
|
21 |
+
"""Visualize some images from a dictionary
|
22 |
+
|
23 |
+
Args:
|
24 |
+
images_dict (Dict[str, Iterable[Union[JpegImageFile, torch.Tensor, np.ndarray]]]): The dictionary of the images with key indicating the tag
|
25 |
+
log_directory (str, optional): The tensorboard log directory. Defaults to "fake_face_logs".
|
26 |
+
n_images (int, optional): The number of images. Defaults to 40.
|
27 |
+
figsize (tuple, optional): The figure size. Defaults to (15, 15).
|
28 |
+
seed (Union[int, None], optional): The seed. Defaults to None.
|
29 |
+
show (bool): Indicate if we want to show the figure. Defaults to True.
|
30 |
+
"""
|
31 |
|
32 |
assert len(images_dict) > 0
|
33 |
|
|
|
89 |
|
90 |
writer.add_figure(tag = tag, figure = fig)
|
91 |
|
92 |
+
if show: return fig
|
93 |
+
|