admin commited on
Commit
319d3b5
·
1 Parent(s): 2f1013d
.gitattributes CHANGED
@@ -1,35 +1,34 @@
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
4
  *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
  *.model filter=lfs diff=lfs merge=lfs -text
13
  *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
  *.onnx filter=lfs diff=lfs merge=lfs -text
17
  *.ot filter=lfs diff=lfs merge=lfs -text
18
  *.parquet filter=lfs diff=lfs merge=lfs -text
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
  *.bz2 filter=lfs diff=lfs merge=lfs -text
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
 
11
  *.model filter=lfs diff=lfs merge=lfs -text
12
  *.msgpack filter=lfs diff=lfs merge=lfs -text
 
 
13
  *.onnx filter=lfs diff=lfs merge=lfs -text
14
  *.ot filter=lfs diff=lfs merge=lfs -text
15
  *.parquet filter=lfs diff=lfs merge=lfs -text
16
  *.pb filter=lfs diff=lfs merge=lfs -text
 
 
17
  *.pt filter=lfs diff=lfs merge=lfs -text
18
  *.pth filter=lfs diff=lfs merge=lfs -text
19
  *.rar filter=lfs diff=lfs merge=lfs -text
 
20
  saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
22
  *.tflite filter=lfs diff=lfs merge=lfs -text
23
  *.tgz filter=lfs diff=lfs merge=lfs -text
 
24
  *.xz filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
34
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ mivolo/model/__pycache__/*
2
+ mivolo/data/__pycache__/*
3
+ mivolo/__pycache__/*
4
+ __pycache__/*
5
+ model/*
6
+ rename.sh
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
- title: Demo
3
- emoji: 🌖
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.35.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Gender Age Detector
3
+ emoji: 👩🧑‍🦲
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.36.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ arxiv: 2307.04616
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import imghdr
4
+ import shutil
5
+ import warnings
6
+ import numpy as np
7
+ import gradio as gr
8
+ from dataclasses import dataclass
9
+ from mivolo.predictor import Predictor
10
+ from utils import is_url, download_file, get_jpg_files, MODEL_DIR
11
+
12
+ TMP_DIR = "./__pycache__"
13
+
14
+
15
+ @dataclass
16
+ class Cfg:
17
+ detector_weights: str
18
+ checkpoint: str
19
+ device: str = "cpu"
20
+ with_persons: bool = True
21
+ disable_faces: bool = False
22
+ draw: bool = True
23
+
24
+
25
+ class ValidImgDetector:
26
+ predictor = None
27
+
28
+ def __init__(self):
29
+ detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt"
30
+ age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar"
31
+ predictor_cfg = Cfg(detector_path, age_gender_path)
32
+ self.predictor = Predictor(predictor_cfg)
33
+
34
+ def _detect(
35
+ self,
36
+ image: np.ndarray,
37
+ score_threshold: float,
38
+ iou_threshold: float,
39
+ mode: str,
40
+ predictor: Predictor,
41
+ ) -> np.ndarray:
42
+ # input is rgb image, output must be rgb too
43
+ predictor.detector.detector_kwargs["conf"] = score_threshold
44
+ predictor.detector.detector_kwargs["iou"] = iou_threshold
45
+ if mode == "Use persons and faces":
46
+ use_persons = True
47
+ disable_faces = False
48
+
49
+ elif mode == "Use persons only":
50
+ use_persons = True
51
+ disable_faces = True
52
+
53
+ elif mode == "Use faces only":
54
+ use_persons = False
55
+ disable_faces = False
56
+
57
+ predictor.age_gender_model.meta.use_persons = use_persons
58
+ predictor.age_gender_model.meta.disable_faces = disable_faces
59
+ # image = image[:, :, ::-1] # RGB -> BGR
60
+ detected_objects, out_im = predictor.recognize(image)
61
+ has_child, has_female, has_male = False, False, False
62
+ if len(detected_objects.ages) > 0:
63
+ has_child = min(detected_objects.ages) < 18
64
+ has_female = "female" in detected_objects.genders
65
+ has_male = "male" in detected_objects.genders
66
+
67
+ return out_im[:, :, ::-1], has_child, has_female, has_male
68
+
69
+ def valid_img(self, img_path):
70
+ image = cv2.imread(img_path)
71
+ return self._detect(image, 0.4, 0.7, "Use persons and faces", self.predictor)
72
+
73
+
74
+ def infer(photo: str):
75
+ if is_url(photo):
76
+ if os.path.exists(TMP_DIR):
77
+ shutil.rmtree(TMP_DIR)
78
+
79
+ photo = download_file(photo, f"{TMP_DIR}/download.jpg")
80
+
81
+ detector = ValidImgDetector()
82
+ if not photo or not os.path.exists(photo) or imghdr.what(photo) == None:
83
+ return None, None, None, "请正确输入图片 Please input image correctly"
84
+
85
+ return detector.valid_img(photo)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ with gr.Blocks() as iface:
90
+ warnings.filterwarnings("ignore")
91
+ with gr.Tab("上传模式 Upload Mode"):
92
+ gr.Interface(
93
+ fn=infer,
94
+ inputs=gr.Image(label="上传照片 Upload Photo", type="filepath"),
95
+ outputs=[
96
+ gr.Image(label="检测结果 Detection Result", type="numpy"),
97
+ gr.Textbox(label="存在儿童 Has Child"),
98
+ gr.Textbox(label="存在女性 Has Female"),
99
+ gr.Textbox(label="存在男性 Has Male"),
100
+ ],
101
+ examples=get_jpg_files(f"{MODEL_DIR}/examples"),
102
+ allow_flagging="never",
103
+ )
104
+
105
+ with gr.Tab("在线模式 Online Mode"):
106
+ gr.Interface(
107
+ fn=infer,
108
+ inputs=gr.Textbox(label="网络图片链接 Online Picture URL"),
109
+ outputs=[
110
+ gr.Image(label="检测结果 Detection Result", type="numpy"),
111
+ gr.Textbox(label="存在儿童 Has Child"),
112
+ gr.Textbox(label="存在女性 Has Female"),
113
+ gr.Textbox(label="存在男性 Has Male"),
114
+ ],
115
+ allow_flagging="never",
116
+ cache_examples=False,
117
+ )
118
+
119
+ iface.launch()
mivolo/data/data_reader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import defaultdict
3
+ from dataclasses import dataclass, field
4
+ from enum import Enum
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ import pandas as pd
8
+
9
+ IMAGES_EXT: Tuple = (".jpeg", ".jpg", ".png", ".webp", ".bmp", ".gif")
10
+ VIDEO_EXT: Tuple = (".mp4", ".avi", ".mov", ".mkv", ".webm")
11
+
12
+
13
+ @dataclass
14
+ class PictureInfo:
15
+ image_path: str
16
+ age: Optional[str] # age or age range(start;end format) or "-1"
17
+ gender: Optional[str] # "M" of "F" or "-1"
18
+ bbox: List[int] = field(default_factory=lambda: [-1, -1, -1, -1]) # face bbox: xyxy
19
+ person_bbox: List[int] = field(default_factory=lambda: [-1, -1, -1, -1]) # person bbox: xyxy
20
+
21
+ @property
22
+ def has_person_bbox(self) -> bool:
23
+ return any(coord != -1 for coord in self.person_bbox)
24
+
25
+ @property
26
+ def has_face_bbox(self) -> bool:
27
+ return any(coord != -1 for coord in self.bbox)
28
+
29
+ def has_gt(self, only_age: bool = False) -> bool:
30
+ if only_age:
31
+ return self.age != "-1"
32
+ else:
33
+ return not (self.age == "-1" and self.gender == "-1")
34
+
35
+ def clear_person_bbox(self):
36
+ self.person_bbox = [-1, -1, -1, -1]
37
+
38
+ def clear_face_bbox(self):
39
+ self.bbox = [-1, -1, -1, -1]
40
+
41
+
42
+ class AnnotType(Enum):
43
+ ORIGINAL = "original"
44
+ PERSONS = "persons"
45
+ NONE = "none"
46
+
47
+ @classmethod
48
+ def _missing_(cls, value):
49
+ print(f"WARN: Unknown annotation type {value}.")
50
+ return AnnotType.NONE
51
+
52
+
53
+ def get_all_files(path: str, extensions: Tuple = IMAGES_EXT):
54
+ files_all = []
55
+ for root, subFolders, files in os.walk(path):
56
+ for name in files:
57
+ # linux tricks with .directory that still is file
58
+ if "directory" not in name and sum([ext.lower() in name.lower() for ext in extensions]) > 0:
59
+ files_all.append(os.path.join(root, name))
60
+ return files_all
61
+
62
+
63
+ class InputType(Enum):
64
+ Image = 0
65
+ Video = 1
66
+ VideoStream = 2
67
+
68
+
69
+ def get_input_type(input_path: str) -> InputType:
70
+ if os.path.isdir(input_path):
71
+ print("Input is a folder, only images will be processed")
72
+ return InputType.Image
73
+ elif os.path.isfile(input_path):
74
+ if input_path.endswith(VIDEO_EXT):
75
+ return InputType.Video
76
+ if input_path.endswith(IMAGES_EXT):
77
+ return InputType.Image
78
+ else:
79
+ raise ValueError(
80
+ f"Unknown or unsupported input file format {input_path}, \
81
+ supported video formats: {VIDEO_EXT}, \
82
+ supported image formats: {IMAGES_EXT}"
83
+ )
84
+ elif input_path.startswith("http") and not input_path.endswith(IMAGES_EXT):
85
+ return InputType.VideoStream
86
+ else:
87
+ raise ValueError(f"Unknown input {input_path}")
88
+
89
+
90
+ def read_csv_annotation_file(annotation_file: str, images_dir: str, ignore_without_gt=False):
91
+ bboxes_per_image: Dict[str, List[PictureInfo]] = defaultdict(list)
92
+
93
+ df = pd.read_csv(annotation_file, sep=",")
94
+
95
+ annot_type = AnnotType("persons") if "person_x0" in df.columns else AnnotType("original")
96
+ print(f"Reading {annotation_file} (type: {annot_type})...")
97
+
98
+ missing_images = 0
99
+ for index, row in df.iterrows():
100
+ img_path = os.path.join(images_dir, row["img_name"])
101
+ if not os.path.exists(img_path):
102
+ missing_images += 1
103
+ continue
104
+
105
+ face_x1, face_y1, face_x2, face_y2 = row["face_x0"], row["face_y0"], row["face_x1"], row["face_y1"]
106
+ age, gender = str(row["age"]), str(row["gender"])
107
+
108
+ if ignore_without_gt and (age == "-1" or gender == "-1"):
109
+ continue
110
+
111
+ if annot_type == AnnotType.PERSONS:
112
+ p_x1, p_y1, p_x2, p_y2 = row["person_x0"], row["person_y0"], row["person_x1"], row["person_y1"]
113
+ person_bbox = list(map(int, [p_x1, p_y1, p_x2, p_y2]))
114
+ else:
115
+ person_bbox = [-1, -1, -1, -1]
116
+
117
+ bbox = list(map(int, [face_x1, face_y1, face_x2, face_y2]))
118
+ pic_info = PictureInfo(img_path, age, gender, bbox, person_bbox)
119
+ assert isinstance(pic_info.person_bbox, list)
120
+
121
+ bboxes_per_image[img_path].append(pic_info)
122
+
123
+ if missing_images > 0:
124
+ print(f"WARNING: Missing images: {missing_images}/{len(df)}")
125
+ return bboxes_per_image, annot_type
mivolo/data/dataset/__init__.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ from mivolo.model.mi_volo import MiVOLO
5
+
6
+ from .age_gender_dataset import AgeGenderDataset
7
+ from .age_gender_loader import create_loader
8
+ from .classification_dataset import AdienceDataset, FairFaceDataset
9
+
10
+ DATASET_CLASS_MAP = {
11
+ "utk": AgeGenderDataset,
12
+ "lagenda": AgeGenderDataset,
13
+ "imdb": AgeGenderDataset,
14
+ "adience": AdienceDataset,
15
+ "fairface": FairFaceDataset,
16
+ }
17
+
18
+
19
+ def build(
20
+ name: str,
21
+ images_path: str,
22
+ annotations_path: str,
23
+ split: str,
24
+ mivolo_model: MiVOLO,
25
+ workers: int,
26
+ batch_size: int,
27
+ ) -> Tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader]:
28
+
29
+ dataset_class = DATASET_CLASS_MAP[name]
30
+
31
+ dataset: torch.utils.data.Dataset = dataset_class(
32
+ images_path=images_path,
33
+ annotations_path=annotations_path,
34
+ name=name,
35
+ split=split,
36
+ target_size=mivolo_model.input_size,
37
+ max_age=mivolo_model.meta.max_age,
38
+ min_age=mivolo_model.meta.min_age,
39
+ model_with_persons=mivolo_model.meta.with_persons_model,
40
+ use_persons=mivolo_model.meta.use_persons,
41
+ disable_faces=mivolo_model.meta.disable_faces,
42
+ only_age=mivolo_model.meta.only_age,
43
+ )
44
+
45
+ data_config = mivolo_model.data_config
46
+
47
+ in_chans = 3 if not mivolo_model.meta.with_persons_model else 6
48
+ input_size = (in_chans, mivolo_model.input_size, mivolo_model.input_size)
49
+
50
+ dataset_loader: torch.utils.data.DataLoader = create_loader(
51
+ dataset,
52
+ input_size=input_size,
53
+ batch_size=batch_size,
54
+ mean=data_config["mean"],
55
+ std=data_config["std"],
56
+ num_workers=workers,
57
+ crop_pct=data_config["crop_pct"],
58
+ crop_mode=data_config["crop_mode"],
59
+ pin_memory=False,
60
+ device=mivolo_model.device,
61
+ target_type=dataset.target_dtype,
62
+ )
63
+
64
+ return dataset, dataset_loader
mivolo/data/dataset/age_gender_dataset.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, List, Optional, Set
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from mivolo.data.dataset.reader_age_gender import ReaderAgeGender
8
+ from PIL import Image
9
+ from torchvision import transforms
10
+
11
+ _logger = logging.getLogger("AgeGenderDataset")
12
+
13
+
14
+ class AgeGenderDataset(torch.utils.data.Dataset):
15
+ def __init__(
16
+ self,
17
+ images_path,
18
+ annotations_path,
19
+ name=None,
20
+ split="train",
21
+ load_bytes=False,
22
+ img_mode="RGB",
23
+ transform=None,
24
+ is_training=False,
25
+ seed=1234,
26
+ target_size=224,
27
+ min_age=None,
28
+ max_age=None,
29
+ model_with_persons=False,
30
+ use_persons=False,
31
+ disable_faces=False,
32
+ only_age=False,
33
+ ):
34
+ reader = ReaderAgeGender(
35
+ images_path,
36
+ annotations_path,
37
+ split=split,
38
+ seed=seed,
39
+ target_size=target_size,
40
+ with_persons=use_persons,
41
+ disable_faces=disable_faces,
42
+ only_age=only_age,
43
+ )
44
+
45
+ self.name = name
46
+ self.model_with_persons = model_with_persons
47
+ self.reader = reader
48
+ self.load_bytes = load_bytes
49
+ self.img_mode = img_mode
50
+ self.transform = transform
51
+ self._consecutive_errors = 0
52
+ self.is_training = is_training
53
+ self.random_flip = 0.0
54
+
55
+ # Setting up classes.
56
+ # If min and max classes are passed - use them to have the same preprocessing for validation
57
+ self.max_age: float = None
58
+ self.min_age: float = None
59
+ self.avg_age: float = None
60
+ self.set_ages_min_max(min_age, max_age)
61
+
62
+ self.genders = ["M", "F"]
63
+ self.num_classes_gender = len(self.genders)
64
+
65
+ self.age_classes: Optional[List[str]] = self.set_age_classes()
66
+
67
+ self.num_classes_age = 1 if self.age_classes is None else len(self.age_classes)
68
+ self.num_classes: int = self.num_classes_age + self.num_classes_gender
69
+ self.target_dtype = torch.float32
70
+
71
+ def set_age_classes(self) -> Optional[List[str]]:
72
+ return None # for regression dataset
73
+
74
+ def set_ages_min_max(self, min_age: Optional[float], max_age: Optional[float]):
75
+
76
+ assert all(age is None for age in [min_age, max_age]) or all(
77
+ age is not None for age in [min_age, max_age]
78
+ ), "Both min and max age must be passed or none of them"
79
+
80
+ if max_age is not None and min_age is not None:
81
+ _logger.info(f"Received predefined min_age {min_age} and max_age {max_age}")
82
+ self.max_age = max_age
83
+ self.min_age = min_age
84
+ else:
85
+ # collect statistics from loaded dataset
86
+ all_ages_set: Set[int] = set()
87
+ for img_path, image_samples in self.reader._ann.items():
88
+ for image_sample_info in image_samples:
89
+ if image_sample_info.age == "-1":
90
+ continue
91
+ age = round(float(image_sample_info.age))
92
+ all_ages_set.add(age)
93
+
94
+ self.max_age = max(all_ages_set)
95
+ self.min_age = min(all_ages_set)
96
+
97
+ self.avg_age = (self.max_age + self.min_age) / 2.0
98
+
99
+ def _norm_age(self, age):
100
+ return (age - self.avg_age) / (self.max_age - self.min_age)
101
+
102
+ def parse_gender(self, _gender: str) -> float:
103
+ if _gender != "-1":
104
+ gender = float(0 if _gender == "M" or _gender == "0" else 1)
105
+ else:
106
+ gender = -1
107
+ return gender
108
+
109
+ def parse_target(self, _age: str, gender: str) -> List[Any]:
110
+ if _age != "-1":
111
+ age = round(float(_age))
112
+ age = self._norm_age(float(age))
113
+ else:
114
+ age = -1
115
+
116
+ target: List[float] = [age, self.parse_gender(gender)]
117
+ return target
118
+
119
+ @property
120
+ def transform(self):
121
+ return self._transform
122
+
123
+ @transform.setter
124
+ def transform(self, transform):
125
+ # Disable pretrained monkey-patched transforms
126
+ if not transform:
127
+ return
128
+
129
+ _trans = []
130
+ for trans in transform.transforms:
131
+ if "Resize" in str(trans):
132
+ continue
133
+ if "Crop" in str(trans):
134
+ continue
135
+ _trans.append(trans)
136
+ self._transform = transforms.Compose(_trans)
137
+
138
+ def apply_tranforms(self, image: Optional[np.ndarray]) -> np.ndarray:
139
+ if image is None:
140
+ return None
141
+
142
+ if self.transform is None:
143
+ return image
144
+
145
+ image = convert_to_pil(image, self.img_mode)
146
+ for trans in self.transform.transforms:
147
+ image = trans(image)
148
+ return image
149
+
150
+ def __getitem__(self, index):
151
+ # get preprocessed face and person crops (np.ndarray)
152
+ # resize + pad, for person crops: cut off other bboxes
153
+ images, target = self.reader[index]
154
+
155
+ target = self.parse_target(*target)
156
+
157
+ if self.model_with_persons:
158
+ face_image, person_image = images
159
+ person_image: np.ndarray = self.apply_tranforms(person_image)
160
+ else:
161
+ face_image = images[0]
162
+ person_image = None
163
+
164
+ face_image: np.ndarray = self.apply_tranforms(face_image)
165
+
166
+ if person_image is not None:
167
+ img = np.concatenate([face_image, person_image], axis=0)
168
+ else:
169
+ img = face_image
170
+
171
+ return img, target
172
+
173
+ def __len__(self):
174
+ return len(self.reader)
175
+
176
+ def filename(self, index, basename=False, absolute=False):
177
+ return self.reader.filename(index, basename, absolute)
178
+
179
+ def filenames(self, basename=False, absolute=False):
180
+ return self.reader.filenames(basename, absolute)
181
+
182
+
183
+ def convert_to_pil(cv_im: Optional[np.ndarray], img_mode: str = "RGB") -> "Image":
184
+ if cv_im is None:
185
+ return None
186
+
187
+ if img_mode == "RGB":
188
+ cv_im = cv2.cvtColor(cv_im, cv2.COLOR_BGR2RGB)
189
+ else:
190
+ raise Exception("Incorrect image mode has been passed!")
191
+
192
+ cv_im = np.ascontiguousarray(cv_im)
193
+ pil_image = Image.fromarray(cv_im)
194
+ return pil_image
mivolo/data/dataset/age_gender_loader.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from timm https://github.com/huggingface/pytorch-image-models
3
+
4
+ Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
+ """
6
+
7
+ import logging
8
+ from contextlib import suppress
9
+ from functools import partial
10
+ from itertools import repeat
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.utils.data
15
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16
+ from timm.data.dataset import IterableImageDataset
17
+ from timm.data.loader import PrefetchLoader, _worker_init
18
+ from timm.data.transforms_factory import create_transform
19
+
20
+ _logger = logging.getLogger(__name__)
21
+
22
+
23
+ def fast_collate(batch, target_dtype=torch.uint8):
24
+ """A fast collation function optimized for uint8 images (np array or torch) and target_dtype targets (labels)"""
25
+ assert isinstance(batch[0], tuple)
26
+ batch_size = len(batch)
27
+ if isinstance(batch[0][0], np.ndarray):
28
+ targets = torch.tensor([b[1] for b in batch], dtype=target_dtype)
29
+ assert len(targets) == batch_size
30
+ tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
31
+ for i in range(batch_size):
32
+ tensor[i] += torch.from_numpy(batch[i][0])
33
+ return tensor, targets
34
+ else:
35
+ raise ValueError(f"Incorrect batch type: {type(batch[0][0])}")
36
+
37
+
38
+ def adapt_to_chs(x, n):
39
+ if not isinstance(x, (tuple, list)):
40
+ x = tuple(repeat(x, n))
41
+ elif len(x) != n:
42
+ # doubled channels
43
+ if len(x) * 2 == n:
44
+ x = np.concatenate((x, x))
45
+ _logger.warning(f"Pretrained mean/std different shape than model (doubled channes), using concat: {x}.")
46
+ else:
47
+ x_mean = np.mean(x).item()
48
+ x = (x_mean,) * n
49
+ _logger.warning(f"Pretrained mean/std different shape than model, using avg value {x}.")
50
+ else:
51
+ assert len(x) == n, "normalization stats must match image channels"
52
+ return x
53
+
54
+
55
+ class PrefetchLoaderForMultiInput(PrefetchLoader):
56
+ def __init__(
57
+ self,
58
+ loader,
59
+ mean=IMAGENET_DEFAULT_MEAN,
60
+ std=IMAGENET_DEFAULT_STD,
61
+ channels=3,
62
+ device=torch.device("cpu"),
63
+ img_dtype=torch.float32,
64
+ ):
65
+
66
+ mean = adapt_to_chs(mean, channels)
67
+ std = adapt_to_chs(std, channels)
68
+ normalization_shape = (1, channels, 1, 1)
69
+
70
+ self.loader = loader
71
+ self.device = device
72
+ self.img_dtype = img_dtype
73
+ self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
74
+ self.std = torch.tensor([x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
75
+
76
+ self.is_cuda = torch.cuda.is_available() and device.type == "cpu"
77
+
78
+ def __iter__(self):
79
+ first = True
80
+ if self.is_cuda:
81
+ stream = torch.cuda.Stream()
82
+ stream_context = partial(torch.cuda.stream, stream=stream)
83
+ else:
84
+ stream = None
85
+ stream_context = suppress
86
+
87
+ for next_input, next_target in self.loader:
88
+
89
+ with stream_context():
90
+ next_input = next_input.to(device=self.device, non_blocking=True)
91
+ next_target = next_target.to(device=self.device, non_blocking=True)
92
+ next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
93
+
94
+ if not first:
95
+ yield input, target # noqa: F823, F821
96
+ else:
97
+ first = False
98
+
99
+ if stream is not None:
100
+ torch.cuda.current_stream().wait_stream(stream)
101
+
102
+ input = next_input
103
+ target = next_target
104
+
105
+ yield input, target
106
+
107
+
108
+ def create_loader(
109
+ dataset,
110
+ input_size,
111
+ batch_size,
112
+ mean=IMAGENET_DEFAULT_MEAN,
113
+ std=IMAGENET_DEFAULT_STD,
114
+ num_workers=1,
115
+ crop_pct=None,
116
+ crop_mode=None,
117
+ pin_memory=False,
118
+ img_dtype=torch.float32,
119
+ device=torch.device("cpu"),
120
+ persistent_workers=True,
121
+ worker_seeding="all",
122
+ target_type=torch.int64,
123
+ ):
124
+
125
+ transform = create_transform(
126
+ input_size,
127
+ is_training=False,
128
+ use_prefetcher=True,
129
+ mean=mean,
130
+ std=std,
131
+ crop_pct=crop_pct,
132
+ crop_mode=crop_mode,
133
+ )
134
+ dataset.transform = transform
135
+
136
+ if isinstance(dataset, IterableImageDataset):
137
+ # give Iterable datasets early knowledge of num_workers so that sample estimates
138
+ # are correct before worker processes are launched
139
+ dataset.set_loader_cfg(num_workers=num_workers)
140
+ raise ValueError("Incorrect dataset type: IterableImageDataset")
141
+
142
+ loader_class = torch.utils.data.DataLoader
143
+ loader_args = dict(
144
+ batch_size=batch_size,
145
+ shuffle=False,
146
+ num_workers=num_workers,
147
+ sampler=None,
148
+ collate_fn=lambda batch: fast_collate(batch, target_dtype=target_type),
149
+ pin_memory=pin_memory,
150
+ drop_last=False,
151
+ worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
152
+ persistent_workers=persistent_workers,
153
+ )
154
+ try:
155
+ loader = loader_class(dataset, **loader_args)
156
+ except TypeError:
157
+ loader_args.pop("persistent_workers") # only in Pytorch 1.7+
158
+ loader = loader_class(dataset, **loader_args)
159
+
160
+ loader = PrefetchLoaderForMultiInput(
161
+ loader,
162
+ mean=mean,
163
+ std=std,
164
+ channels=input_size[0],
165
+ device=device,
166
+ img_dtype=img_dtype,
167
+ )
168
+
169
+ return loader
mivolo/data/dataset/classification_dataset.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Optional
2
+
3
+ import torch
4
+
5
+ from .age_gender_dataset import AgeGenderDataset
6
+
7
+
8
+ class ClassificationDataset(AgeGenderDataset):
9
+ def __init__(self, *args, **kwargs):
10
+ super().__init__(*args, **kwargs)
11
+
12
+ self.target_dtype = torch.int32
13
+
14
+ def set_age_classes(self) -> Optional[List[str]]:
15
+ raise NotImplementedError
16
+
17
+ def parse_target(self, age: str, gender: str) -> List[Any]:
18
+ assert self.age_classes is not None
19
+ if age != "-1":
20
+ assert age in self.age_classes, f"Unknown category in {self.name} dataset: {age}"
21
+ age_ind = self.age_classes.index(age)
22
+ else:
23
+ age_ind = -1
24
+
25
+ target: List[int] = [age_ind, int(self.parse_gender(gender))]
26
+ return target
27
+
28
+
29
+ class FairFaceDataset(ClassificationDataset):
30
+ def set_age_classes(self) -> Optional[List[str]]:
31
+ age_classes = ["0;2", "3;9", "10;19", "20;29", "30;39", "40;49", "50;59", "60;69", "70;120"]
32
+ # a[i-1] <= v < a[i] => age_classes[i-1]
33
+ self._intervals = torch.tensor([0, 3, 10, 20, 30, 40, 50, 60, 70])
34
+
35
+ return age_classes
36
+
37
+
38
+ class AdienceDataset(ClassificationDataset):
39
+ def __init__(self, *args, **kwargs):
40
+ super().__init__(*args, **kwargs)
41
+
42
+ self.target_dtype = torch.int32
43
+
44
+ def set_age_classes(self) -> Optional[List[str]]:
45
+ age_classes = ["0;2", "4;6", "8;12", "15;20", "25;32", "38;43", "48;53", "60;100"]
46
+ # a[i-1] <= v < a[i] => age_classes[i-1]
47
+ self._intervals = torch.tensor([0, 4, 7, 14, 24, 36, 46, 57])
48
+ return age_classes
mivolo/data/dataset/reader_age_gender.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from functools import partial
4
+ from multiprocessing.pool import ThreadPool
5
+ from typing import Dict, List, Optional, Tuple
6
+
7
+ import cv2
8
+ import numpy as np
9
+ from mivolo.data.data_reader import AnnotType, PictureInfo, get_all_files, read_csv_annotation_file
10
+ from mivolo.data.misc import IOU, class_letterbox, cropout_black_parts
11
+ from timm.data.readers.reader import Reader
12
+ from tqdm import tqdm
13
+
14
+ CROP_ROUND_TOL = 0.3
15
+ MIN_PERSON_SIZE = 100
16
+ MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4
17
+
18
+ _logger = logging.getLogger("ReaderAgeGender")
19
+
20
+
21
+ class ReaderAgeGender(Reader):
22
+ """
23
+ Reader for almost original imdb-wiki cleaned dataset.
24
+ Two changes:
25
+ 1. Your annotation must be in ./annotation subdir of dataset root
26
+ 2. Images must be in images subdir
27
+
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ images_path,
33
+ annotations_path,
34
+ split="validation",
35
+ target_size=224,
36
+ min_size=5,
37
+ seed=1234,
38
+ with_persons=False,
39
+ min_person_size=MIN_PERSON_SIZE,
40
+ disable_faces=False,
41
+ only_age=False,
42
+ min_person_aftercut_ratio=MIN_PERSON_CROP_AFTERCUT_RATIO,
43
+ crop_round_tol=CROP_ROUND_TOL,
44
+ ):
45
+ super().__init__()
46
+
47
+ self.with_persons = with_persons
48
+ self.disable_faces = disable_faces
49
+ self.only_age = only_age
50
+
51
+ # can be only black for now, even though it's not very good with further normalization
52
+ self.crop_out_color = (0, 0, 0)
53
+
54
+ self.empty_crop = np.ones((target_size, target_size, 3)) * self.crop_out_color
55
+ self.empty_crop = self.empty_crop.astype(np.uint8)
56
+
57
+ self.min_person_size = min_person_size
58
+ self.min_person_aftercut_ratio = min_person_aftercut_ratio
59
+ self.crop_round_tol = crop_round_tol
60
+
61
+ self.split = split
62
+ self.min_size = min_size
63
+ self.seed = seed
64
+ self.target_size = target_size
65
+
66
+ # Reading annotations. Can be multiple files if annotations_path dir
67
+ self._ann: Dict[str, List[PictureInfo]] = {} # list of samples for each image
68
+ self._associated_objects: Dict[str, Dict[int, List[List[int]]]] = {}
69
+ self._faces_list: List[Tuple[str, int]] = [] # samples from this list will be loaded in __getitem__
70
+
71
+ self._read_annotations(images_path, annotations_path)
72
+ _logger.info(f"Dataset length: {len(self._faces_list)} crops")
73
+
74
+ def __getitem__(self, index):
75
+ return self._read_img_and_label(index)
76
+
77
+ def __len__(self):
78
+ return len(self._faces_list)
79
+
80
+ def _filename(self, index, basename=False, absolute=False):
81
+ img_p = self._faces_list[index][0]
82
+ return os.path.basename(img_p) if basename else img_p
83
+
84
+ def _read_annotations(self, images_path, csvs_path):
85
+ self._ann = {}
86
+ self._faces_list = []
87
+ self._associated_objects = {}
88
+
89
+ csvs = get_all_files(csvs_path, [".csv"])
90
+ csvs = [c for c in csvs if self.split in os.path.basename(c)]
91
+
92
+ # load annotations per image
93
+ for csv in csvs:
94
+ db, ann_type = read_csv_annotation_file(csv, images_path)
95
+ if self.with_persons and ann_type != AnnotType.PERSONS:
96
+ raise ValueError(
97
+ f"Annotation type in file {csv} contains no persons, "
98
+ f"but annotations with persons are requested."
99
+ )
100
+ self._ann.update(db)
101
+
102
+ if len(self._ann) == 0:
103
+ raise ValueError("Annotations are empty!")
104
+
105
+ self._ann, self._associated_objects = self.prepare_annotations()
106
+ images_list = list(self._ann.keys())
107
+
108
+ for img_path in images_list:
109
+ for index, image_sample_info in enumerate(self._ann[img_path]):
110
+ assert image_sample_info.has_gt(
111
+ self.only_age
112
+ ), "Annotations must be checked with self.prepare_annotations() func"
113
+ self._faces_list.append((img_path, index))
114
+
115
+ def _read_img_and_label(self, index):
116
+ if not isinstance(index, int):
117
+ raise TypeError("ReaderAgeGender expected index to be integer")
118
+
119
+ img_p, face_index = self._faces_list[index]
120
+ ann: PictureInfo = self._ann[img_p][face_index]
121
+ img = cv2.imread(img_p)
122
+
123
+ face_empty = True
124
+ if ann.has_face_bbox and not (self.with_persons and self.disable_faces):
125
+ face_crop, face_empty = self._get_crop(ann.bbox, img)
126
+
127
+ if not self.with_persons and face_empty:
128
+ # model without persons
129
+ raise ValueError("Annotations must be checked with self.prepare_annotations() func")
130
+
131
+ if face_empty:
132
+ face_crop = self.empty_crop
133
+
134
+ person_empty = True
135
+ if self.with_persons or self.disable_faces:
136
+ if ann.has_person_bbox:
137
+ # cut off all associated objects from person crop
138
+ objects = self._associated_objects[img_p][face_index]
139
+ person_crop, person_empty = self._get_crop(
140
+ ann.person_bbox,
141
+ img,
142
+ crop_out_color=self.crop_out_color,
143
+ asced_objects=objects,
144
+ )
145
+
146
+ if face_empty and person_empty:
147
+ raise ValueError("Annotations must be checked with self.prepare_annotations() func")
148
+
149
+ if person_empty:
150
+ person_crop = self.empty_crop
151
+
152
+ return (face_crop, person_crop), [ann.age, ann.gender]
153
+
154
+ def _get_crop(
155
+ self,
156
+ bbox,
157
+ img,
158
+ asced_objects=None,
159
+ crop_out_color=(0, 0, 0),
160
+ ) -> Tuple[np.ndarray, bool]:
161
+
162
+ empty_bbox = False
163
+
164
+ xmin, ymin, xmax, ymax = bbox
165
+ assert not (
166
+ ymax - ymin < self.min_size or xmax - xmin < self.min_size
167
+ ), "Annotations must be checked with self.prepare_annotations() func"
168
+
169
+ crop = img[ymin:ymax, xmin:xmax]
170
+
171
+ if asced_objects:
172
+ # cut off other objects for person crop
173
+ crop, empty_bbox = _cropout_asced_objs(
174
+ asced_objects,
175
+ bbox,
176
+ crop.copy(),
177
+ crop_out_color=crop_out_color,
178
+ min_person_size=self.min_person_size,
179
+ crop_round_tol=self.crop_round_tol,
180
+ min_person_aftercut_ratio=self.min_person_aftercut_ratio,
181
+ )
182
+ if empty_bbox:
183
+ crop = self.empty_crop
184
+
185
+ crop = class_letterbox(crop, new_shape=(self.target_size, self.target_size), color=crop_out_color)
186
+ return crop, empty_bbox
187
+
188
+ def prepare_annotations(self):
189
+
190
+ good_anns: Dict[str, List[PictureInfo]] = {}
191
+ all_associated_objects: Dict[str, Dict[int, List[List[int]]]] = {}
192
+
193
+ if not self.with_persons:
194
+ # remove all persons
195
+ for img_path, bboxes in self._ann.items():
196
+ for sample in bboxes:
197
+ sample.clear_person_bbox()
198
+
199
+ # check dataset and collect associated_objects
200
+ verify_images_func = partial(
201
+ verify_images,
202
+ min_size=self.min_size,
203
+ min_person_size=self.min_person_size,
204
+ with_persons=self.with_persons,
205
+ disable_faces=self.disable_faces,
206
+ crop_round_tol=self.crop_round_tol,
207
+ min_person_aftercut_ratio=self.min_person_aftercut_ratio,
208
+ only_age=self.only_age,
209
+ )
210
+ num_threads = min(8, os.cpu_count())
211
+
212
+ all_msgs = []
213
+ broken = 0
214
+ skipped = 0
215
+ all_skipped_crops = 0
216
+ desc = "Check annotations..."
217
+ with ThreadPool(num_threads) as pool:
218
+ pbar = tqdm(
219
+ pool.imap_unordered(verify_images_func, list(self._ann.items())),
220
+ desc=desc,
221
+ total=len(self._ann),
222
+ )
223
+
224
+ for (img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops) in pbar:
225
+ broken += 1 if is_corrupted else 0
226
+ all_msgs.extend(msgs)
227
+ all_skipped_crops += skipped_crops
228
+ skipped += 1 if is_empty_annotations else 0
229
+ if img_info is not None:
230
+ img_path, img_samples = img_info
231
+ good_anns[img_path] = img_samples
232
+ all_associated_objects.update({img_path: associated_objects})
233
+
234
+ pbar.desc = (
235
+ f"{desc} {skipped} images skipped ({all_skipped_crops} crops are incorrect); "
236
+ f"{broken} images corrupted"
237
+ )
238
+
239
+ pbar.close()
240
+
241
+ for msg in all_msgs:
242
+ print(msg)
243
+ print(f"\nLeft images: {len(good_anns)}")
244
+
245
+ return good_anns, all_associated_objects
246
+
247
+
248
+ def verify_images(
249
+ img_info,
250
+ min_size: int,
251
+ min_person_size: int,
252
+ with_persons: bool,
253
+ disable_faces: bool,
254
+ crop_round_tol: float,
255
+ min_person_aftercut_ratio: float,
256
+ only_age: bool,
257
+ ):
258
+ # If crop is too small, if image can not be read or if image does not exist
259
+ # then filter out this sample
260
+
261
+ disable_faces = disable_faces and with_persons
262
+ kwargs = dict(
263
+ min_person_size=min_person_size,
264
+ disable_faces=disable_faces,
265
+ with_persons=with_persons,
266
+ crop_round_tol=crop_round_tol,
267
+ min_person_aftercut_ratio=min_person_aftercut_ratio,
268
+ only_age=only_age,
269
+ )
270
+
271
+ def bbox_correct(bbox, min_size, im_h, im_w) -> Tuple[bool, List[int]]:
272
+ ymin, ymax, xmin, xmax = _correct_bbox(bbox, im_h, im_w)
273
+ crop_h, crop_w = ymax - ymin, xmax - xmin
274
+ if crop_h < min_size or crop_w < min_size:
275
+ return False, [-1, -1, -1, -1]
276
+ bbox = [xmin, ymin, xmax, ymax]
277
+ return True, bbox
278
+
279
+ msgs = []
280
+ skipped_crops = 0
281
+ is_corrupted = False
282
+ is_empty_annotations = False
283
+
284
+ img_path: str = img_info[0]
285
+ img_samples: List[PictureInfo] = img_info[1]
286
+ try:
287
+ im_cv = cv2.imread(img_path)
288
+ im_h, im_w = im_cv.shape[:2]
289
+ except Exception:
290
+ msgs.append(f"Can not load image {img_path}")
291
+ is_corrupted = True
292
+ return None, {}, msgs, is_corrupted, is_empty_annotations, skipped_crops
293
+
294
+ out_samples: List[PictureInfo] = []
295
+ for sample in img_samples:
296
+ # correct face bbox
297
+ if sample.has_face_bbox:
298
+ is_correct, sample.bbox = bbox_correct(sample.bbox, min_size, im_h, im_w)
299
+ if not is_correct and sample.has_gt(only_age):
300
+ msgs.append("Small face. Passing..")
301
+ skipped_crops += 1
302
+
303
+ # correct person bbox
304
+ if sample.has_person_bbox:
305
+ is_correct, sample.person_bbox = bbox_correct(
306
+ sample.person_bbox, max(min_person_size, min_size), im_h, im_w
307
+ )
308
+ if not is_correct and sample.has_gt(only_age):
309
+ msgs.append(f"Small person {img_path}. Passing..")
310
+ skipped_crops += 1
311
+
312
+ if sample.has_face_bbox or sample.has_person_bbox:
313
+ out_samples.append(sample)
314
+ elif sample.has_gt(only_age):
315
+ msgs.append("Sample hs no face and no body. Passing..")
316
+ skipped_crops += 1
317
+
318
+ # sort that samples with undefined age and gender be the last
319
+ out_samples = sorted(out_samples, key=lambda sample: 1 if not sample.has_gt(only_age) else 0)
320
+
321
+ # for each person find other faces and persons bboxes, intersected with it
322
+ associated_objects: Dict[int, List[List[int]]] = find_associated_objects(out_samples, only_age=only_age)
323
+
324
+ out_samples, associated_objects, skipped_crops = filter_bad_samples(
325
+ out_samples, associated_objects, im_cv, msgs, skipped_crops, **kwargs
326
+ )
327
+
328
+ out_img_info: Optional[Tuple[str, List]] = (img_path, out_samples)
329
+ if len(out_samples) == 0:
330
+ out_img_info = None
331
+ is_empty_annotations = True
332
+
333
+ return out_img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops
334
+
335
+
336
+ def filter_bad_samples(
337
+ out_samples: List[PictureInfo],
338
+ associated_objects: dict,
339
+ im_cv: np.ndarray,
340
+ msgs: List[str],
341
+ skipped_crops: int,
342
+ **kwargs,
343
+ ):
344
+ with_persons, disable_faces, min_person_size, crop_round_tol, min_person_aftercut_ratio, only_age = (
345
+ kwargs["with_persons"],
346
+ kwargs["disable_faces"],
347
+ kwargs["min_person_size"],
348
+ kwargs["crop_round_tol"],
349
+ kwargs["min_person_aftercut_ratio"],
350
+ kwargs["only_age"],
351
+ )
352
+
353
+ # left only samples with annotations
354
+ inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_gt(only_age)]
355
+ out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
356
+
357
+ if kwargs["disable_faces"]:
358
+ # clear all faces
359
+ for ind, sample in enumerate(out_samples):
360
+ sample.clear_face_bbox()
361
+
362
+ # left only samples with person_bbox
363
+ inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_person_bbox]
364
+ out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
365
+
366
+ if with_persons or disable_faces:
367
+ # check that preprocessing func
368
+ # _cropout_asced_objs() return not empty person_image for each out sample
369
+
370
+ inds = []
371
+ for ind, sample in enumerate(out_samples):
372
+ person_empty = True
373
+ if sample.has_person_bbox:
374
+ xmin, ymin, xmax, ymax = sample.person_bbox
375
+ crop = im_cv[ymin:ymax, xmin:xmax]
376
+ # cut off all associated objects from person crop
377
+ _, person_empty = _cropout_asced_objs(
378
+ associated_objects[ind],
379
+ sample.person_bbox,
380
+ crop.copy(),
381
+ min_person_size=min_person_size,
382
+ crop_round_tol=crop_round_tol,
383
+ min_person_aftercut_ratio=min_person_aftercut_ratio,
384
+ )
385
+
386
+ if person_empty and not sample.has_face_bbox:
387
+ msgs.append("Small person after preprocessing. Passing..")
388
+ skipped_crops += 1
389
+ else:
390
+ inds.append(ind)
391
+ out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
392
+
393
+ assert len(associated_objects) == len(out_samples)
394
+ return out_samples, associated_objects, skipped_crops
395
+
396
+
397
+ def _filter_by_ind(out_samples, associated_objects, inds):
398
+ _associated_objects = {}
399
+ _out_samples = []
400
+ for ind, sample in enumerate(out_samples):
401
+ if ind in inds:
402
+ _associated_objects[len(_out_samples)] = associated_objects[ind]
403
+ _out_samples.append(sample)
404
+
405
+ return _out_samples, _associated_objects
406
+
407
+
408
+ def find_associated_objects(
409
+ image_samples: List[PictureInfo], iou_thresh=0.0001, only_age=False
410
+ ) -> Dict[int, List[List[int]]]:
411
+ """
412
+ For each person (which has gt age and gt gender) find other faces and persons bboxes, intersected with it
413
+ """
414
+ associated_objects: Dict[int, List[List[int]]] = {}
415
+
416
+ for iindex, image_sample_info in enumerate(image_samples):
417
+ # add own face
418
+ associated_objects[iindex] = [image_sample_info.bbox] if image_sample_info.has_face_bbox else []
419
+
420
+ if not image_sample_info.has_person_bbox or not image_sample_info.has_gt(only_age):
421
+ # if sample has not gt => not be used
422
+ continue
423
+
424
+ iperson_box = image_sample_info.person_bbox
425
+ for jindex, other_image_sample in enumerate(image_samples):
426
+ if iindex == jindex:
427
+ continue
428
+ if other_image_sample.has_face_bbox:
429
+ jface_bbox = other_image_sample.bbox
430
+ iou = _get_iou(jface_bbox, iperson_box)
431
+ if iou >= iou_thresh:
432
+ associated_objects[iindex].append(jface_bbox)
433
+ if other_image_sample.has_person_bbox:
434
+ jperson_bbox = other_image_sample.person_bbox
435
+ iou = _get_iou(jperson_bbox, iperson_box)
436
+ if iou >= iou_thresh:
437
+ associated_objects[iindex].append(jperson_bbox)
438
+
439
+ return associated_objects
440
+
441
+
442
+ def _cropout_asced_objs(
443
+ asced_objects,
444
+ person_bbox,
445
+ crop,
446
+ min_person_size,
447
+ crop_round_tol,
448
+ min_person_aftercut_ratio,
449
+ crop_out_color=(0, 0, 0),
450
+ ):
451
+ empty = False
452
+ xmin, ymin, xmax, ymax = person_bbox
453
+
454
+ for a_obj in asced_objects:
455
+ aobj_xmin, aobj_ymin, aobj_xmax, aobj_ymax = a_obj
456
+
457
+ aobj_ymin = int(max(aobj_ymin - ymin, 0))
458
+ aobj_xmin = int(max(aobj_xmin - xmin, 0))
459
+ aobj_ymax = int(min(aobj_ymax - ymin, ymax - ymin))
460
+ aobj_xmax = int(min(aobj_xmax - xmin, xmax - xmin))
461
+
462
+ crop[aobj_ymin:aobj_ymax, aobj_xmin:aobj_xmax] = crop_out_color
463
+
464
+ crop, cropped_ratio = cropout_black_parts(crop, crop_round_tol)
465
+ if (
466
+ crop.shape[0] < min_person_size or crop.shape[1] < min_person_size
467
+ ) or cropped_ratio < min_person_aftercut_ratio:
468
+ crop = None
469
+ empty = True
470
+
471
+ return crop, empty
472
+
473
+
474
+ def _correct_bbox(bbox, h, w):
475
+ xmin, ymin, xmax, ymax = bbox
476
+ ymin = min(max(ymin, 0), h)
477
+ ymax = min(max(ymax, 0), h)
478
+ xmin = min(max(xmin, 0), w)
479
+ xmax = min(max(xmax, 0), w)
480
+ return ymin, ymax, xmin, xmax
481
+
482
+
483
+ def _get_iou(bbox1, bbox2):
484
+ xmin1, ymin1, xmax1, ymax1 = bbox1
485
+ xmin2, ymin2, xmax2, ymax2 = bbox2
486
+ iou = IOU(
487
+ [ymin1, xmin1, ymax1, xmax1],
488
+ [ymin2, xmin2, ymax2, xmax2],
489
+ )
490
+ return iou
mivolo/data/misc.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import ast
3
+ import re
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ import torchvision.transforms.functional as F
10
+ from scipy.optimize import linear_sum_assignment
11
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
12
+
13
+ CROP_ROUND_RATE = 0.1
14
+ MIN_PERSON_CROP_NONZERO = 0.5
15
+
16
+
17
+ def aggregate_votes_winsorized(ages, max_age_dist=6):
18
+ # Replace any annotation that is more than a max_age_dist away from the median
19
+ # with the median + max_age_dist if higher or max_age_dist - max_age_dist if below
20
+ median = np.median(ages)
21
+ ages = np.clip(ages, median - max_age_dist, median + max_age_dist)
22
+ return np.mean(ages)
23
+
24
+
25
+ def cropout_black_parts(img, tol=0.3):
26
+ # Create a binary mask of zero pixels
27
+ zero_pixels_mask = np.all(img == 0, axis=2)
28
+ # Calculate the threshold for zero pixels in rows and columns
29
+ threshold = img.shape[0] - img.shape[0] * tol
30
+ # Calculate row sums and column sums of zero pixels mask
31
+ row_sums = np.sum(zero_pixels_mask, axis=1)
32
+ col_sums = np.sum(zero_pixels_mask, axis=0)
33
+ # Find the first and last rows with zero pixel sums above the threshold
34
+ start_row = np.argmin(row_sums > threshold)
35
+ end_row = img.shape[0] - np.argmin(row_sums[::-1] > threshold)
36
+ # Find the first and last columns with zero pixel sums above the threshold
37
+ start_col = np.argmin(col_sums > threshold)
38
+ end_col = img.shape[1] - np.argmin(col_sums[::-1] > threshold)
39
+ # Crop the image
40
+ cropped_img = img[start_row:end_row, start_col:end_col, :]
41
+ area = cropped_img.shape[0] * cropped_img.shape[1]
42
+ area_orig = img.shape[0] * img.shape[1]
43
+ return cropped_img, area / area_orig
44
+
45
+
46
+ def natural_key(string_):
47
+ """See http://www.codinghorror.com/blog/archives/001018.html"""
48
+ return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
49
+
50
+
51
+ def add_bool_arg(parser, name, default=False, help=""):
52
+ dest_name = name.replace("-", "_")
53
+ group = parser.add_mutually_exclusive_group(required=False)
54
+ group.add_argument("--" + name, dest=dest_name, action="store_true", help=help)
55
+ group.add_argument("--no-" + name, dest=dest_name, action="store_false", help=help)
56
+ parser.set_defaults(**{dest_name: default})
57
+
58
+
59
+ def cumulative_score(pred_ages, gt_ages, L, tol=1e-6):
60
+ n = pred_ages.shape[0]
61
+ num_correct = torch.sum(torch.abs(pred_ages - gt_ages) <= L + tol)
62
+ cs_score = num_correct / n
63
+ return cs_score
64
+
65
+
66
+ def cumulative_error(pred_ages, gt_ages, L, tol=1e-6):
67
+ n = pred_ages.shape[0]
68
+ num_correct = torch.sum(torch.abs(pred_ages - gt_ages) >= L + tol)
69
+ cs_score = num_correct / n
70
+ return cs_score
71
+
72
+
73
+ class ParseKwargs(argparse.Action):
74
+ def __call__(self, parser, namespace, values, option_string=None):
75
+ kw = {}
76
+ for value in values:
77
+ key, value = value.split("=")
78
+ try:
79
+ kw[key] = ast.literal_eval(value)
80
+ except ValueError:
81
+ kw[key] = str(value) # fallback to string (avoid need to escape on command line)
82
+ setattr(namespace, self.dest, kw)
83
+
84
+
85
+ def box_iou(box1, box2, over_second=False):
86
+ """
87
+ Return intersection-over-union (Jaccard index) of boxes.
88
+ If over_second == True, return mean(intersection-over-union, (inter / area2))
89
+
90
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
91
+
92
+ Arguments:
93
+ box1 (Tensor[N, 4])
94
+ box2 (Tensor[M, 4])
95
+ Returns:
96
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
97
+ IoU values for every element in boxes1 and boxes2
98
+ """
99
+
100
+ def box_area(box):
101
+ # box = 4xn
102
+ return (box[2] - box[0]) * (box[3] - box[1])
103
+
104
+ area1 = box_area(box1.T)
105
+ area2 = box_area(box2.T)
106
+
107
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
108
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
109
+
110
+ iou = inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
111
+ if over_second:
112
+ return (inter / area2 + iou) / 2 # mean(inter / area2, iou)
113
+ else:
114
+ return iou
115
+
116
+
117
+ def split_batch(bs: int, dev: int) -> Tuple[int, int]:
118
+ full_bs = (bs // dev) * dev
119
+ part_bs = bs - full_bs
120
+ return full_bs, part_bs
121
+
122
+
123
+ def assign_faces(
124
+ persons_bboxes: List[torch.tensor], faces_bboxes: List[torch.tensor], iou_thresh: float = 0.0001
125
+ ) -> Tuple[List[Optional[int]], List[int]]:
126
+ """
127
+ Assign person to each face if it is possible.
128
+ Return:
129
+ - assigned_faces List[Optional[int]]: mapping of face_ind to person_ind
130
+ ( assigned_faces[face_ind] = person_ind ). person_ind can be None
131
+ - unassigned_persons_inds List[int]: persons indexes without any assigned face
132
+ """
133
+
134
+ assigned_faces: List[Optional[int]] = [None for _ in range(len(faces_bboxes))]
135
+ unassigned_persons_inds: List[int] = [p_ind for p_ind in range(len(persons_bboxes))]
136
+
137
+ if len(persons_bboxes) == 0 or len(faces_bboxes) == 0:
138
+ return assigned_faces, unassigned_persons_inds
139
+
140
+ cost_matrix = box_iou(torch.stack(persons_bboxes), torch.stack(faces_bboxes), over_second=True).cpu().numpy()
141
+ persons_indexes, face_indexes = [], []
142
+
143
+ if len(cost_matrix) > 0:
144
+ persons_indexes, face_indexes = linear_sum_assignment(cost_matrix, maximize=True)
145
+
146
+ matched_persons = set()
147
+ for person_idx, face_idx in zip(persons_indexes, face_indexes):
148
+ ciou = cost_matrix[person_idx][face_idx]
149
+ if ciou > iou_thresh:
150
+ if person_idx in matched_persons:
151
+ # Person can not be assigned twice, in reality this should not happen
152
+ continue
153
+ assigned_faces[face_idx] = person_idx
154
+ matched_persons.add(person_idx)
155
+
156
+ unassigned_persons_inds = [p_ind for p_ind in range(len(persons_bboxes)) if p_ind not in matched_persons]
157
+
158
+ return assigned_faces, unassigned_persons_inds
159
+
160
+
161
+ def class_letterbox(im, new_shape=(640, 640), color=(0, 0, 0), scaleup=True):
162
+ # Resize and pad image while meeting stride-multiple constraints
163
+ shape = im.shape[:2] # current shape [height, width]
164
+ if isinstance(new_shape, int):
165
+ new_shape = (new_shape, new_shape)
166
+
167
+ if im.shape[0] == new_shape[0] and im.shape[1] == new_shape[1]:
168
+ return im
169
+
170
+ # Scale ratio (new / old)
171
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
172
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
173
+ r = min(r, 1.0)
174
+
175
+ # Compute padding
176
+ # ratio = r, r # width, height ratios
177
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
178
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
179
+
180
+ dw /= 2 # divide padding into 2 sides
181
+ dh /= 2
182
+
183
+ if shape[::-1] != new_unpad: # resize
184
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
185
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
186
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
187
+ im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
188
+ return im
189
+
190
+
191
+ def prepare_classification_images(
192
+ img_list: List[Optional[np.ndarray]],
193
+ target_size: int = 224,
194
+ mean=IMAGENET_DEFAULT_MEAN,
195
+ std=IMAGENET_DEFAULT_STD,
196
+ device=None,
197
+ ) -> torch.tensor:
198
+
199
+ prepared_images: List[torch.tensor] = []
200
+
201
+ for img in img_list:
202
+ if img is None:
203
+ img = torch.zeros((3, target_size, target_size), dtype=torch.float32)
204
+ img = F.normalize(img, mean=mean, std=std)
205
+ img = img.unsqueeze(0)
206
+ prepared_images.append(img)
207
+ continue
208
+ img = class_letterbox(img, new_shape=(target_size, target_size))
209
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
210
+
211
+ img = img / 255.0
212
+ img = (img - mean) / std
213
+ img = img.astype(dtype=np.float32)
214
+
215
+ img = img.transpose((2, 0, 1))
216
+ img = np.ascontiguousarray(img)
217
+ img = torch.from_numpy(img)
218
+ img = img.unsqueeze(0)
219
+
220
+ prepared_images.append(img)
221
+
222
+ prepared_input = torch.concat(prepared_images)
223
+
224
+ if device:
225
+ prepared_input = prepared_input.to(device)
226
+
227
+ return prepared_input
228
+
229
+
230
+ def IOU(bb1: Union[tuple, list], bb2: Union[tuple, list], norm_second_bbox: bool = False) -> float:
231
+ # expects [ymin, xmin, ymax, xmax], doesnt matter absolute or relative
232
+ assert bb1[1] < bb1[3]
233
+ assert bb1[0] < bb1[2]
234
+ assert bb2[1] < bb2[3]
235
+ assert bb2[0] < bb2[2]
236
+
237
+ # determine the coordinates of the intersection rectangle
238
+ x_left = max(bb1[1], bb2[1])
239
+ y_top = max(bb1[0], bb2[0])
240
+ x_right = min(bb1[3], bb2[3])
241
+ y_bottom = min(bb1[2], bb2[2])
242
+
243
+ if x_right < x_left or y_bottom < y_top:
244
+ return 0.0
245
+
246
+ # The intersection of two axis-aligned bounding boxes is always an
247
+ # axis-aligned bounding box
248
+ intersection_area = (x_right - x_left) * (y_bottom - y_top)
249
+ # compute the area of both AABBs
250
+ bb1_area = (bb1[3] - bb1[1]) * (bb1[2] - bb1[0])
251
+ bb2_area = (bb2[3] - bb2[1]) * (bb2[2] - bb2[0])
252
+ if not norm_second_bbox:
253
+ # compute the intersection over union by taking the intersection
254
+ # area and dividing it by the sum of prediction + ground-truth
255
+ # areas - the interesection area
256
+ iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
257
+ else:
258
+ # for cases when we search if second bbox is inside first one
259
+ iou = intersection_area / float(bb2_area)
260
+
261
+ assert iou >= 0.0
262
+ assert iou <= 1.01
263
+
264
+ return iou
mivolo/model/create_timm_model.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from timm https://github.com/huggingface/pytorch-image-models
3
+
4
+ Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
+ """
6
+
7
+ import os
8
+ from typing import Any, Dict, Optional, Union
9
+
10
+ import timm
11
+
12
+ # register new models
13
+ from mivolo.model.mivolo_model import * # noqa: F403, F401
14
+ from timm.layers import set_layer_config
15
+ from timm.models._factory import parse_model_name
16
+ from timm.models._helpers import load_state_dict, remap_checkpoint
17
+ from timm.models._hub import load_model_config_from_hf
18
+ from timm.models._pretrained import PretrainedCfg, split_model_name_tag
19
+ from timm.models._registry import is_model, model_entrypoint
20
+
21
+
22
+ def load_checkpoint(
23
+ model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None
24
+ ):
25
+ if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"):
26
+ # numpy checkpoint, try to load via model specific load_pretrained fn
27
+ if hasattr(model, "load_pretrained"):
28
+ timm.models._model_builder.load_pretrained(checkpoint_path)
29
+ else:
30
+ raise NotImplementedError("Model cannot load numpy checkpoint")
31
+ return
32
+ state_dict = load_state_dict(checkpoint_path, use_ema)
33
+ if remap:
34
+ state_dict = remap_checkpoint(model, state_dict)
35
+ if filter_keys:
36
+ for sd_key in list(state_dict.keys()):
37
+ for filter_key in filter_keys:
38
+ if filter_key in sd_key:
39
+ if sd_key in state_dict:
40
+ del state_dict[sd_key]
41
+
42
+ rep = []
43
+ if state_dict_map is not None:
44
+ # 'patch_embed.conv1.' : 'patch_embed.conv.'
45
+ for state_k in list(state_dict.keys()):
46
+ for target_k, target_v in state_dict_map.items():
47
+ if target_v in state_k:
48
+ target_name = state_k.replace(target_v, target_k)
49
+ state_dict[target_name] = state_dict[state_k]
50
+ rep.append(state_k)
51
+ for r in rep:
52
+ if r in state_dict:
53
+ del state_dict[r]
54
+
55
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False)
56
+ return incompatible_keys
57
+
58
+
59
+ def create_model(
60
+ model_name: str,
61
+ pretrained: bool = False,
62
+ pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
63
+ pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
64
+ checkpoint_path: str = "",
65
+ scriptable: Optional[bool] = None,
66
+ exportable: Optional[bool] = None,
67
+ no_jit: Optional[bool] = None,
68
+ filter_keys=None,
69
+ state_dict_map=None,
70
+ **kwargs,
71
+ ):
72
+ """Create a model
73
+ Lookup model's entrypoint function and pass relevant args to create a new model.
74
+ """
75
+ # Parameters that aren't supported by all models or are intended to only override model defaults if set
76
+ # should default to None in command line args/cfg. Remove them if they are present and not set so that
77
+ # non-supporting models don't break and default args remain in effect.
78
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
79
+
80
+ model_source, model_name = parse_model_name(model_name)
81
+ if model_source == "hf-hub":
82
+ assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub."
83
+ # For model names specified in the form `hf-hub:path/architecture_name@revision`,
84
+ # load model weights + pretrained_cfg from Hugging Face hub.
85
+ pretrained_cfg, model_name = load_model_config_from_hf(model_name)
86
+ else:
87
+ model_name, pretrained_tag = split_model_name_tag(model_name)
88
+ if not pretrained_cfg:
89
+ # a valid pretrained_cfg argument takes priority over tag in model name
90
+ pretrained_cfg = pretrained_tag
91
+
92
+ if not is_model(model_name):
93
+ raise RuntimeError("Unknown model (%s)" % model_name)
94
+
95
+ create_fn = model_entrypoint(model_name)
96
+ with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
97
+ model = create_fn(
98
+ pretrained=pretrained,
99
+ pretrained_cfg=pretrained_cfg,
100
+ pretrained_cfg_overlay=pretrained_cfg_overlay,
101
+ **kwargs,
102
+ )
103
+
104
+ if checkpoint_path:
105
+ load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map)
106
+
107
+ return model
mivolo/model/cross_bottleneck_attn.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code based on timm https://github.com/huggingface/pytorch-image-models
3
+
4
+ Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from timm.layers.bottleneck_attn import PosEmbedRel
10
+ from timm.layers.helpers import make_divisible
11
+ from timm.layers.mlp import Mlp
12
+ from timm.layers.trace_utils import _assert
13
+ from timm.layers.weight_init import trunc_normal_
14
+
15
+
16
+ class CrossBottleneckAttn(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim,
20
+ dim_out=None,
21
+ feat_size=None,
22
+ stride=1,
23
+ num_heads=4,
24
+ dim_head=None,
25
+ qk_ratio=1.0,
26
+ qkv_bias=False,
27
+ scale_pos_embed=False,
28
+ ):
29
+ super().__init__()
30
+ assert feat_size is not None, "A concrete feature size matching expected input (H, W) is required"
31
+ dim_out = dim_out or dim
32
+ assert dim_out % num_heads == 0
33
+
34
+ self.num_heads = num_heads
35
+ self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
36
+ self.dim_head_v = dim_out // self.num_heads
37
+ self.dim_out_qk = num_heads * self.dim_head_qk
38
+ self.dim_out_v = num_heads * self.dim_head_v
39
+ self.scale = self.dim_head_qk**-0.5
40
+ self.scale_pos_embed = scale_pos_embed
41
+
42
+ self.qkv_f = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
43
+ self.qkv_p = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
44
+
45
+ # NOTE I'm only supporting relative pos embedding for now
46
+ self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
47
+
48
+ self.norm = nn.LayerNorm([self.dim_out_v * 2, *feat_size])
49
+ mlp_ratio = 4
50
+ self.mlp = Mlp(
51
+ in_features=self.dim_out_v * 2,
52
+ hidden_features=int(dim * mlp_ratio),
53
+ act_layer=nn.GELU,
54
+ out_features=dim_out,
55
+ drop=0,
56
+ use_conv=True,
57
+ )
58
+
59
+ self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
60
+ self.reset_parameters()
61
+
62
+ def reset_parameters(self):
63
+ trunc_normal_(self.qkv_f.weight, std=self.qkv_f.weight.shape[1] ** -0.5) # fan-in
64
+ trunc_normal_(self.qkv_p.weight, std=self.qkv_p.weight.shape[1] ** -0.5) # fan-in
65
+ trunc_normal_(self.pos_embed.height_rel, std=self.scale)
66
+ trunc_normal_(self.pos_embed.width_rel, std=self.scale)
67
+
68
+ def get_qkv(self, x, qvk_conv):
69
+ B, C, H, W = x.shape
70
+
71
+ x = qvk_conv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
72
+
73
+ q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
74
+
75
+ q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
76
+ k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
77
+ v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
78
+
79
+ return q, k, v
80
+
81
+ def apply_attn(self, q, k, v, B, H, W, dropout=None):
82
+ if self.scale_pos_embed:
83
+ attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
84
+ else:
85
+ attn = (q @ k) * self.scale + self.pos_embed(q)
86
+ attn = attn.softmax(dim=-1)
87
+ if dropout:
88
+ attn = dropout(attn)
89
+
90
+ out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
91
+ return out
92
+
93
+ def forward(self, x):
94
+ B, C, H, W = x.shape
95
+
96
+ dim = int(C / 2)
97
+ x1 = x[:, :dim, :, :]
98
+ x2 = x[:, dim:, :, :]
99
+
100
+ _assert(H == self.pos_embed.height, "")
101
+ _assert(W == self.pos_embed.width, "")
102
+
103
+ q_f, k_f, v_f = self.get_qkv(x1, self.qkv_f)
104
+ q_p, k_p, v_p = self.get_qkv(x2, self.qkv_p)
105
+
106
+ # person to face
107
+ out_f = self.apply_attn(q_f, k_p, v_p, B, H, W)
108
+ # face to person
109
+ out_p = self.apply_attn(q_p, k_f, v_f, B, H, W)
110
+
111
+ x_pf = torch.cat((out_f, out_p), dim=1) # B, dim_out * 2, H, W
112
+ x_pf = self.norm(x_pf)
113
+ x_pf = self.mlp(x_pf) # B, dim_out, H, W
114
+
115
+ out = self.pool(x_pf)
116
+ return out
mivolo/model/mi_volo.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from mivolo.data.misc import prepare_classification_images
7
+ from mivolo.model.create_timm_model import create_model
8
+ from mivolo.structures import PersonAndFaceCrops, PersonAndFaceResult
9
+ from timm.data import resolve_data_config
10
+
11
+ _logger = logging.getLogger("MiVOLO")
12
+ has_compile = hasattr(torch, "compile")
13
+
14
+
15
+ class Meta:
16
+ def __init__(self):
17
+ self.min_age = None
18
+ self.max_age = None
19
+ self.avg_age = None
20
+ self.num_classes = None
21
+
22
+ self.in_chans = 3
23
+ self.with_persons_model = False
24
+ self.disable_faces = False
25
+ self.use_persons = True
26
+ self.only_age = False
27
+
28
+ self.num_classes_gender = 2
29
+
30
+ def load_from_ckpt(self, ckpt_path: str, disable_faces: bool = False, use_persons: bool = True) -> "Meta":
31
+
32
+ state = torch.load(ckpt_path, map_location="cpu")
33
+
34
+ self.min_age = state["min_age"]
35
+ self.max_age = state["max_age"]
36
+ self.avg_age = state["avg_age"]
37
+ self.only_age = state["no_gender"]
38
+
39
+ only_age = state["no_gender"]
40
+
41
+ self.disable_faces = disable_faces
42
+ if "with_persons_model" in state:
43
+ self.with_persons_model = state["with_persons_model"]
44
+ else:
45
+ self.with_persons_model = True if "patch_embed.conv1.0.weight" in state["state_dict"] else False
46
+
47
+ self.num_classes = 1 if only_age else 3
48
+ self.in_chans = 3 if not self.with_persons_model else 6
49
+ self.use_persons = use_persons and self.with_persons_model
50
+
51
+ if not self.with_persons_model and self.disable_faces:
52
+ raise ValueError("You can not use disable-faces for faces-only model")
53
+ if self.with_persons_model and self.disable_faces and not self.use_persons:
54
+ raise ValueError("You can not disable faces and persons together")
55
+
56
+ return self
57
+
58
+ def __str__(self):
59
+ attrs = vars(self)
60
+ attrs.update({"use_person_crops": self.use_person_crops, "use_face_crops": self.use_face_crops})
61
+ return ", ".join("%s: %s" % item for item in attrs.items())
62
+
63
+ @property
64
+ def use_person_crops(self) -> bool:
65
+ return self.with_persons_model and self.use_persons
66
+
67
+ @property
68
+ def use_face_crops(self) -> bool:
69
+ return not self.disable_faces or not self.with_persons_model
70
+
71
+
72
+ class MiVOLO:
73
+ def __init__(
74
+ self,
75
+ ckpt_path: str,
76
+ device: str = "cpu",
77
+ half: bool = True,
78
+ disable_faces: bool = False,
79
+ use_persons: bool = True,
80
+ verbose: bool = False,
81
+ torchcompile: Optional[str] = None,
82
+ ):
83
+ self.verbose = verbose
84
+ self.device = torch.device(device)
85
+ self.half = half and self.device.type != "cpu"
86
+
87
+ self.meta: Meta = Meta().load_from_ckpt(ckpt_path, disable_faces, use_persons)
88
+ if self.verbose:
89
+ _logger.info(f"Model meta:\n{str(self.meta)}")
90
+
91
+ model_name = "mivolo_d1_224"
92
+ self.model = create_model(
93
+ model_name=model_name,
94
+ num_classes=self.meta.num_classes,
95
+ in_chans=self.meta.in_chans,
96
+ pretrained=False,
97
+ checkpoint_path=ckpt_path,
98
+ filter_keys=["fds."],
99
+ )
100
+ self.param_count = sum([m.numel() for m in self.model.parameters()])
101
+ _logger.info(f"Model {model_name} created, param count: {self.param_count}")
102
+
103
+ self.data_config = resolve_data_config(
104
+ model=self.model,
105
+ verbose=verbose,
106
+ use_test_size=True,
107
+ )
108
+ self.data_config["crop_pct"] = 1.0
109
+ c, h, w = self.data_config["input_size"]
110
+ assert h == w, "Incorrect data_config"
111
+ self.input_size = w
112
+
113
+ self.model = self.model.to(self.device)
114
+
115
+ if torchcompile:
116
+ assert has_compile, "A version of torch w/ torch.compile() is required for --compile, possibly a nightly."
117
+ torch._dynamo.reset()
118
+ self.model = torch.compile(self.model, backend=torchcompile)
119
+
120
+ self.model.eval()
121
+ if self.half:
122
+ self.model = self.model.half()
123
+
124
+ def warmup(self, batch_size: int, steps=10):
125
+ if self.meta.with_persons_model:
126
+ input_size = (6, self.input_size, self.input_size)
127
+ else:
128
+ input_size = self.data_config["input_size"]
129
+
130
+ input = torch.randn((batch_size,) + tuple(input_size)).to(self.device)
131
+
132
+ for _ in range(steps):
133
+ out = self.inference(input) # noqa: F841
134
+
135
+ if torch.cuda.is_available():
136
+ torch.cuda.synchronize()
137
+
138
+ def inference(self, model_input: torch.tensor) -> torch.tensor:
139
+
140
+ with torch.no_grad():
141
+ if self.half:
142
+ model_input = model_input.half()
143
+ output = self.model(model_input)
144
+ return output
145
+
146
+ def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
147
+ if detected_bboxes.n_objects == 0:
148
+ return
149
+
150
+ faces_input, person_input, faces_inds, bodies_inds = self.prepare_crops(image, detected_bboxes)
151
+
152
+ if self.meta.with_persons_model:
153
+ model_input = torch.cat((faces_input, person_input), dim=1)
154
+ else:
155
+ model_input = faces_input
156
+ output = self.inference(model_input)
157
+
158
+ # write gender and age results into detected_bboxes
159
+ self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds)
160
+
161
+ def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds):
162
+ if self.meta.only_age:
163
+ age_output = output
164
+ gender_probs, gender_indx = None, None
165
+ else:
166
+ age_output = output[:, 2]
167
+ gender_output = output[:, :2].softmax(-1)
168
+ gender_probs, gender_indx = gender_output.topk(1)
169
+
170
+ assert output.shape[0] == len(faces_inds) == len(bodies_inds)
171
+
172
+ # per face
173
+ for index in range(output.shape[0]):
174
+ face_ind = faces_inds[index]
175
+ body_ind = bodies_inds[index]
176
+
177
+ # get_age
178
+ age = age_output[index].item()
179
+ age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age
180
+ age = round(age, 2)
181
+
182
+ detected_bboxes.set_age(face_ind, age)
183
+ detected_bboxes.set_age(body_ind, age)
184
+
185
+ _logger.info(f"\tage: {age}")
186
+
187
+ if gender_probs is not None:
188
+ gender = "male" if gender_indx[index].item() == 0 else "female"
189
+ gender_score = gender_probs[index].item()
190
+
191
+ _logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]")
192
+
193
+ detected_bboxes.set_gender(face_ind, gender, gender_score)
194
+ detected_bboxes.set_gender(body_ind, gender, gender_score)
195
+
196
+ def prepare_crops(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
197
+
198
+ if self.meta.use_person_crops and self.meta.use_face_crops:
199
+ detected_bboxes.associate_faces_with_persons()
200
+
201
+ crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image)
202
+ (bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies(
203
+ self.meta.use_person_crops, self.meta.use_face_crops
204
+ )
205
+
206
+ if not self.meta.use_face_crops:
207
+ assert all(f is None for f in faces_crops)
208
+
209
+ faces_input = prepare_classification_images(
210
+ faces_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
211
+ )
212
+
213
+ if not self.meta.use_person_crops:
214
+ assert all(p is None for p in bodies_crops)
215
+
216
+ person_input = prepare_classification_images(
217
+ bodies_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
218
+ )
219
+
220
+ _logger.info(
221
+ f"faces_input: {faces_input.shape if faces_input is not None else None}, "
222
+ f"person_input: {person_input.shape if person_input is not None else None}"
223
+ )
224
+
225
+ return faces_input, person_input, faces_inds, bodies_inds
226
+
227
+
228
+ if __name__ == "__main__":
229
+ model = MiVOLO("../pretrained/checkpoint-377.pth.tar", half=True, device="cuda:0")
mivolo/model/mivolo_model.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code adapted from timm https://github.com/huggingface/pytorch-image-models
3
+
4
+ Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from mivolo.model.cross_bottleneck_attn import CrossBottleneckAttn
10
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
+ from timm.layers import trunc_normal_
12
+ from timm.models._builder import build_model_with_cfg
13
+ from timm.models._registry import register_model
14
+ from timm.models.volo import VOLO
15
+
16
+ __all__ = ["MiVOLOModel"] # model_registry will add each entrypoint fn to this
17
+
18
+
19
+ def _cfg(url="", **kwargs):
20
+ return {
21
+ "url": url,
22
+ "num_classes": 1000,
23
+ "input_size": (3, 224, 224),
24
+ "pool_size": None,
25
+ "crop_pct": 0.96,
26
+ "interpolation": "bicubic",
27
+ "fixed_input_size": True,
28
+ "mean": IMAGENET_DEFAULT_MEAN,
29
+ "std": IMAGENET_DEFAULT_STD,
30
+ "first_conv": None,
31
+ "classifier": ("head", "aux_head"),
32
+ **kwargs,
33
+ }
34
+
35
+
36
+ default_cfgs = {
37
+ "mivolo_d1_224": _cfg(
38
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar", crop_pct=0.96
39
+ ),
40
+ "mivolo_d1_384": _cfg(
41
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar",
42
+ crop_pct=1.0,
43
+ input_size=(3, 384, 384),
44
+ ),
45
+ "mivolo_d2_224": _cfg(
46
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar", crop_pct=0.96
47
+ ),
48
+ "mivolo_d2_384": _cfg(
49
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar",
50
+ crop_pct=1.0,
51
+ input_size=(3, 384, 384),
52
+ ),
53
+ "mivolo_d3_224": _cfg(
54
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar", crop_pct=0.96
55
+ ),
56
+ "mivolo_d3_448": _cfg(
57
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar",
58
+ crop_pct=1.0,
59
+ input_size=(3, 448, 448),
60
+ ),
61
+ "mivolo_d4_224": _cfg(
62
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar", crop_pct=0.96
63
+ ),
64
+ "mivolo_d4_448": _cfg(
65
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar",
66
+ crop_pct=1.15,
67
+ input_size=(3, 448, 448),
68
+ ),
69
+ "mivolo_d5_224": _cfg(
70
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar", crop_pct=0.96
71
+ ),
72
+ "mivolo_d5_448": _cfg(
73
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar",
74
+ crop_pct=1.15,
75
+ input_size=(3, 448, 448),
76
+ ),
77
+ "mivolo_d5_512": _cfg(
78
+ url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar",
79
+ crop_pct=1.15,
80
+ input_size=(3, 512, 512),
81
+ ),
82
+ }
83
+
84
+
85
+ def get_output_size(input_shape, conv_layer):
86
+ padding = conv_layer.padding
87
+ dilation = conv_layer.dilation
88
+ kernel_size = conv_layer.kernel_size
89
+ stride = conv_layer.stride
90
+
91
+ output_size = [
92
+ ((input_shape[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[i]) + 1 for i in range(2)
93
+ ]
94
+ return output_size
95
+
96
+
97
+ def get_output_size_module(input_size, stem):
98
+ output_size = input_size
99
+
100
+ for module in stem:
101
+ if isinstance(module, nn.Conv2d):
102
+ output_size = [
103
+ (
104
+ (output_size[i] + 2 * module.padding[i] - module.dilation[i] * (module.kernel_size[i] - 1) - 1)
105
+ // module.stride[i]
106
+ )
107
+ + 1
108
+ for i in range(2)
109
+ ]
110
+
111
+ return output_size
112
+
113
+
114
+ class PatchEmbed(nn.Module):
115
+ """Image to Patch Embedding."""
116
+
117
+ def __init__(
118
+ self, img_size=224, stem_conv=False, stem_stride=1, patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384
119
+ ):
120
+ super().__init__()
121
+ assert patch_size in [4, 8, 16]
122
+ assert in_chans in [3, 6]
123
+ self.with_persons_model = in_chans == 6
124
+ self.use_cross_attn = True
125
+
126
+ if stem_conv:
127
+ if not self.with_persons_model:
128
+ self.conv = self.create_stem(stem_stride, in_chans, hidden_dim)
129
+ else:
130
+ self.conv = True # just to match interface
131
+ # split
132
+ self.conv1 = self.create_stem(stem_stride, 3, hidden_dim)
133
+ self.conv2 = self.create_stem(stem_stride, 3, hidden_dim)
134
+ else:
135
+ self.conv = None
136
+
137
+ if self.with_persons_model:
138
+
139
+ self.proj1 = nn.Conv2d(
140
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
141
+ )
142
+ self.proj2 = nn.Conv2d(
143
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
144
+ )
145
+
146
+ stem_out_shape = get_output_size_module((img_size, img_size), self.conv1)
147
+ self.proj_output_size = get_output_size(stem_out_shape, self.proj1)
148
+
149
+ self.map = CrossBottleneckAttn(embed_dim, dim_out=embed_dim, num_heads=1, feat_size=self.proj_output_size)
150
+
151
+ else:
152
+ self.proj = nn.Conv2d(
153
+ hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
154
+ )
155
+
156
+ self.patch_dim = img_size // patch_size
157
+ self.num_patches = self.patch_dim**2
158
+
159
+ def create_stem(self, stem_stride, in_chans, hidden_dim):
160
+ return nn.Sequential(
161
+ nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112
162
+ nn.BatchNorm2d(hidden_dim),
163
+ nn.ReLU(inplace=True),
164
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
165
+ nn.BatchNorm2d(hidden_dim),
166
+ nn.ReLU(inplace=True),
167
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
168
+ nn.BatchNorm2d(hidden_dim),
169
+ nn.ReLU(inplace=True),
170
+ )
171
+
172
+ def forward(self, x):
173
+ if self.conv is not None:
174
+ if self.with_persons_model:
175
+ x1 = x[:, :3]
176
+ x2 = x[:, 3:]
177
+
178
+ x1 = self.conv1(x1)
179
+ x1 = self.proj1(x1)
180
+
181
+ x2 = self.conv2(x2)
182
+ x2 = self.proj2(x2)
183
+
184
+ x = torch.cat([x1, x2], dim=1)
185
+ x = self.map(x)
186
+ else:
187
+ x = self.conv(x)
188
+ x = self.proj(x) # B, C, H, W
189
+
190
+ return x
191
+
192
+
193
+ class MiVOLOModel(VOLO):
194
+ """
195
+ Vision Outlooker, the main class of our model
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ layers,
201
+ img_size=224,
202
+ in_chans=3,
203
+ num_classes=1000,
204
+ global_pool="token",
205
+ patch_size=8,
206
+ stem_hidden_dim=64,
207
+ embed_dims=None,
208
+ num_heads=None,
209
+ downsamples=(True, False, False, False),
210
+ outlook_attention=(True, False, False, False),
211
+ mlp_ratio=3.0,
212
+ qkv_bias=False,
213
+ drop_rate=0.0,
214
+ attn_drop_rate=0.0,
215
+ drop_path_rate=0.0,
216
+ norm_layer=nn.LayerNorm,
217
+ post_layers=("ca", "ca"),
218
+ use_aux_head=True,
219
+ use_mix_token=False,
220
+ pooling_scale=2,
221
+ ):
222
+ super().__init__(
223
+ layers,
224
+ img_size,
225
+ in_chans,
226
+ num_classes,
227
+ global_pool,
228
+ patch_size,
229
+ stem_hidden_dim,
230
+ embed_dims,
231
+ num_heads,
232
+ downsamples,
233
+ outlook_attention,
234
+ mlp_ratio,
235
+ qkv_bias,
236
+ drop_rate,
237
+ attn_drop_rate,
238
+ drop_path_rate,
239
+ norm_layer,
240
+ post_layers,
241
+ use_aux_head,
242
+ use_mix_token,
243
+ pooling_scale,
244
+ )
245
+
246
+ self.patch_embed = PatchEmbed(
247
+ stem_conv=True,
248
+ stem_stride=2,
249
+ patch_size=patch_size,
250
+ in_chans=in_chans,
251
+ hidden_dim=stem_hidden_dim,
252
+ embed_dim=embed_dims[0],
253
+ )
254
+
255
+ trunc_normal_(self.pos_embed, std=0.02)
256
+ self.apply(self._init_weights)
257
+
258
+ def forward_features(self, x):
259
+ x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
260
+
261
+ # step2: tokens learning in the two stages
262
+ x = self.forward_tokens(x)
263
+
264
+ # step3: post network, apply class attention or not
265
+ if self.post_network is not None:
266
+ x = self.forward_cls(x)
267
+ x = self.norm(x)
268
+ return x
269
+
270
+ def forward_head(self, x, pre_logits: bool = False, targets=None, epoch=None):
271
+ if self.global_pool == "avg":
272
+ out = x.mean(dim=1)
273
+ elif self.global_pool == "token":
274
+ out = x[:, 0]
275
+ else:
276
+ out = x
277
+ if pre_logits:
278
+ return out
279
+
280
+ features = out
281
+ fds_enabled = hasattr(self, "_fds_forward")
282
+ if fds_enabled:
283
+ features = self._fds_forward(features, targets, epoch)
284
+
285
+ out = self.head(features)
286
+ if self.aux_head is not None:
287
+ # generate classes in all feature tokens, see token labeling
288
+ aux = self.aux_head(x[:, 1:])
289
+ out = out + 0.5 * aux.max(1)[0]
290
+
291
+ return (out, features) if (fds_enabled and self.training) else out
292
+
293
+ def forward(self, x, targets=None, epoch=None):
294
+ """simplified forward (without mix token training)"""
295
+ x = self.forward_features(x)
296
+ x = self.forward_head(x, targets=targets, epoch=epoch)
297
+ return x
298
+
299
+
300
+ def _create_mivolo(variant, pretrained=False, **kwargs):
301
+ if kwargs.get("features_only", None):
302
+ raise RuntimeError("features_only not implemented for Vision Transformer models.")
303
+ return build_model_with_cfg(MiVOLOModel, variant, pretrained, **kwargs)
304
+
305
+
306
+ @register_model
307
+ def mivolo_d1_224(pretrained=False, **kwargs):
308
+ model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
309
+ model = _create_mivolo("mivolo_d1_224", pretrained=pretrained, **model_args)
310
+ return model
311
+
312
+
313
+ @register_model
314
+ def mivolo_d1_384(pretrained=False, **kwargs):
315
+ model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
316
+ model = _create_mivolo("mivolo_d1_384", pretrained=pretrained, **model_args)
317
+ return model
318
+
319
+
320
+ @register_model
321
+ def mivolo_d2_224(pretrained=False, **kwargs):
322
+ model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
323
+ model = _create_mivolo("mivolo_d2_224", pretrained=pretrained, **model_args)
324
+ return model
325
+
326
+
327
+ @register_model
328
+ def mivolo_d2_384(pretrained=False, **kwargs):
329
+ model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
330
+ model = _create_mivolo("mivolo_d2_384", pretrained=pretrained, **model_args)
331
+ return model
332
+
333
+
334
+ @register_model
335
+ def mivolo_d3_224(pretrained=False, **kwargs):
336
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
337
+ model = _create_mivolo("mivolo_d3_224", pretrained=pretrained, **model_args)
338
+ return model
339
+
340
+
341
+ @register_model
342
+ def mivolo_d3_448(pretrained=False, **kwargs):
343
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
344
+ model = _create_mivolo("mivolo_d3_448", pretrained=pretrained, **model_args)
345
+ return model
346
+
347
+
348
+ @register_model
349
+ def mivolo_d4_224(pretrained=False, **kwargs):
350
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
351
+ model = _create_mivolo("mivolo_d4_224", pretrained=pretrained, **model_args)
352
+ return model
353
+
354
+
355
+ @register_model
356
+ def mivolo_d4_448(pretrained=False, **kwargs):
357
+ """VOLO-D4 model, Params: 193M"""
358
+ model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
359
+ model = _create_mivolo("mivolo_d4_448", pretrained=pretrained, **model_args)
360
+ return model
361
+
362
+
363
+ @register_model
364
+ def mivolo_d5_224(pretrained=False, **kwargs):
365
+ model_args = dict(
366
+ layers=(12, 12, 20, 4),
367
+ embed_dims=(384, 768, 768, 768),
368
+ num_heads=(12, 16, 16, 16),
369
+ mlp_ratio=4,
370
+ stem_hidden_dim=128,
371
+ **kwargs
372
+ )
373
+ model = _create_mivolo("mivolo_d5_224", pretrained=pretrained, **model_args)
374
+ return model
375
+
376
+
377
+ @register_model
378
+ def mivolo_d5_448(pretrained=False, **kwargs):
379
+ model_args = dict(
380
+ layers=(12, 12, 20, 4),
381
+ embed_dims=(384, 768, 768, 768),
382
+ num_heads=(12, 16, 16, 16),
383
+ mlp_ratio=4,
384
+ stem_hidden_dim=128,
385
+ **kwargs
386
+ )
387
+ model = _create_mivolo("mivolo_d5_448", pretrained=pretrained, **model_args)
388
+ return model
389
+
390
+
391
+ @register_model
392
+ def mivolo_d5_512(pretrained=False, **kwargs):
393
+ model_args = dict(
394
+ layers=(12, 12, 20, 4),
395
+ embed_dims=(384, 768, 768, 768),
396
+ num_heads=(12, 16, 16, 16),
397
+ mlp_ratio=4,
398
+ stem_hidden_dim=128,
399
+ **kwargs
400
+ )
401
+ model = _create_mivolo("mivolo_d5_512", pretrained=pretrained, **model_args)
402
+ return model
mivolo/model/yolo_detector.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, Union
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+ from mivolo.structures import PersonAndFaceResult
8
+ from ultralytics import YOLO
9
+ # from ultralytics.yolo.engine.results import Results
10
+
11
+ # because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing
12
+ os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
13
+
14
+
15
+ class Detector:
16
+ def __init__(
17
+ self,
18
+ weights: str,
19
+ device: str = "cpu",
20
+ half: bool = True,
21
+ verbose: bool = False,
22
+ conf_thresh: float = 0.4,
23
+ iou_thresh: float = 0.7,
24
+ ):
25
+ self.yolo = YOLO(weights)
26
+ self.yolo.fuse()
27
+
28
+ self.device = torch.device(device)
29
+ self.half = half and self.device.type != "cpu"
30
+
31
+ if self.half:
32
+ self.yolo.model = self.yolo.model.half()
33
+
34
+ self.detector_names: Dict[int, str] = self.yolo.model.names
35
+
36
+ # init yolo.predictor
37
+ self.detector_kwargs = {
38
+ "conf": conf_thresh, "iou": iou_thresh, "half": self.half, "verbose": verbose}
39
+ # self.yolo.predict(**self.detector_kwargs)
40
+
41
+ def predict(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult:
42
+ results = self.yolo.predict(image, **self.detector_kwargs)[0]
43
+ return PersonAndFaceResult(results)
44
+
45
+ def track(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult:
46
+ results = self.yolo.track(
47
+ image, persist=True, **self.detector_kwargs)[0]
48
+ return PersonAndFaceResult(results)
mivolo/predictor.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Dict, Generator, List, Optional, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import tqdm
7
+ from mivolo.model.mi_volo import MiVOLO
8
+ from mivolo.model.yolo_detector import Detector
9
+ from mivolo.structures import AGE_GENDER_TYPE, PersonAndFaceResult
10
+
11
+
12
+ class Predictor:
13
+ def __init__(self, config, verbose: bool = False):
14
+ self.detector = Detector(config.detector_weights, config.device, verbose=verbose)
15
+ self.age_gender_model = MiVOLO(
16
+ config.checkpoint,
17
+ config.device,
18
+ half=True,
19
+ use_persons=config.with_persons,
20
+ disable_faces=config.disable_faces,
21
+ verbose=verbose,
22
+ )
23
+ self.draw = config.draw
24
+
25
+ def recognize(self, image: np.ndarray) -> Tuple[PersonAndFaceResult, Optional[np.ndarray]]:
26
+ detected_objects: PersonAndFaceResult = self.detector.predict(image)
27
+ self.age_gender_model.predict(image, detected_objects)
28
+
29
+ out_im = None
30
+ if self.draw:
31
+ # plot results on image
32
+ out_im = detected_objects.plot()
33
+
34
+ return detected_objects, out_im
35
+
36
+ def recognize_video(self, source: str) -> Generator:
37
+ video_capture = cv2.VideoCapture(source)
38
+ if not video_capture.isOpened():
39
+ raise ValueError(f"Failed to open video source {source}")
40
+
41
+ detected_objects_history: Dict[int, List[AGE_GENDER_TYPE]] = defaultdict(list)
42
+
43
+ total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
44
+ for _ in tqdm.tqdm(range(total_frames)):
45
+ ret, frame = video_capture.read()
46
+ if not ret:
47
+ break
48
+
49
+ detected_objects: PersonAndFaceResult = self.detector.track(frame)
50
+ self.age_gender_model.predict(frame, detected_objects)
51
+
52
+ current_frame_objs = detected_objects.get_results_for_tracking()
53
+ cur_persons: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[0]
54
+ cur_faces: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[1]
55
+
56
+ # add tr_persons and tr_faces to history
57
+ for guid, data in cur_persons.items():
58
+ # not useful for tracking :)
59
+ if None not in data:
60
+ detected_objects_history[guid].append(data)
61
+ for guid, data in cur_faces.items():
62
+ if None not in data:
63
+ detected_objects_history[guid].append(data)
64
+
65
+ detected_objects.set_tracked_age_gender(detected_objects_history)
66
+ if self.draw:
67
+ frame = detected_objects.plot()
68
+ yield detected_objects_history, frame
mivolo/structures.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from copy import deepcopy
4
+ from typing import Dict, List, Optional, Tuple
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from mivolo.data.misc import aggregate_votes_winsorized, assign_faces, box_iou, cropout_black_parts
10
+ from ultralytics.engine.results import Results
11
+ from ultralytics.utils.plotting import Annotator, colors
12
+
13
+ # because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing
14
+ os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
15
+
16
+ AGE_GENDER_TYPE = Tuple[float, str]
17
+
18
+
19
+ class PersonAndFaceCrops:
20
+ def __init__(self):
21
+ # int: index of person along results
22
+ self.crops_persons: Dict[int, np.ndarray] = {}
23
+
24
+ # int: index of face along results
25
+ self.crops_faces: Dict[int, np.ndarray] = {}
26
+
27
+ # int: index of face along results
28
+ self.crops_faces_wo_body: Dict[int, np.ndarray] = {}
29
+
30
+ # int: index of person along results
31
+ self.crops_persons_wo_face: Dict[int, np.ndarray] = {}
32
+
33
+ def _add_to_output(
34
+ self, crops: Dict[int, np.ndarray], out_crops: List[np.ndarray], out_crop_inds: List[Optional[int]]
35
+ ):
36
+ inds_to_add = list(crops.keys())
37
+ crops_to_add = list(crops.values())
38
+ out_crops.extend(crops_to_add)
39
+ out_crop_inds.extend(inds_to_add)
40
+
41
+ def _get_all_faces(
42
+ self, use_persons: bool, use_faces: bool
43
+ ) -> Tuple[List[Optional[int]], List[Optional[np.ndarray]]]:
44
+ """
45
+ Returns
46
+ if use_persons and use_faces
47
+ faces: faces_with_bodies + faces_without_bodies + [None] * len(crops_persons_wo_face)
48
+ if use_persons and not use_faces
49
+ faces: [None] * n_persons
50
+ if not use_persons and use_faces:
51
+ faces: faces_with_bodies + faces_without_bodies
52
+ """
53
+
54
+ def add_none_to_output(faces_inds, faces_crops, num):
55
+ faces_inds.extend([None for _ in range(num)])
56
+ faces_crops.extend([None for _ in range(num)])
57
+
58
+ faces_inds: List[Optional[int]] = []
59
+ faces_crops: List[Optional[np.ndarray]] = []
60
+
61
+ if not use_faces:
62
+ add_none_to_output(faces_inds, faces_crops, len(
63
+ self.crops_persons) + len(self.crops_persons_wo_face))
64
+ return faces_inds, faces_crops
65
+
66
+ self._add_to_output(self.crops_faces, faces_crops, faces_inds)
67
+ self._add_to_output(self.crops_faces_wo_body, faces_crops, faces_inds)
68
+
69
+ if use_persons:
70
+ add_none_to_output(faces_inds, faces_crops,
71
+ len(self.crops_persons_wo_face))
72
+
73
+ return faces_inds, faces_crops
74
+
75
+ def _get_all_bodies(
76
+ self, use_persons: bool, use_faces: bool
77
+ ) -> Tuple[List[Optional[int]], List[Optional[np.ndarray]]]:
78
+ """
79
+ Returns
80
+ if use_persons and use_faces
81
+ persons: bodies_with_faces + [None] * len(faces_without_bodies) + bodies_without_faces
82
+ if use_persons and not use_faces
83
+ persons: bodies_with_faces + bodies_without_faces
84
+ if not use_persons and use_faces
85
+ persons: [None] * n_faces
86
+ """
87
+
88
+ def add_none_to_output(bodies_inds, bodies_crops, num):
89
+ bodies_inds.extend([None for _ in range(num)])
90
+ bodies_crops.extend([None for _ in range(num)])
91
+
92
+ bodies_inds: List[Optional[int]] = []
93
+ bodies_crops: List[Optional[np.ndarray]] = []
94
+
95
+ if not use_persons:
96
+ add_none_to_output(bodies_inds, bodies_crops, len(
97
+ self.crops_faces) + len(self.crops_faces_wo_body))
98
+ return bodies_inds, bodies_crops
99
+
100
+ self._add_to_output(self.crops_persons, bodies_crops, bodies_inds)
101
+ if use_faces:
102
+ add_none_to_output(bodies_inds, bodies_crops,
103
+ len(self.crops_faces_wo_body))
104
+
105
+ self._add_to_output(self.crops_persons_wo_face,
106
+ bodies_crops, bodies_inds)
107
+
108
+ return bodies_inds, bodies_crops
109
+
110
+ def get_faces_with_bodies(self, use_persons: bool, use_faces: bool):
111
+ """
112
+ Return
113
+ faces: faces_with_bodies, faces_without_bodies, [None] * len(crops_persons_wo_face)
114
+ persons: bodies_with_faces, [None] * len(faces_without_bodies), bodies_without_faces
115
+ """
116
+
117
+ bodies_inds, bodies_crops = self._get_all_bodies(
118
+ use_persons, use_faces)
119
+ faces_inds, faces_crops = self._get_all_faces(use_persons, use_faces)
120
+
121
+ return (bodies_inds, bodies_crops), (faces_inds, faces_crops)
122
+
123
+ def save(self, out_dir="output"):
124
+ ind = 0
125
+ os.makedirs(out_dir, exist_ok=True)
126
+ for crops in [self.crops_persons, self.crops_faces, self.crops_faces_wo_body, self.crops_persons_wo_face]:
127
+ for crop in crops.values():
128
+ if crop is None:
129
+ continue
130
+ out_name = os.path.join(out_dir, f"{ind}_crop.jpg")
131
+ cv2.imwrite(out_name, crop)
132
+ ind += 1
133
+
134
+
135
+ class PersonAndFaceResult:
136
+ def __init__(self, results: Results):
137
+
138
+ self.yolo_results = results
139
+ names = set(results.names.values())
140
+ assert "person" in names and "face" in names
141
+
142
+ # initially no faces and persons are associated to each other
143
+ self.face_to_person_map: Dict[int, Optional[int]] = {
144
+ ind: None for ind in self.get_bboxes_inds("face")}
145
+ self.unassigned_persons_inds: List[int] = self.get_bboxes_inds(
146
+ "person")
147
+ n_objects = len(self.yolo_results.boxes)
148
+ self.ages: List[Optional[float]] = [None for _ in range(n_objects)]
149
+ self.genders: List[Optional[str]] = [None for _ in range(n_objects)]
150
+ self.gender_scores: List[Optional[float]] = [
151
+ None for _ in range(n_objects)]
152
+
153
+ @property
154
+ def n_objects(self) -> int:
155
+ return len(self.yolo_results.boxes)
156
+
157
+ def get_bboxes_inds(self, category: str) -> List[int]:
158
+ bboxes: List[int] = []
159
+ for ind, det in enumerate(self.yolo_results.boxes):
160
+ name = self.yolo_results.names[int(det.cls)]
161
+ if name == category:
162
+ bboxes.append(ind)
163
+
164
+ return bboxes
165
+
166
+ def get_distance_to_center(self, bbox_ind: int) -> float:
167
+ """
168
+ Calculate euclidian distance between bbox center and image center.
169
+ """
170
+ im_h, im_w = self.yolo_results[bbox_ind].orig_shape
171
+ x1, y1, x2, y2 = self.get_bbox_by_ind(bbox_ind).cpu().numpy()
172
+ center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
173
+ dist = math.dist([center_x, center_y], [im_w / 2, im_h / 2])
174
+ return dist
175
+
176
+ def plot(
177
+ self,
178
+ conf=False,
179
+ line_width=None,
180
+ font_size=None,
181
+ font="Arial.ttf",
182
+ pil=False,
183
+ img=None,
184
+ labels=True,
185
+ boxes=True,
186
+ probs=True,
187
+ ages=True,
188
+ genders=True,
189
+ gender_probs=False,
190
+ ):
191
+ """
192
+ Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
193
+ Args:
194
+ conf (bool): Whether to plot the detection confidence score.
195
+ line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
196
+ font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
197
+ font (str): The font to use for the text.
198
+ pil (bool): Whether to return the image as a PIL Image.
199
+ img (numpy.ndarray): Plot to another image. if not, plot to original image.
200
+ labels (bool): Whether to plot the label of bounding boxes.
201
+ boxes (bool): Whether to plot the bounding boxes.
202
+ probs (bool): Whether to plot classification probability
203
+ ages (bool): Whether to plot the age of bounding boxes.
204
+ genders (bool): Whether to plot the genders of bounding boxes.
205
+ gender_probs (bool): Whether to plot gender classification probability
206
+ Returns:
207
+ (numpy.ndarray): A numpy array of the annotated image.
208
+ """
209
+
210
+ # return self.yolo_results.plot()
211
+ colors_by_ind = {}
212
+ for face_ind, person_ind in self.face_to_person_map.items():
213
+ if person_ind is not None:
214
+ colors_by_ind[face_ind] = face_ind + 2
215
+ colors_by_ind[person_ind] = face_ind + 2
216
+ else:
217
+ colors_by_ind[face_ind] = 0
218
+ for person_ind in self.unassigned_persons_inds:
219
+ colors_by_ind[person_ind] = 1
220
+
221
+ names = self.yolo_results.names
222
+ annotator = Annotator(
223
+ deepcopy(self.yolo_results.orig_img if img is None else img),
224
+ line_width,
225
+ font_size,
226
+ font,
227
+ pil,
228
+ example=names,
229
+ )
230
+ pred_boxes, show_boxes = self.yolo_results.boxes, boxes
231
+ pred_probs, show_probs = self.yolo_results.probs, probs
232
+
233
+ if pred_boxes and show_boxes:
234
+ for bb_ind, (d, age, gender, gender_score) in enumerate(
235
+ zip(pred_boxes, self.ages, self.genders, self.gender_scores)
236
+ ):
237
+ c, conf, guid = int(d.cls), float(
238
+ d.conf) if conf else None, None if d.id is None else int(d.id.item())
239
+ name = ("" if guid is None else f"id:{guid} ") + names[c]
240
+ label = (
241
+ f"{name} {conf:.2f}" if conf else name) if labels else None
242
+ if ages and age is not None:
243
+ label += f" {age:.1f}"
244
+ if genders and gender is not None:
245
+ label += f" {'F' if gender == 'female' else 'M'}"
246
+ if gender_probs and gender_score is not None:
247
+ label += f" ({gender_score:.1f})"
248
+ annotator.box_label(d.xyxy.squeeze(), label,
249
+ color=colors(colors_by_ind[bb_ind], True))
250
+
251
+ if pred_probs is not None and show_probs:
252
+ text = f"{', '.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)}, "
253
+ annotator.text((32, 32), text, txt_color=(
254
+ 255, 255, 255)) # TODO: allow setting colors
255
+
256
+ return annotator.result()
257
+
258
+ def set_tracked_age_gender(self, tracked_objects: Dict[int, List[AGE_GENDER_TYPE]]):
259
+ """
260
+ Update age and gender for objects based on history from tracked_objects.
261
+ Args:
262
+ tracked_objects (dict[int, list[AGE_GENDER_TYPE]]): info about tracked objects by guid
263
+ """
264
+
265
+ for face_ind, person_ind in self.face_to_person_map.items():
266
+ pguid = self._get_id_by_ind(person_ind)
267
+ fguid = self._get_id_by_ind(face_ind)
268
+
269
+ if fguid == -1 and pguid == -1:
270
+ # YOLO might not assign ids for some objects in some cases:
271
+ # https://github.com/ultralytics/ultralytics/issues/3830
272
+ continue
273
+ age, gender = self._gather_tracking_result(
274
+ tracked_objects, fguid, pguid)
275
+ if age is None or gender is None:
276
+ continue
277
+ self.set_age(face_ind, age)
278
+ self.set_gender(face_ind, gender, 1.0)
279
+ if pguid != -1:
280
+ self.set_gender(person_ind, gender, 1.0)
281
+ self.set_age(person_ind, age)
282
+
283
+ for person_ind in self.unassigned_persons_inds:
284
+ pid = self._get_id_by_ind(person_ind)
285
+ if pid == -1:
286
+ continue
287
+ age, gender = self._gather_tracking_result(
288
+ tracked_objects, -1, pid)
289
+ if age is None or gender is None:
290
+ continue
291
+ self.set_gender(person_ind, gender, 1.0)
292
+ self.set_age(person_ind, age)
293
+
294
+ def _get_id_by_ind(self, ind: Optional[int] = None) -> int:
295
+ if ind is None:
296
+ return -1
297
+ obj_id = self.yolo_results.boxes[ind].id
298
+ if obj_id is None:
299
+ return -1
300
+ return obj_id.item()
301
+
302
+ def get_bbox_by_ind(self, ind: int, im_h: int = None, im_w: int = None) -> torch.tensor:
303
+ bb = self.yolo_results.boxes[ind].xyxy.squeeze().type(torch.int32)
304
+ if im_h is not None and im_w is not None:
305
+ bb[0] = torch.clamp(bb[0], min=0, max=im_w - 1)
306
+ bb[1] = torch.clamp(bb[1], min=0, max=im_h - 1)
307
+ bb[2] = torch.clamp(bb[2], min=0, max=im_w - 1)
308
+ bb[3] = torch.clamp(bb[3], min=0, max=im_h - 1)
309
+ return bb
310
+
311
+ def set_age(self, ind: Optional[int], age: float):
312
+ if ind is not None:
313
+ self.ages[ind] = age
314
+
315
+ def set_gender(self, ind: Optional[int], gender: str, gender_score: float):
316
+ if ind is not None:
317
+ self.genders[ind] = gender
318
+ self.gender_scores[ind] = gender_score
319
+
320
+ @staticmethod
321
+ def _gather_tracking_result(
322
+ tracked_objects: Dict[int, List[AGE_GENDER_TYPE]],
323
+ fguid: int = -1,
324
+ pguid: int = -1,
325
+ minimum_sample_size: int = 10,
326
+ ) -> AGE_GENDER_TYPE:
327
+
328
+ assert fguid != -1 or pguid != -1, "Incorrect tracking behaviour"
329
+
330
+ face_ages = [r[0] for r in tracked_objects[fguid] if r[0]
331
+ is not None] if fguid in tracked_objects else []
332
+ face_genders = [r[1] for r in tracked_objects[fguid]
333
+ if r[1] is not None] if fguid in tracked_objects else []
334
+ person_ages = [r[0] for r in tracked_objects[pguid]
335
+ if r[0] is not None] if pguid in tracked_objects else []
336
+ person_genders = [r[1] for r in tracked_objects[pguid]
337
+ if r[1] is not None] if pguid in tracked_objects else []
338
+
339
+ if not face_ages and not person_ages: # both empty
340
+ return None, None
341
+
342
+ # You can play here with different aggregation strategies
343
+ # Face ages - predictions based on face or face + person, depends on history of object
344
+ # Person ages - predictions based on person or face + person, depends on history of object
345
+
346
+ if len(person_ages + face_ages) >= minimum_sample_size:
347
+ age = aggregate_votes_winsorized(person_ages + face_ages)
348
+ else:
349
+ face_age = np.mean(face_ages) if face_ages else None
350
+ person_age = np.mean(person_ages) if person_ages else None
351
+ if face_age is None:
352
+ face_age = person_age
353
+ if person_age is None:
354
+ person_age = face_age
355
+ age = (face_age + person_age) / 2.0
356
+
357
+ genders = face_genders + person_genders
358
+ assert len(genders) > 0
359
+ # take mode of genders
360
+ gender = max(set(genders), key=genders.count)
361
+
362
+ return age, gender
363
+
364
+ def get_results_for_tracking(self) -> Tuple[Dict[int, AGE_GENDER_TYPE], Dict[int, AGE_GENDER_TYPE]]:
365
+ """
366
+ Get objects from current frame
367
+ """
368
+ persons: Dict[int, AGE_GENDER_TYPE] = {}
369
+ faces: Dict[int, AGE_GENDER_TYPE] = {}
370
+
371
+ names = self.yolo_results.names
372
+ pred_boxes = self.yolo_results.boxes
373
+ for _, (det, age, gender, _) in enumerate(zip(pred_boxes, self.ages, self.genders, self.gender_scores)):
374
+ if det.id is None:
375
+ continue
376
+ cat_id, _, guid = int(det.cls), float(det.conf), int(det.id.item())
377
+ name = names[cat_id]
378
+ if name == "person":
379
+ persons[guid] = (age, gender)
380
+ elif name == "face":
381
+ faces[guid] = (age, gender)
382
+
383
+ return persons, faces
384
+
385
+ def associate_faces_with_persons(self):
386
+ face_bboxes_inds: List[int] = self.get_bboxes_inds("face")
387
+ person_bboxes_inds: List[int] = self.get_bboxes_inds("person")
388
+
389
+ face_bboxes: List[torch.tensor] = [
390
+ self.get_bbox_by_ind(ind) for ind in face_bboxes_inds]
391
+ person_bboxes: List[torch.tensor] = [
392
+ self.get_bbox_by_ind(ind) for ind in person_bboxes_inds]
393
+
394
+ self.face_to_person_map = {ind: None for ind in face_bboxes_inds}
395
+ assigned_faces, unassigned_persons_inds = assign_faces(
396
+ person_bboxes, face_bboxes)
397
+
398
+ for face_ind, person_ind in enumerate(assigned_faces):
399
+ face_ind = face_bboxes_inds[face_ind]
400
+ person_ind = person_bboxes_inds[person_ind] if person_ind is not None else None
401
+ self.face_to_person_map[face_ind] = person_ind
402
+
403
+ self.unassigned_persons_inds = [
404
+ person_bboxes_inds[person_ind] for person_ind in unassigned_persons_inds]
405
+
406
+ def crop_object(
407
+ self, full_image: np.ndarray, ind: int, cut_other_classes: Optional[List[str]] = None
408
+ ) -> Optional[np.ndarray]:
409
+
410
+ IOU_THRESH = 0.000001
411
+ MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4
412
+ CROP_ROUND_RATE = 0.3
413
+ MIN_PERSON_SIZE = 50
414
+
415
+ obj_bbox = self.get_bbox_by_ind(ind, *full_image.shape[:2])
416
+ x1, y1, x2, y2 = obj_bbox
417
+ cur_cat = self.yolo_results.names[int(
418
+ self.yolo_results.boxes[ind].cls)]
419
+ # get crop of face or person
420
+ obj_image = full_image[y1:y2, x1:x2].copy()
421
+ crop_h, crop_w = obj_image.shape[:2]
422
+
423
+ if cur_cat == "person" and (crop_h < MIN_PERSON_SIZE or crop_w < MIN_PERSON_SIZE):
424
+ return None
425
+
426
+ if not cut_other_classes:
427
+ return obj_image
428
+
429
+ # calc iou between obj_bbox and other bboxes
430
+ other_bboxes: List[torch.tensor] = [
431
+ self.get_bbox_by_ind(other_ind, *full_image.shape[:2]) for other_ind in range(len(self.yolo_results.boxes))
432
+ ]
433
+
434
+ iou_matrix = box_iou(torch.stack([obj_bbox]), torch.stack(
435
+ other_bboxes)).cpu().numpy()[0]
436
+
437
+ # cut out other objects in case of intersection
438
+ for other_ind, (det, iou) in enumerate(zip(self.yolo_results.boxes, iou_matrix)):
439
+ other_cat = self.yolo_results.names[int(det.cls)]
440
+ if ind == other_ind or iou < IOU_THRESH or other_cat not in cut_other_classes:
441
+ continue
442
+ o_x1, o_y1, o_x2, o_y2 = det.xyxy.squeeze().type(torch.int32)
443
+
444
+ # remap current_person_bbox to reference_person_bbox coordinates
445
+ o_x1 = max(o_x1 - x1, 0)
446
+ o_y1 = max(o_y1 - y1, 0)
447
+ o_x2 = min(o_x2 - x1, crop_w)
448
+ o_y2 = min(o_y2 - y1, crop_h)
449
+
450
+ if other_cat != "face":
451
+ if (o_y1 / crop_h) < CROP_ROUND_RATE:
452
+ o_y1 = 0
453
+ if ((crop_h - o_y2) / crop_h) < CROP_ROUND_RATE:
454
+ o_y2 = crop_h
455
+ if (o_x1 / crop_w) < CROP_ROUND_RATE:
456
+ o_x1 = 0
457
+ if ((crop_w - o_x2) / crop_w) < CROP_ROUND_RATE:
458
+ o_x2 = crop_w
459
+
460
+ obj_image[o_y1:o_y2, o_x1:o_x2] = 0
461
+
462
+ obj_image, remain_ratio = cropout_black_parts(
463
+ obj_image, CROP_ROUND_RATE)
464
+ if remain_ratio < MIN_PERSON_CROP_AFTERCUT_RATIO:
465
+ return None
466
+
467
+ return obj_image
468
+
469
+ def collect_crops(self, image) -> PersonAndFaceCrops:
470
+
471
+ crops_data = PersonAndFaceCrops()
472
+ for face_ind, person_ind in self.face_to_person_map.items():
473
+ face_image = self.crop_object(
474
+ image, face_ind, cut_other_classes=[])
475
+
476
+ if person_ind is None:
477
+ crops_data.crops_faces_wo_body[face_ind] = face_image
478
+ continue
479
+
480
+ person_image = self.crop_object(
481
+ image, person_ind, cut_other_classes=["face", "person"])
482
+
483
+ crops_data.crops_faces[face_ind] = face_image
484
+ crops_data.crops_persons[person_ind] = person_image
485
+
486
+ for person_ind in self.unassigned_persons_inds:
487
+ person_image = self.crop_object(
488
+ image, person_ind, cut_other_classes=["face", "person"])
489
+ crops_data.crops_persons_wo_face[person_ind] = person_image
490
+
491
+ # uncomment to save preprocessed crops
492
+ # crops_data.save()
493
+ return crops_data
mivolo/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "0.3.0dev"
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ultralytics==8.0.187
2
+ timm==0.8.13.dev0
3
+ tqdm
4
+ requests
5
+ opencv-python
6
+ omegaconf
utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from tqdm import tqdm
4
+ from modelscope import snapshot_download
5
+ from urllib.parse import urlparse
6
+
7
+ MODEL_DIR = snapshot_download("MuGeminorum/MiVOLO", cache_dir="./mivolo/__pycache__")
8
+
9
+
10
+ def is_url(s: str):
11
+ try:
12
+ # 解析字符串
13
+ result = urlparse(s)
14
+ # 检查scheme(如http, https)和netloc(域名)
15
+ return all([result.scheme, result.netloc])
16
+
17
+ except:
18
+ # 如果解析过程中发生异常,则返回False
19
+ return False
20
+
21
+
22
+ def download_file(url: str, save_path: str):
23
+ if os.path.exists(save_path):
24
+ print("目标已存在,无需下载")
25
+ return
26
+
27
+ create_dir(os.path.dirname(save_path))
28
+ response = requests.get(url, stream=True)
29
+ total_size = int(response.headers.get("content-length", 0))
30
+ # 使用 tqdm 创建一个进度条
31
+ progress_bar = tqdm(total=total_size, unit="B", unit_scale=True)
32
+ with open(save_path, "wb") as file:
33
+ for data in response.iter_content(chunk_size=1024):
34
+ file.write(data)
35
+ progress_bar.update(len(data))
36
+
37
+ progress_bar.close()
38
+ if total_size != 0 and progress_bar.n != total_size:
39
+ os.remove(save_path)
40
+ print("下载失败,重试中...")
41
+ download_file(url, save_path)
42
+
43
+ else:
44
+ print("下载完成")
45
+
46
+ return save_path
47
+
48
+
49
+ def create_dir(dir_path: str):
50
+ if not os.path.exists(dir_path):
51
+ os.makedirs(dir_path)
52
+
53
+
54
+ def get_jpg_files(folder_path: str):
55
+ all_files = os.listdir(folder_path)
56
+ return [
57
+ os.path.join(folder_path, file)
58
+ for file in all_files
59
+ if file.lower().endswith(".jpg")
60
+ ]