Upload 26 files
Browse filesTAD Bot Face Detection algorithm
- detectfaces.py +102 -0
- face_detection.py +23 -0
- main.py +602 -0
- models/.DS_Store +0 -0
- models/PosterV2_7cls.py +441 -0
- models/PosterV2_8cls.py +317 -0
- models/__pycache__/PosterV2_7cls.cpython-310.pyc +0 -0
- models/__pycache__/PosterV2_7cls.cpython-311.pyc +0 -0
- models/__pycache__/ir50.cpython-310.pyc +0 -0
- models/__pycache__/ir50.cpython-311.pyc +0 -0
- models/__pycache__/mobilefacenet.cpython-310.pyc +0 -0
- models/__pycache__/mobilefacenet.cpython-311.pyc +0 -0
- models/__pycache__/vit_model.cpython-310.pyc +0 -0
- models/__pycache__/vit_model.cpython-311.pyc +0 -0
- models/ir50.py +272 -0
- models/matrix.py +62 -0
- models/mobilefacenet.py +193 -0
- models/pretrain/.DS_Store +0 -0
- models/pretrain/.gitignore +2 -0
- models/pretrain/ir50.pth +3 -0
- models/pretrain/mobilefacenet_model_best.pth.tar +3 -0
- models/vit_model.py +828 -0
- models/vit_model_8.py +828 -0
- prediction.py +103 -0
- raf-db-model_best.pth +3 -0
- requirements.txt +131 -0
detectfaces.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from main import *
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
|
5 |
+
model_path = "raf-db-model_best.pth"
|
6 |
+
|
7 |
+
if torch.backends.mps.is_available():
|
8 |
+
device = "mps"
|
9 |
+
elif torch.cuda.is_available():
|
10 |
+
device = "cuda"
|
11 |
+
else:
|
12 |
+
device = "cpu"
|
13 |
+
|
14 |
+
model = pyramid_trans_expr2(img_size=224, num_classes=7)
|
15 |
+
|
16 |
+
model = torch.nn.DataParallel(model)
|
17 |
+
model = model.to(device)
|
18 |
+
currtime = time.strftime("%H:%M:%S")
|
19 |
+
print(currtime)
|
20 |
+
|
21 |
+
|
22 |
+
def main():
|
23 |
+
if model_path is not None:
|
24 |
+
if os.path.isfile(model_path):
|
25 |
+
print("=> loading checkpoint '{}'".format(model_path))
|
26 |
+
checkpoint = torch.load(model_path, map_location=device)
|
27 |
+
best_acc = checkpoint["best_acc"]
|
28 |
+
best_acc = best_acc.to()
|
29 |
+
print(f"best_acc:{best_acc}")
|
30 |
+
model.load_state_dict(checkpoint["state_dict"])
|
31 |
+
print(
|
32 |
+
"=> loaded checkpoint '{}' (epoch {})".format(
|
33 |
+
model_path, checkpoint["epoch"]
|
34 |
+
)
|
35 |
+
)
|
36 |
+
else:
|
37 |
+
print("=> no checkpoint found at '{}'".format(model_path))
|
38 |
+
imagecapture(model)
|
39 |
+
return
|
40 |
+
|
41 |
+
|
42 |
+
def imagecapture(model):
|
43 |
+
currtimeimg = time.strftime("%H:%M:%S")
|
44 |
+
cap = cv2.VideoCapture(0)
|
45 |
+
if not cap.isOpened():
|
46 |
+
print("Error: Could not open webcam.")
|
47 |
+
exit()
|
48 |
+
|
49 |
+
face_cascade = cv2.CascadeClassifier(
|
50 |
+
cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
|
51 |
+
)
|
52 |
+
|
53 |
+
start_time = None
|
54 |
+
capturing = False
|
55 |
+
|
56 |
+
while True:
|
57 |
+
from prediction import predict
|
58 |
+
|
59 |
+
ret, frame = cap.read()
|
60 |
+
|
61 |
+
if not ret:
|
62 |
+
print("Error: Could not read frame.")
|
63 |
+
break
|
64 |
+
|
65 |
+
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
|
66 |
+
|
67 |
+
faces = face_cascade.detectMultiScale(
|
68 |
+
gray, scaleFactor=1.3, minNeighbors=5, minSize=(30, 30)
|
69 |
+
)
|
70 |
+
|
71 |
+
# Display the frame
|
72 |
+
cv2.imshow("Webcam", frame)
|
73 |
+
|
74 |
+
# If faces are detected, start the timer
|
75 |
+
if len(faces) > 0:
|
76 |
+
print(f"[!]Face detected at {currtimeimg}")
|
77 |
+
face_region = frame[
|
78 |
+
faces[0][1] : faces[0][1] + faces[0][3],
|
79 |
+
faces[0][0] : faces[0][0] + faces[0][2],
|
80 |
+
] # Crop the face region
|
81 |
+
face_pil_image = Image.fromarray(
|
82 |
+
cv2.cvtColor(face_region, cv2.COLOR_BGR2RGB)
|
83 |
+
) # Convert to PIL image
|
84 |
+
print("[!]Start Expressions")
|
85 |
+
print(f"-->Prediction starting at {currtimeimg}")
|
86 |
+
predictions = predict(model, image_path=face_pil_image)
|
87 |
+
print(f"-->Done prediction at {currtimeimg}")
|
88 |
+
|
89 |
+
# Reset capturing
|
90 |
+
capturing = False
|
91 |
+
|
92 |
+
# Break the loop if the 'q' key is pressed
|
93 |
+
if cv2.waitKey(1) & 0xFF == ord("q"):
|
94 |
+
break
|
95 |
+
|
96 |
+
# Release the webcam and close the OpenCV window
|
97 |
+
cap.release()
|
98 |
+
cv2.destroyAllWindows()
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
main()
|
face_detection.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from deepface import DeepFace
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import time
|
6 |
+
|
7 |
+
|
8 |
+
def face_detection(img_path):
|
9 |
+
currtime = time.strftime("%H:%M:%S")
|
10 |
+
face_objs = DeepFace.extract_faces(np.array(img_path), detector_backend="mtcnn", enforce_detection=False)
|
11 |
+
|
12 |
+
coordinates = face_objs[0]["facial_area"]
|
13 |
+
image = img_path
|
14 |
+
cropped_image = image.crop(
|
15 |
+
(
|
16 |
+
coordinates["x"],
|
17 |
+
coordinates["y"],
|
18 |
+
coordinates["x"] + coordinates["w"],
|
19 |
+
coordinates["y"] + coordinates["h"],
|
20 |
+
)
|
21 |
+
)
|
22 |
+
cropped_image.save(f"Images/test_{currtime}.jpg")
|
23 |
+
return cropped_image
|
main.py
ADDED
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
import warnings
|
3 |
+
from sklearn import metrics
|
4 |
+
from sklearn.metrics import confusion_matrix
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
warnings.filterwarnings("ignore")
|
8 |
+
import torch.utils.data as data
|
9 |
+
import os
|
10 |
+
import argparse
|
11 |
+
from sklearn.metrics import f1_score, confusion_matrix
|
12 |
+
from data_preprocessing.sam import SAM
|
13 |
+
import torch.nn.parallel
|
14 |
+
import torch.backends.cudnn as cudnn
|
15 |
+
import torch.optim
|
16 |
+
import torch.utils.data
|
17 |
+
import torch.utils.data.distributed
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
import torchvision.datasets as datasets
|
20 |
+
import torchvision.transforms as transforms
|
21 |
+
import numpy as np
|
22 |
+
import datetime
|
23 |
+
from torchsampler import ImbalancedDatasetSampler
|
24 |
+
from models.PosterV2_7cls import *
|
25 |
+
|
26 |
+
|
27 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
28 |
+
|
29 |
+
now = datetime.datetime.now()
|
30 |
+
time_str = now.strftime("[%m-%d]-[%H-%M]-")
|
31 |
+
if torch.backends.mps.is_available():
|
32 |
+
device = "mps"
|
33 |
+
elif torch.cuda.is_available():
|
34 |
+
device = "cuda"
|
35 |
+
else:
|
36 |
+
device = "cpu"
|
37 |
+
|
38 |
+
print(f"Using device: {device}")
|
39 |
+
|
40 |
+
parser = argparse.ArgumentParser()
|
41 |
+
parser.add_argument("--data", type=str, default=r"raf-db/DATASET")
|
42 |
+
parser.add_argument(
|
43 |
+
"--data_type",
|
44 |
+
default="RAF-DB",
|
45 |
+
choices=["RAF-DB", "AffectNet-7", "CAER-S"],
|
46 |
+
type=str,
|
47 |
+
help="dataset option",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--checkpoint_path", type=str, default="./checkpoint/" + time_str + "model.pth"
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--best_checkpoint_path",
|
54 |
+
type=str,
|
55 |
+
default="./checkpoint/" + time_str + "model_best.pth",
|
56 |
+
)
|
57 |
+
parser.add_argument(
|
58 |
+
"-j",
|
59 |
+
"--workers",
|
60 |
+
default=4,
|
61 |
+
type=int,
|
62 |
+
metavar="N",
|
63 |
+
help="number of data loading workers",
|
64 |
+
)
|
65 |
+
parser.add_argument(
|
66 |
+
"--epochs", default=200, type=int, metavar="N", help="number of total epochs to run"
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"--start-epoch",
|
70 |
+
default=0,
|
71 |
+
type=int,
|
72 |
+
metavar="N",
|
73 |
+
help="manual epoch number (useful on restarts)",
|
74 |
+
)
|
75 |
+
parser.add_argument("-b", "--batch-size", default=2, type=int, metavar="N")
|
76 |
+
parser.add_argument(
|
77 |
+
"--optimizer", type=str, default="adam", help="Optimizer, adam or sgd."
|
78 |
+
)
|
79 |
+
|
80 |
+
parser.add_argument(
|
81 |
+
"--lr", "--learning-rate", default=0.000035, type=float, metavar="LR", dest="lr"
|
82 |
+
)
|
83 |
+
parser.add_argument("--momentum", default=0.9, type=float, metavar="M")
|
84 |
+
parser.add_argument(
|
85 |
+
"--wd", "--weight-decay", default=1e-4, type=float, metavar="W", dest="weight_decay"
|
86 |
+
)
|
87 |
+
parser.add_argument(
|
88 |
+
"-p", "--print-freq", default=30, type=int, metavar="N", help="print frequency"
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--resume", default=None, type=str, metavar="PATH", help="path to checkpoint"
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"-e", "--evaluate", default=None, type=str, help="evaluate model on test set"
|
95 |
+
)
|
96 |
+
parser.add_argument("--beta", type=float, default=0.6)
|
97 |
+
parser.add_argument("--gpu", type=str, default="0")
|
98 |
+
|
99 |
+
parser.add_argument(
|
100 |
+
"-i", "--image", type=str, help="upload a single image to test the prediction"
|
101 |
+
)
|
102 |
+
parser.add_argument("-t", "--test", type=str, help="test model on single image")
|
103 |
+
args = parser.parse_args()
|
104 |
+
|
105 |
+
|
106 |
+
def main():
|
107 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = device
|
108 |
+
best_acc = 0
|
109 |
+
# print("Training time: " + now.strftime("%m-%d %H:%M"))
|
110 |
+
|
111 |
+
# create model
|
112 |
+
model = pyramid_trans_expr2(img_size=224, num_classes=7)
|
113 |
+
|
114 |
+
model = torch.nn.DataParallel(model)
|
115 |
+
model = model.to(device)
|
116 |
+
|
117 |
+
criterion = torch.nn.CrossEntropyLoss()
|
118 |
+
|
119 |
+
if args.optimizer == "adamw":
|
120 |
+
base_optimizer = torch.optim.AdamW
|
121 |
+
elif args.optimizer == "adam":
|
122 |
+
base_optimizer = torch.optim.Adam
|
123 |
+
elif args.optimizer == "sgd":
|
124 |
+
base_optimizer = torch.optim.SGD
|
125 |
+
else:
|
126 |
+
raise ValueError("Optimizer not supported.")
|
127 |
+
|
128 |
+
optimizer = SAM(
|
129 |
+
model.parameters(),
|
130 |
+
base_optimizer,
|
131 |
+
lr=args.lr,
|
132 |
+
rho=0.05,
|
133 |
+
adaptive=False,
|
134 |
+
)
|
135 |
+
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)
|
136 |
+
recorder = RecorderMeter(args.epochs)
|
137 |
+
recorder1 = RecorderMeter1(args.epochs)
|
138 |
+
|
139 |
+
if args.resume:
|
140 |
+
if os.path.isfile(args.resume):
|
141 |
+
print("=> loading checkpoint '{}'".format(args.resume))
|
142 |
+
checkpoint = torch.load(args.resume)
|
143 |
+
args.start_epoch = checkpoint["epoch"]
|
144 |
+
best_acc = checkpoint["best_acc"]
|
145 |
+
recorder = checkpoint["recorder"]
|
146 |
+
recorder1 = checkpoint["recorder1"]
|
147 |
+
best_acc = best_acc.to()
|
148 |
+
model.load_state_dict(checkpoint["state_dict"])
|
149 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
150 |
+
print(
|
151 |
+
"=> loaded checkpoint '{}' (epoch {})".format(
|
152 |
+
args.resume, checkpoint["epoch"]
|
153 |
+
)
|
154 |
+
)
|
155 |
+
else:
|
156 |
+
print("=> no checkpoint found at '{}'".format(args.resume))
|
157 |
+
cudnn.benchmark = True
|
158 |
+
|
159 |
+
# Data loading code
|
160 |
+
traindir = os.path.join(args.data, "train")
|
161 |
+
|
162 |
+
valdir = os.path.join(args.data, "test")
|
163 |
+
|
164 |
+
if args.evaluate is None:
|
165 |
+
if args.data_type == "RAF-DB":
|
166 |
+
train_dataset = datasets.ImageFolder(
|
167 |
+
traindir,
|
168 |
+
transforms.Compose(
|
169 |
+
[
|
170 |
+
transforms.Resize((224, 224)),
|
171 |
+
transforms.RandomHorizontalFlip(),
|
172 |
+
transforms.ToTensor(),
|
173 |
+
transforms.Normalize(
|
174 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
175 |
+
),
|
176 |
+
transforms.RandomErasing(scale=(0.02, 0.1)),
|
177 |
+
]
|
178 |
+
),
|
179 |
+
)
|
180 |
+
else:
|
181 |
+
train_dataset = datasets.ImageFolder(
|
182 |
+
traindir,
|
183 |
+
transforms.Compose(
|
184 |
+
[
|
185 |
+
transforms.Resize((224, 224)),
|
186 |
+
transforms.RandomHorizontalFlip(),
|
187 |
+
transforms.ToTensor(),
|
188 |
+
transforms.Normalize(
|
189 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
190 |
+
),
|
191 |
+
transforms.RandomErasing(p=1, scale=(0.05, 0.05)),
|
192 |
+
]
|
193 |
+
),
|
194 |
+
)
|
195 |
+
|
196 |
+
if args.data_type == "AffectNet-7":
|
197 |
+
train_loader = torch.utils.data.DataLoader(
|
198 |
+
train_dataset,
|
199 |
+
sampler=ImbalancedDatasetSampler(train_dataset),
|
200 |
+
batch_size=args.batch_size,
|
201 |
+
shuffle=False,
|
202 |
+
num_workers=args.workers,
|
203 |
+
pin_memory=True,
|
204 |
+
)
|
205 |
+
|
206 |
+
else:
|
207 |
+
train_loader = torch.utils.data.DataLoader(
|
208 |
+
train_dataset,
|
209 |
+
batch_size=args.batch_size,
|
210 |
+
shuffle=True,
|
211 |
+
num_workers=args.workers,
|
212 |
+
pin_memory=True,
|
213 |
+
)
|
214 |
+
|
215 |
+
test_dataset = datasets.ImageFolder(
|
216 |
+
valdir,
|
217 |
+
transforms.Compose(
|
218 |
+
[
|
219 |
+
transforms.Resize((224, 224)),
|
220 |
+
transforms.ToTensor(),
|
221 |
+
transforms.Normalize(
|
222 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
223 |
+
),
|
224 |
+
]
|
225 |
+
),
|
226 |
+
)
|
227 |
+
|
228 |
+
val_loader = torch.utils.data.DataLoader(
|
229 |
+
test_dataset,
|
230 |
+
batch_size=args.batch_size,
|
231 |
+
shuffle=False,
|
232 |
+
num_workers=args.workers,
|
233 |
+
pin_memory=True,
|
234 |
+
)
|
235 |
+
|
236 |
+
if args.evaluate is not None:
|
237 |
+
from validation import validate
|
238 |
+
|
239 |
+
if os.path.isfile(args.evaluate):
|
240 |
+
print("=> loading checkpoint '{}'".format(args.evaluate))
|
241 |
+
checkpoint = torch.load(args.evaluate, map_location=device)
|
242 |
+
best_acc = checkpoint["best_acc"]
|
243 |
+
best_acc = best_acc.to()
|
244 |
+
print(f"best_acc:{best_acc}")
|
245 |
+
model.load_state_dict(checkpoint["state_dict"])
|
246 |
+
print(
|
247 |
+
"=> loaded checkpoint '{}' (epoch {})".format(
|
248 |
+
args.evaluate, checkpoint["epoch"]
|
249 |
+
)
|
250 |
+
)
|
251 |
+
else:
|
252 |
+
print("=> no checkpoint found at '{}'".format(args.evaluate))
|
253 |
+
validate(val_loader, model, criterion, args)
|
254 |
+
return
|
255 |
+
|
256 |
+
if args.test is not None:
|
257 |
+
from prediction import predict
|
258 |
+
|
259 |
+
if os.path.isfile(args.test):
|
260 |
+
print("=> loading checkpoint '{}'".format(args.test))
|
261 |
+
checkpoint = torch.load(args.test, map_location=device)
|
262 |
+
best_acc = checkpoint["best_acc"]
|
263 |
+
best_acc = best_acc.to()
|
264 |
+
print(f"best_acc:{best_acc}")
|
265 |
+
model.load_state_dict(checkpoint["state_dict"])
|
266 |
+
print(
|
267 |
+
"=> loaded checkpoint '{}' (epoch {})".format(
|
268 |
+
args.test, checkpoint["epoch"]
|
269 |
+
)
|
270 |
+
)
|
271 |
+
else:
|
272 |
+
print("=> no checkpoint found at '{}'".format(args.test))
|
273 |
+
predict(model, image_path=args.image)
|
274 |
+
|
275 |
+
return
|
276 |
+
matrix = None
|
277 |
+
|
278 |
+
for epoch in range(args.start_epoch, args.epochs):
|
279 |
+
current_learning_rate = optimizer.state_dict()["param_groups"][0]["lr"]
|
280 |
+
print("Current learning rate: ", current_learning_rate)
|
281 |
+
txt_name = "./log/" + time_str + "log.txt"
|
282 |
+
with open(txt_name, "a") as f:
|
283 |
+
f.write("Current learning rate: " + str(current_learning_rate) + "\n")
|
284 |
+
|
285 |
+
# train for one epoch
|
286 |
+
train_acc, train_los = train(
|
287 |
+
train_loader, model, criterion, optimizer, epoch, args
|
288 |
+
)
|
289 |
+
|
290 |
+
# evaluate on validation set
|
291 |
+
val_acc, val_los, output, target, D = validate(
|
292 |
+
val_loader, model, criterion, args
|
293 |
+
)
|
294 |
+
|
295 |
+
scheduler.step()
|
296 |
+
|
297 |
+
recorder.update(epoch, train_los, train_acc, val_los, val_acc)
|
298 |
+
recorder1.update(output, target)
|
299 |
+
|
300 |
+
curve_name = time_str + "cnn.png"
|
301 |
+
recorder.plot_curve(os.path.join("./log/", curve_name))
|
302 |
+
|
303 |
+
# remember best acc and save checkpoint
|
304 |
+
is_best = val_acc > best_acc
|
305 |
+
best_acc = max(val_acc, best_acc)
|
306 |
+
|
307 |
+
print("Current best accuracy: ", best_acc.item())
|
308 |
+
|
309 |
+
if is_best:
|
310 |
+
matrix = D
|
311 |
+
|
312 |
+
print("Current best matrix: ", matrix)
|
313 |
+
|
314 |
+
txt_name = "./log/" + time_str + "log.txt"
|
315 |
+
with open(txt_name, "a") as f:
|
316 |
+
f.write("Current best accuracy: " + str(best_acc.item()) + "\n")
|
317 |
+
|
318 |
+
save_checkpoint(
|
319 |
+
{
|
320 |
+
"epoch": epoch + 1,
|
321 |
+
"state_dict": model.state_dict(),
|
322 |
+
"best_acc": best_acc,
|
323 |
+
"optimizer": optimizer.state_dict(),
|
324 |
+
"recorder1": recorder1,
|
325 |
+
"recorder": recorder,
|
326 |
+
},
|
327 |
+
is_best,
|
328 |
+
args,
|
329 |
+
)
|
330 |
+
|
331 |
+
|
332 |
+
def train(train_loader, model, criterion, optimizer, epoch, args):
|
333 |
+
losses = AverageMeter("Loss", ":.4f")
|
334 |
+
top1 = AverageMeter("Accuracy", ":6.3f")
|
335 |
+
progress = ProgressMeter(
|
336 |
+
len(train_loader), [losses, top1], prefix="Epoch: [{}]".format(epoch)
|
337 |
+
)
|
338 |
+
|
339 |
+
# switch to train mode
|
340 |
+
model.train()
|
341 |
+
|
342 |
+
for i, (images, target) in enumerate(train_loader):
|
343 |
+
images = images.to(device)
|
344 |
+
target = target.to(device)
|
345 |
+
|
346 |
+
# compute output
|
347 |
+
output = model(images)
|
348 |
+
loss = criterion(output, target)
|
349 |
+
|
350 |
+
# measure accuracy and record loss
|
351 |
+
acc1, _ = accuracy(output, target, topk=(1, 5))
|
352 |
+
losses.update(loss.item(), images.size(0))
|
353 |
+
top1.update(acc1[0], images.size(0))
|
354 |
+
|
355 |
+
# compute gradient and do SGD step
|
356 |
+
optimizer.zero_grad()
|
357 |
+
loss.backward()
|
358 |
+
# optimizer.step()
|
359 |
+
optimizer.first_step(zero_grad=True)
|
360 |
+
images = images.to(device)
|
361 |
+
target = target.to(device)
|
362 |
+
|
363 |
+
# compute output
|
364 |
+
output = model(images)
|
365 |
+
loss = criterion(output, target)
|
366 |
+
|
367 |
+
# measure accuracy and record loss
|
368 |
+
acc1, _ = accuracy(output, target, topk=(1, 5))
|
369 |
+
losses.update(loss.item(), images.size(0))
|
370 |
+
top1.update(acc1[0], images.size(0))
|
371 |
+
|
372 |
+
# compute gradient and do SGD step
|
373 |
+
optimizer.zero_grad()
|
374 |
+
loss.backward()
|
375 |
+
optimizer.second_step(zero_grad=True)
|
376 |
+
|
377 |
+
# print loss and accuracy
|
378 |
+
if i % args.print_freq == 0:
|
379 |
+
progress.display(i)
|
380 |
+
|
381 |
+
return top1.avg, losses.avg
|
382 |
+
|
383 |
+
|
384 |
+
def save_checkpoint(state, is_best, args):
|
385 |
+
torch.save(state, args.checkpoint_path)
|
386 |
+
if is_best:
|
387 |
+
best_state = state.pop("optimizer")
|
388 |
+
torch.save(best_state, args.best_checkpoint_path)
|
389 |
+
|
390 |
+
|
391 |
+
class AverageMeter(object):
|
392 |
+
"""Computes and stores the average and current value"""
|
393 |
+
|
394 |
+
def __init__(self, name, fmt=":f"):
|
395 |
+
self.name = name
|
396 |
+
self.fmt = fmt
|
397 |
+
self.reset()
|
398 |
+
|
399 |
+
def reset(self):
|
400 |
+
self.val = 0
|
401 |
+
self.avg = 0
|
402 |
+
self.sum = 0
|
403 |
+
self.count = 0
|
404 |
+
|
405 |
+
def update(self, val, n=1):
|
406 |
+
self.val = val
|
407 |
+
self.sum += val * n
|
408 |
+
self.count += n
|
409 |
+
self.avg = self.sum / self.count
|
410 |
+
|
411 |
+
def __str__(self):
|
412 |
+
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
|
413 |
+
return fmtstr.format(**self.__dict__)
|
414 |
+
|
415 |
+
|
416 |
+
class ProgressMeter(object):
|
417 |
+
def __init__(self, num_batches, meters, prefix=""):
|
418 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
419 |
+
self.meters = meters
|
420 |
+
self.prefix = prefix
|
421 |
+
|
422 |
+
def display(self, batch):
|
423 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
424 |
+
entries += [str(meter) for meter in self.meters]
|
425 |
+
print_txt = "\t".join(entries)
|
426 |
+
print(print_txt)
|
427 |
+
txt_name = "./log/" + time_str + "log.txt"
|
428 |
+
with open(txt_name, "a") as f:
|
429 |
+
f.write(print_txt + "\n")
|
430 |
+
|
431 |
+
def _get_batch_fmtstr(self, num_batches):
|
432 |
+
num_digits = len(str(num_batches // 1))
|
433 |
+
fmt = "{:" + str(num_digits) + "d}"
|
434 |
+
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
|
435 |
+
|
436 |
+
|
437 |
+
def accuracy(output, target, topk=(1,)):
|
438 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
439 |
+
with torch.no_grad():
|
440 |
+
maxk = max(topk)
|
441 |
+
batch_size = target.size(0)
|
442 |
+
_, pred = output.topk(maxk, 1, True, True)
|
443 |
+
pred = pred.t()
|
444 |
+
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
445 |
+
res = []
|
446 |
+
for k in topk:
|
447 |
+
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
|
448 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
449 |
+
return res
|
450 |
+
|
451 |
+
|
452 |
+
labels = ["A", "B", "C", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O"]
|
453 |
+
|
454 |
+
|
455 |
+
class RecorderMeter1(object):
|
456 |
+
"""Computes and stores the minimum loss value and its epoch index"""
|
457 |
+
|
458 |
+
def __init__(self, total_epoch):
|
459 |
+
self.reset(total_epoch)
|
460 |
+
|
461 |
+
def reset(self, total_epoch):
|
462 |
+
self.total_epoch = total_epoch
|
463 |
+
self.current_epoch = 0
|
464 |
+
self.epoch_losses = np.zeros(
|
465 |
+
(self.total_epoch, 2), dtype=np.float32
|
466 |
+
) # [epoch, train/val]
|
467 |
+
self.epoch_accuracy = np.zeros(
|
468 |
+
(self.total_epoch, 2), dtype=np.float32
|
469 |
+
) # [epoch, train/val]
|
470 |
+
|
471 |
+
def update(self, output, target):
|
472 |
+
self.y_pred = output
|
473 |
+
self.y_true = target
|
474 |
+
|
475 |
+
def plot_confusion_matrix(self, cm, title="Confusion Matrix", cmap=plt.cm.binary):
|
476 |
+
plt.imshow(cm, interpolation="nearest", cmap=cmap)
|
477 |
+
y_true = self.y_true
|
478 |
+
y_pred = self.y_pred
|
479 |
+
|
480 |
+
plt.title(title)
|
481 |
+
plt.colorbar()
|
482 |
+
xlocations = np.array(range(len(labels)))
|
483 |
+
plt.xticks(xlocations, labels, rotation=90)
|
484 |
+
plt.yticks(xlocations, labels)
|
485 |
+
plt.ylabel("True label")
|
486 |
+
plt.xlabel("Predicted label")
|
487 |
+
|
488 |
+
cm = confusion_matrix(y_true, y_pred)
|
489 |
+
np.set_printoptions(precision=2)
|
490 |
+
cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
|
491 |
+
plt.figure(figsize=(12, 8), dpi=120)
|
492 |
+
|
493 |
+
ind_array = np.arange(len(labels))
|
494 |
+
x, y = np.meshgrid(ind_array, ind_array)
|
495 |
+
for x_val, y_val in zip(x.flatten(), y.flatten()):
|
496 |
+
c = cm_normalized[y_val][x_val]
|
497 |
+
if c > 0.01:
|
498 |
+
plt.text(
|
499 |
+
x_val,
|
500 |
+
y_val,
|
501 |
+
"%0.2f" % (c,),
|
502 |
+
color="red",
|
503 |
+
fontsize=7,
|
504 |
+
va="center",
|
505 |
+
ha="center",
|
506 |
+
)
|
507 |
+
# offset the tick
|
508 |
+
tick_marks = np.arange(len(7))
|
509 |
+
plt.gca().set_xticks(tick_marks, minor=True)
|
510 |
+
plt.gca().set_yticks(tick_marks, minor=True)
|
511 |
+
plt.gca().xaxis.set_ticks_position("none")
|
512 |
+
plt.gca().yaxis.set_ticks_position("none")
|
513 |
+
plt.grid(True, which="minor", linestyle="-")
|
514 |
+
plt.gcf().subplots_adjust(bottom=0.15)
|
515 |
+
|
516 |
+
plot_confusion_matrix(cm_normalized, title="Normalized confusion matrix")
|
517 |
+
# show confusion matrix
|
518 |
+
plt.savefig("./log/confusion_matrix.png", format="png")
|
519 |
+
# fig.savefig(save_path, dpi=dpi, bbox_inches='tight')
|
520 |
+
print("Saved figure")
|
521 |
+
plt.show()
|
522 |
+
|
523 |
+
def matrix(self):
|
524 |
+
target = self.y_true
|
525 |
+
output = self.y_pred
|
526 |
+
im_re_label = np.array(target)
|
527 |
+
im_pre_label = np.array(output)
|
528 |
+
y_ture = im_re_label.flatten()
|
529 |
+
# im_re_label.transpose()
|
530 |
+
y_pred = im_pre_label.flatten()
|
531 |
+
im_pre_label.transpose()
|
532 |
+
|
533 |
+
|
534 |
+
class RecorderMeter(object):
|
535 |
+
"""Computes and stores the minimum loss value and its epoch index"""
|
536 |
+
|
537 |
+
def __init__(self, total_epoch):
|
538 |
+
self.reset(total_epoch)
|
539 |
+
|
540 |
+
def reset(self, total_epoch):
|
541 |
+
self.total_epoch = total_epoch
|
542 |
+
self.current_epoch = 0
|
543 |
+
self.epoch_losses = np.zeros(
|
544 |
+
(self.total_epoch, 2), dtype=np.float32
|
545 |
+
) # [epoch, train/val]
|
546 |
+
self.epoch_accuracy = np.zeros(
|
547 |
+
(self.total_epoch, 2), dtype=np.float32
|
548 |
+
) # [epoch, train/val]
|
549 |
+
|
550 |
+
def update(self, idx, train_loss, train_acc, val_loss, val_acc):
|
551 |
+
self.epoch_losses[idx, 0] = train_loss * 30
|
552 |
+
self.epoch_losses[idx, 1] = val_loss * 30
|
553 |
+
self.epoch_accuracy[idx, 0] = train_acc
|
554 |
+
self.epoch_accuracy[idx, 1] = val_acc
|
555 |
+
self.current_epoch = idx + 1
|
556 |
+
|
557 |
+
def plot_curve(self, save_path):
|
558 |
+
title = "the accuracy/loss curve of train/val"
|
559 |
+
dpi = 80
|
560 |
+
width, height = 1800, 800
|
561 |
+
legend_fontsize = 10
|
562 |
+
figsize = width / float(dpi), height / float(dpi)
|
563 |
+
|
564 |
+
fig = plt.figure(figsize=figsize)
|
565 |
+
x_axis = np.array([i for i in range(self.total_epoch)]) # epochs
|
566 |
+
y_axis = np.zeros(self.total_epoch)
|
567 |
+
|
568 |
+
plt.xlim(0, self.total_epoch)
|
569 |
+
plt.ylim(0, 100)
|
570 |
+
interval_y = 5
|
571 |
+
interval_x = 5
|
572 |
+
plt.xticks(np.arange(0, self.total_epoch + interval_x, interval_x))
|
573 |
+
plt.yticks(np.arange(0, 100 + interval_y, interval_y))
|
574 |
+
plt.grid()
|
575 |
+
plt.title(title, fontsize=20)
|
576 |
+
plt.xlabel("the training epoch", fontsize=16)
|
577 |
+
plt.ylabel("accuracy", fontsize=16)
|
578 |
+
|
579 |
+
y_axis[:] = self.epoch_accuracy[:, 0]
|
580 |
+
plt.plot(x_axis, y_axis, color="g", linestyle="-", label="train-accuracy", lw=2)
|
581 |
+
plt.legend(loc=4, fontsize=legend_fontsize)
|
582 |
+
|
583 |
+
y_axis[:] = self.epoch_accuracy[:, 1]
|
584 |
+
plt.plot(x_axis, y_axis, color="y", linestyle="-", label="valid-accuracy", lw=2)
|
585 |
+
plt.legend(loc=4, fontsize=legend_fontsize)
|
586 |
+
|
587 |
+
y_axis[:] = self.epoch_losses[:, 0]
|
588 |
+
plt.plot(x_axis, y_axis, color="g", linestyle=":", label="train-loss-x30", lw=2)
|
589 |
+
plt.legend(loc=4, fontsize=legend_fontsize)
|
590 |
+
|
591 |
+
y_axis[:] = self.epoch_losses[:, 1]
|
592 |
+
plt.plot(x_axis, y_axis, color="y", linestyle=":", label="valid-loss-x30", lw=2)
|
593 |
+
plt.legend(loc=4, fontsize=legend_fontsize)
|
594 |
+
|
595 |
+
if save_path is not None:
|
596 |
+
fig.savefig(save_path, dpi=dpi, bbox_inches="tight")
|
597 |
+
print("Saved figure")
|
598 |
+
plt.close(fig)
|
599 |
+
|
600 |
+
|
601 |
+
if __name__ == "__main__":
|
602 |
+
main()
|
models/.DS_Store
ADDED
Binary file (8.2 kB). View file
|
|
models/PosterV2_7cls.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
from .mobilefacenet import MobileFaceNet
|
6 |
+
from .ir50 import Backbone
|
7 |
+
from .vit_model import VisionTransformer, PatchEmbed
|
8 |
+
from timm.models.layers import trunc_normal_, DropPath
|
9 |
+
from thop import profile
|
10 |
+
|
11 |
+
|
12 |
+
def load_pretrained_weights(model, checkpoint):
|
13 |
+
import collections
|
14 |
+
|
15 |
+
if "state_dict" in checkpoint:
|
16 |
+
state_dict = checkpoint["state_dict"]
|
17 |
+
else:
|
18 |
+
state_dict = checkpoint
|
19 |
+
model_dict = model.state_dict()
|
20 |
+
new_state_dict = collections.OrderedDict()
|
21 |
+
matched_layers, discarded_layers = [], []
|
22 |
+
for k, v in state_dict.items():
|
23 |
+
# If the pretrained state_dict was saved as nn.DataParallel,
|
24 |
+
# keys would contain "module.", which should be ignored.
|
25 |
+
if k.startswith("module."):
|
26 |
+
k = k[7:]
|
27 |
+
if k in model_dict and model_dict[k].size() == v.size():
|
28 |
+
new_state_dict[k] = v
|
29 |
+
matched_layers.append(k)
|
30 |
+
else:
|
31 |
+
discarded_layers.append(k)
|
32 |
+
# new_state_dict.requires_grad = False
|
33 |
+
model_dict.update(new_state_dict)
|
34 |
+
|
35 |
+
model.load_state_dict(model_dict)
|
36 |
+
print("load_weight", len(matched_layers))
|
37 |
+
return model
|
38 |
+
|
39 |
+
|
40 |
+
def window_partition(x, window_size, h_w, w_w):
|
41 |
+
"""
|
42 |
+
Args:
|
43 |
+
x: (B, H, W, C)
|
44 |
+
window_size: window size
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
local window features (num_windows*B, window_size, window_size, C)
|
48 |
+
"""
|
49 |
+
B, H, W, C = x.shape
|
50 |
+
x = x.view(B, h_w, window_size, w_w, window_size, C)
|
51 |
+
windows = (
|
52 |
+
x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
53 |
+
)
|
54 |
+
return windows
|
55 |
+
|
56 |
+
|
57 |
+
class window(nn.Module):
|
58 |
+
def __init__(self, window_size, dim):
|
59 |
+
super(window, self).__init__()
|
60 |
+
self.window_size = window_size
|
61 |
+
self.norm = nn.LayerNorm(dim)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
x = x.permute(0, 2, 3, 1)
|
65 |
+
B, H, W, C = x.shape
|
66 |
+
x = self.norm(x)
|
67 |
+
shortcut = x
|
68 |
+
h_w = int(torch.div(H, self.window_size).item())
|
69 |
+
w_w = int(torch.div(W, self.window_size).item())
|
70 |
+
x_windows = window_partition(x, self.window_size, h_w, w_w)
|
71 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
|
72 |
+
return x_windows, shortcut
|
73 |
+
|
74 |
+
|
75 |
+
class WindowAttentionGlobal(nn.Module):
|
76 |
+
"""
|
77 |
+
Global window attention based on: "Hatamizadeh et al.,
|
78 |
+
Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
dim,
|
84 |
+
num_heads,
|
85 |
+
window_size,
|
86 |
+
qkv_bias=True,
|
87 |
+
qk_scale=None,
|
88 |
+
attn_drop=0.0,
|
89 |
+
proj_drop=0.0,
|
90 |
+
):
|
91 |
+
"""
|
92 |
+
Args:
|
93 |
+
dim: feature size dimension.
|
94 |
+
num_heads: number of attention head.
|
95 |
+
window_size: window size.
|
96 |
+
qkv_bias: bool argument for query, key, value learnable bias.
|
97 |
+
qk_scale: bool argument to scaling query, key.
|
98 |
+
attn_drop: attention dropout rate.
|
99 |
+
proj_drop: output dropout rate.
|
100 |
+
"""
|
101 |
+
|
102 |
+
super().__init__()
|
103 |
+
window_size = (window_size, window_size)
|
104 |
+
self.window_size = window_size
|
105 |
+
self.num_heads = num_heads
|
106 |
+
head_dim = torch.div(dim, num_heads)
|
107 |
+
self.scale = qk_scale or head_dim**-0.5
|
108 |
+
self.relative_position_bias_table = nn.Parameter(
|
109 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
110 |
+
)
|
111 |
+
coords_h = torch.arange(self.window_size[0])
|
112 |
+
coords_w = torch.arange(self.window_size[1])
|
113 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
114 |
+
coords_flatten = torch.flatten(coords, 1)
|
115 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
116 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
117 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
118 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
119 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
120 |
+
relative_position_index = relative_coords.sum(-1)
|
121 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
122 |
+
self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
123 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
124 |
+
self.proj = nn.Linear(dim, dim)
|
125 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
126 |
+
trunc_normal_(self.relative_position_bias_table, std=0.02)
|
127 |
+
self.softmax = nn.Softmax(dim=-1)
|
128 |
+
|
129 |
+
def forward(self, x, q_global):
|
130 |
+
# print(f'q_global.shape:{q_global.shape}')
|
131 |
+
# print(f'x.shape:{x.shape}')
|
132 |
+
B_, N, C = x.shape
|
133 |
+
B = q_global.shape[0]
|
134 |
+
head_dim = int(torch.div(C, self.num_heads).item())
|
135 |
+
B_dim = int(torch.div(B_, B).item())
|
136 |
+
kv = (
|
137 |
+
self.qkv(x)
|
138 |
+
.reshape(B_, N, 2, self.num_heads, head_dim)
|
139 |
+
.permute(2, 0, 3, 1, 4)
|
140 |
+
)
|
141 |
+
k, v = kv[0], kv[1]
|
142 |
+
q_global = q_global.repeat(1, B_dim, 1, 1, 1)
|
143 |
+
q = q_global.reshape(B_, self.num_heads, N, head_dim)
|
144 |
+
q = q * self.scale
|
145 |
+
attn = q @ k.transpose(-2, -1)
|
146 |
+
relative_position_bias = self.relative_position_bias_table[
|
147 |
+
self.relative_position_index.view(-1)
|
148 |
+
].view(
|
149 |
+
self.window_size[0] * self.window_size[1],
|
150 |
+
self.window_size[0] * self.window_size[1],
|
151 |
+
-1,
|
152 |
+
)
|
153 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
154 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
155 |
+
attn = self.softmax(attn)
|
156 |
+
attn = self.attn_drop(attn)
|
157 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
158 |
+
x = self.proj(x)
|
159 |
+
x = self.proj_drop(x)
|
160 |
+
return x
|
161 |
+
|
162 |
+
|
163 |
+
def _to_channel_last(x):
|
164 |
+
"""
|
165 |
+
Args:
|
166 |
+
x: (B, C, H, W)
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
x: (B, H, W, C)
|
170 |
+
"""
|
171 |
+
return x.permute(0, 2, 3, 1)
|
172 |
+
|
173 |
+
|
174 |
+
def _to_channel_first(x):
|
175 |
+
return x.permute(0, 3, 1, 2)
|
176 |
+
|
177 |
+
|
178 |
+
def _to_query(x, N, num_heads, dim_head):
|
179 |
+
B = x.shape[0]
|
180 |
+
x = x.reshape(B, 1, N, num_heads, dim_head).permute(0, 1, 3, 2, 4)
|
181 |
+
return x
|
182 |
+
|
183 |
+
|
184 |
+
class Mlp(nn.Module):
|
185 |
+
"""
|
186 |
+
Multi-Layer Perceptron (MLP) block
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __init__(
|
190 |
+
self,
|
191 |
+
in_features,
|
192 |
+
hidden_features=None,
|
193 |
+
out_features=None,
|
194 |
+
act_layer=nn.GELU,
|
195 |
+
drop=0.0,
|
196 |
+
):
|
197 |
+
"""
|
198 |
+
Args:
|
199 |
+
in_features: input features dimension.
|
200 |
+
hidden_features: hidden features dimension.
|
201 |
+
out_features: output features dimension.
|
202 |
+
act_layer: activation function.
|
203 |
+
drop: dropout rate.
|
204 |
+
"""
|
205 |
+
|
206 |
+
super().__init__()
|
207 |
+
out_features = out_features or in_features
|
208 |
+
hidden_features = hidden_features or in_features
|
209 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
210 |
+
self.act = act_layer()
|
211 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
212 |
+
self.drop = nn.Dropout(drop)
|
213 |
+
|
214 |
+
def forward(self, x):
|
215 |
+
x = self.fc1(x)
|
216 |
+
x = self.act(x)
|
217 |
+
x = self.drop(x)
|
218 |
+
x = self.fc2(x)
|
219 |
+
x = self.drop(x)
|
220 |
+
return x
|
221 |
+
|
222 |
+
|
223 |
+
def window_reverse(windows, window_size, H, W, h_w, w_w):
|
224 |
+
"""
|
225 |
+
Args:
|
226 |
+
windows: local window features (num_windows*B, window_size, window_size, C)
|
227 |
+
window_size: Window size
|
228 |
+
H: Height of image
|
229 |
+
W: Width of image
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
x: (B, H, W, C)
|
233 |
+
"""
|
234 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
235 |
+
x = windows.view(B, h_w, w_w, window_size, window_size, -1)
|
236 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
237 |
+
return x
|
238 |
+
|
239 |
+
|
240 |
+
class feedforward(nn.Module):
|
241 |
+
def __init__(
|
242 |
+
self,
|
243 |
+
dim,
|
244 |
+
window_size,
|
245 |
+
mlp_ratio=4.0,
|
246 |
+
act_layer=nn.GELU,
|
247 |
+
drop=0.0,
|
248 |
+
drop_path=0.0,
|
249 |
+
layer_scale=None,
|
250 |
+
):
|
251 |
+
super(feedforward, self).__init__()
|
252 |
+
if layer_scale is not None and type(layer_scale) in [int, float]:
|
253 |
+
self.layer_scale = True
|
254 |
+
self.gamma1 = nn.Parameter(
|
255 |
+
layer_scale * torch.ones(dim), requires_grad=True
|
256 |
+
)
|
257 |
+
self.gamma2 = nn.Parameter(
|
258 |
+
layer_scale * torch.ones(dim), requires_grad=True
|
259 |
+
)
|
260 |
+
else:
|
261 |
+
self.gamma1 = 1.0
|
262 |
+
self.gamma2 = 1.0
|
263 |
+
self.window_size = window_size
|
264 |
+
self.mlp = Mlp(
|
265 |
+
in_features=dim,
|
266 |
+
hidden_features=int(dim * mlp_ratio),
|
267 |
+
act_layer=act_layer,
|
268 |
+
drop=drop,
|
269 |
+
)
|
270 |
+
self.norm = nn.LayerNorm(dim)
|
271 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
272 |
+
|
273 |
+
def forward(self, attn_windows, shortcut):
|
274 |
+
B, H, W, C = shortcut.shape
|
275 |
+
h_w = int(torch.div(H, self.window_size).item())
|
276 |
+
w_w = int(torch.div(W, self.window_size).item())
|
277 |
+
x = window_reverse(attn_windows, self.window_size, H, W, h_w, w_w)
|
278 |
+
x = shortcut + self.drop_path(self.gamma1 * x)
|
279 |
+
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm(x)))
|
280 |
+
return x
|
281 |
+
|
282 |
+
|
283 |
+
class pyramid_trans_expr2(nn.Module):
|
284 |
+
def __init__(
|
285 |
+
self,
|
286 |
+
img_size=224,
|
287 |
+
num_classes=7,
|
288 |
+
window_size=[28, 14, 7],
|
289 |
+
num_heads=[2, 4, 8],
|
290 |
+
dims=[64, 128, 256],
|
291 |
+
embed_dim=768,
|
292 |
+
):
|
293 |
+
super().__init__()
|
294 |
+
|
295 |
+
self.img_size = img_size
|
296 |
+
self.num_heads = num_heads
|
297 |
+
self.dim_head = []
|
298 |
+
for num_head, dim in zip(num_heads, dims):
|
299 |
+
self.dim_head.append(int(torch.div(dim, num_head).item()))
|
300 |
+
self.num_classes = num_classes
|
301 |
+
self.window_size = window_size
|
302 |
+
self.N = [win * win for win in window_size]
|
303 |
+
self.face_landback = MobileFaceNet([112, 112], 136)
|
304 |
+
|
305 |
+
mobilefacenet_path = os.path.join(
|
306 |
+
os.getcwd(), "models/pretrain/mobilefacenet_model_best.pth.tar"
|
307 |
+
)
|
308 |
+
ir50_path = os.path.join(os.getcwd(), "models/pretrain/ir50.pth")
|
309 |
+
|
310 |
+
print(mobilefacenet_path)
|
311 |
+
face_landback_checkpoint = torch.load(
|
312 |
+
mobilefacenet_path,
|
313 |
+
map_location=lambda storage, loc: storage,
|
314 |
+
)
|
315 |
+
self.face_landback.load_state_dict(face_landback_checkpoint["state_dict"])
|
316 |
+
|
317 |
+
for param in self.face_landback.parameters():
|
318 |
+
param.requires_grad = False
|
319 |
+
|
320 |
+
self.VIT = VisionTransformer(depth=2, embed_dim=embed_dim)
|
321 |
+
|
322 |
+
self.ir_back = Backbone(50, 0.0, "ir")
|
323 |
+
ir_checkpoint = torch.load(
|
324 |
+
ir50_path,
|
325 |
+
map_location=lambda storage, loc: storage,
|
326 |
+
)
|
327 |
+
|
328 |
+
self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
|
329 |
+
|
330 |
+
self.attn1 = WindowAttentionGlobal(
|
331 |
+
dim=dims[0], num_heads=num_heads[0], window_size=window_size[0]
|
332 |
+
)
|
333 |
+
self.attn2 = WindowAttentionGlobal(
|
334 |
+
dim=dims[1], num_heads=num_heads[1], window_size=window_size[1]
|
335 |
+
)
|
336 |
+
self.attn3 = WindowAttentionGlobal(
|
337 |
+
dim=dims[2], num_heads=num_heads[2], window_size=window_size[2]
|
338 |
+
)
|
339 |
+
self.window1 = window(window_size=window_size[0], dim=dims[0])
|
340 |
+
self.window2 = window(window_size=window_size[1], dim=dims[1])
|
341 |
+
self.window3 = window(window_size=window_size[2], dim=dims[2])
|
342 |
+
self.conv1 = nn.Conv2d(
|
343 |
+
in_channels=dims[0],
|
344 |
+
out_channels=dims[0],
|
345 |
+
kernel_size=3,
|
346 |
+
stride=2,
|
347 |
+
padding=1,
|
348 |
+
)
|
349 |
+
self.conv2 = nn.Conv2d(
|
350 |
+
in_channels=dims[1],
|
351 |
+
out_channels=dims[1],
|
352 |
+
kernel_size=3,
|
353 |
+
stride=2,
|
354 |
+
padding=1,
|
355 |
+
)
|
356 |
+
self.conv3 = nn.Conv2d(
|
357 |
+
in_channels=dims[2],
|
358 |
+
out_channels=dims[2],
|
359 |
+
kernel_size=3,
|
360 |
+
stride=2,
|
361 |
+
padding=1,
|
362 |
+
)
|
363 |
+
|
364 |
+
dpr = [x.item() for x in torch.linspace(0, 0.5, 5)]
|
365 |
+
self.ffn1 = feedforward(
|
366 |
+
dim=dims[0], window_size=window_size[0], layer_scale=1e-5, drop_path=dpr[0]
|
367 |
+
)
|
368 |
+
self.ffn2 = feedforward(
|
369 |
+
dim=dims[1], window_size=window_size[1], layer_scale=1e-5, drop_path=dpr[1]
|
370 |
+
)
|
371 |
+
self.ffn3 = feedforward(
|
372 |
+
dim=dims[2], window_size=window_size[2], layer_scale=1e-5, drop_path=dpr[2]
|
373 |
+
)
|
374 |
+
|
375 |
+
self.last_face_conv = nn.Conv2d(
|
376 |
+
in_channels=512, out_channels=256, kernel_size=3, padding=1
|
377 |
+
)
|
378 |
+
|
379 |
+
self.embed_q = nn.Sequential(
|
380 |
+
nn.Conv2d(dims[0], 768, kernel_size=3, stride=2, padding=1),
|
381 |
+
nn.Conv2d(768, 768, kernel_size=3, stride=2, padding=1),
|
382 |
+
)
|
383 |
+
self.embed_k = nn.Sequential(
|
384 |
+
nn.Conv2d(dims[1], 768, kernel_size=3, stride=2, padding=1)
|
385 |
+
)
|
386 |
+
self.embed_v = PatchEmbed(img_size=14, patch_size=14, in_c=256, embed_dim=768)
|
387 |
+
|
388 |
+
def forward(self, x):
|
389 |
+
x_face = F.interpolate(x, size=112)
|
390 |
+
x_face1, x_face2, x_face3 = self.face_landback(x_face)
|
391 |
+
x_face3 = self.last_face_conv(x_face3)
|
392 |
+
x_face1, x_face2, x_face3 = (
|
393 |
+
_to_channel_last(x_face1),
|
394 |
+
_to_channel_last(x_face2),
|
395 |
+
_to_channel_last(x_face3),
|
396 |
+
)
|
397 |
+
|
398 |
+
q1, q2, q3 = (
|
399 |
+
_to_query(x_face1, self.N[0], self.num_heads[0], self.dim_head[0]),
|
400 |
+
_to_query(x_face2, self.N[1], self.num_heads[1], self.dim_head[1]),
|
401 |
+
_to_query(x_face3, self.N[2], self.num_heads[2], self.dim_head[2]),
|
402 |
+
)
|
403 |
+
|
404 |
+
x_ir1, x_ir2, x_ir3 = self.ir_back(x)
|
405 |
+
|
406 |
+
x_ir1, x_ir2, x_ir3 = self.conv1(x_ir1), self.conv2(x_ir2), self.conv3(x_ir3)
|
407 |
+
x_window1, shortcut1 = self.window1(x_ir1)
|
408 |
+
x_window2, shortcut2 = self.window2(x_ir2)
|
409 |
+
x_window3, shortcut3 = self.window3(x_ir3)
|
410 |
+
|
411 |
+
o1, o2, o3 = (
|
412 |
+
self.attn1(x_window1, q1),
|
413 |
+
self.attn2(x_window2, q2),
|
414 |
+
self.attn3(x_window3, q3),
|
415 |
+
)
|
416 |
+
|
417 |
+
o1, o2, o3 = (
|
418 |
+
self.ffn1(o1, shortcut1),
|
419 |
+
self.ffn2(o2, shortcut2),
|
420 |
+
self.ffn3(o3, shortcut3),
|
421 |
+
)
|
422 |
+
|
423 |
+
o1, o2, o3 = _to_channel_first(o1), _to_channel_first(o2), _to_channel_first(o3)
|
424 |
+
|
425 |
+
o1, o2, o3 = (
|
426 |
+
self.embed_q(o1).flatten(2).transpose(1, 2),
|
427 |
+
self.embed_k(o2).flatten(2).transpose(1, 2),
|
428 |
+
self.embed_v(o3),
|
429 |
+
)
|
430 |
+
|
431 |
+
o = torch.cat([o1, o2, o3], dim=1)
|
432 |
+
|
433 |
+
out = self.VIT(o)
|
434 |
+
return out
|
435 |
+
|
436 |
+
|
437 |
+
def compute_param_flop():
|
438 |
+
model = pyramid_trans_expr2()
|
439 |
+
img = torch.rand(size=(1, 3, 224, 224))
|
440 |
+
flops, params = profile(model, inputs=(img,))
|
441 |
+
print(f"flops:{flops/1000**3}G,params:{params/1000**2}M")
|
models/PosterV2_8cls.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from .mobilefacenet import MobileFaceNet
|
5 |
+
from .ir50 import Backbone
|
6 |
+
from .vit_model_8 import VisionTransformer, PatchEmbed
|
7 |
+
from timm.models.layers import trunc_normal_, DropPath
|
8 |
+
from thop import profile
|
9 |
+
|
10 |
+
def load_pretrained_weights(model, checkpoint):
|
11 |
+
import collections
|
12 |
+
if 'state_dict' in checkpoint:
|
13 |
+
state_dict = checkpoint['state_dict']
|
14 |
+
else:
|
15 |
+
state_dict = checkpoint
|
16 |
+
model_dict = model.state_dict()
|
17 |
+
new_state_dict = collections.OrderedDict()
|
18 |
+
matched_layers, discarded_layers = [], []
|
19 |
+
for k, v in state_dict.items():
|
20 |
+
# If the pretrained state_dict was saved as nn.DataParallel,
|
21 |
+
# keys would contain "module.", which should be ignored.
|
22 |
+
if k.startswith('module.'):
|
23 |
+
k = k[7:]
|
24 |
+
if k in model_dict and model_dict[k].size() == v.size():
|
25 |
+
new_state_dict[k] = v
|
26 |
+
matched_layers.append(k)
|
27 |
+
else:
|
28 |
+
discarded_layers.append(k)
|
29 |
+
# new_state_dict.requires_grad = False
|
30 |
+
model_dict.update(new_state_dict)
|
31 |
+
|
32 |
+
model.load_state_dict(model_dict)
|
33 |
+
print('load_weight', len(matched_layers))
|
34 |
+
return model
|
35 |
+
|
36 |
+
def window_partition(x, window_size, h_w, w_w):
|
37 |
+
"""
|
38 |
+
Args:
|
39 |
+
x: (B, H, W, C)
|
40 |
+
window_size: window size
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
local window features (num_windows*B, window_size, window_size, C)
|
44 |
+
"""
|
45 |
+
B, H, W, C = x.shape
|
46 |
+
x = x.view(B, h_w, window_size, w_w, window_size, C)
|
47 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
48 |
+
return windows
|
49 |
+
|
50 |
+
class window(nn.Module):
|
51 |
+
def __init__(self, window_size, dim):
|
52 |
+
super(window, self).__init__()
|
53 |
+
self.window_size = window_size
|
54 |
+
self.norm = nn.LayerNorm(dim)
|
55 |
+
def forward(self, x):
|
56 |
+
x = x.permute(0, 2, 3, 1)
|
57 |
+
B, H, W, C = x.shape
|
58 |
+
x = self.norm(x)
|
59 |
+
shortcut = x
|
60 |
+
h_w = int(torch.div(H, self.window_size).item())
|
61 |
+
w_w = int(torch.div(W, self.window_size).item())
|
62 |
+
x_windows = window_partition(x, self.window_size, h_w, w_w)
|
63 |
+
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
|
64 |
+
return x_windows, shortcut
|
65 |
+
|
66 |
+
class WindowAttentionGlobal(nn.Module):
|
67 |
+
"""
|
68 |
+
Global window attention based on: "Hatamizadeh et al.,
|
69 |
+
Global Context Vision Transformers <https://arxiv.org/abs/2206.09959>"
|
70 |
+
"""
|
71 |
+
|
72 |
+
def __init__(self,
|
73 |
+
dim,
|
74 |
+
num_heads,
|
75 |
+
window_size,
|
76 |
+
qkv_bias=True,
|
77 |
+
qk_scale=None,
|
78 |
+
attn_drop=0.,
|
79 |
+
proj_drop=0.,
|
80 |
+
):
|
81 |
+
"""
|
82 |
+
Args:
|
83 |
+
dim: feature size dimension.
|
84 |
+
num_heads: number of attention head.
|
85 |
+
window_size: window size.
|
86 |
+
qkv_bias: bool argument for query, key, value learnable bias.
|
87 |
+
qk_scale: bool argument to scaling query, key.
|
88 |
+
attn_drop: attention dropout rate.
|
89 |
+
proj_drop: output dropout rate.
|
90 |
+
"""
|
91 |
+
|
92 |
+
super().__init__()
|
93 |
+
window_size = (window_size, window_size)
|
94 |
+
self.window_size = window_size
|
95 |
+
self.num_heads = num_heads
|
96 |
+
head_dim = torch.div(dim, num_heads)
|
97 |
+
self.scale = qk_scale or head_dim ** -0.5
|
98 |
+
self.relative_position_bias_table = nn.Parameter(
|
99 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))
|
100 |
+
coords_h = torch.arange(self.window_size[0])
|
101 |
+
coords_w = torch.arange(self.window_size[1])
|
102 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
|
103 |
+
coords_flatten = torch.flatten(coords, 1)
|
104 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
105 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
106 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1
|
107 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
108 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
109 |
+
relative_position_index = relative_coords.sum(-1)
|
110 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
111 |
+
self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
112 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
113 |
+
self.proj = nn.Linear(dim, dim)
|
114 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
115 |
+
trunc_normal_(self.relative_position_bias_table, std=.02)
|
116 |
+
self.softmax = nn.Softmax(dim=-1)
|
117 |
+
|
118 |
+
def forward(self, x, q_global):
|
119 |
+
# print(f'q_global.shape:{q_global.shape}')
|
120 |
+
# print(f'x.shape:{x.shape}')
|
121 |
+
B_, N, C = x.shape
|
122 |
+
B = q_global.shape[0]
|
123 |
+
head_dim = int(torch.div(C, self.num_heads).item())
|
124 |
+
B_dim = int(torch.div(B_, B).item())
|
125 |
+
kv = self.qkv(x).reshape(B_, N, 2, self.num_heads, head_dim).permute(2, 0, 3, 1, 4)
|
126 |
+
k, v = kv[0], kv[1]
|
127 |
+
q_global = q_global.repeat(1, B_dim, 1, 1, 1)
|
128 |
+
q = q_global.reshape(B_, self.num_heads, N, head_dim)
|
129 |
+
q = q * self.scale
|
130 |
+
attn = (q @ k.transpose(-2, -1))
|
131 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
132 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
|
133 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
|
134 |
+
attn = attn + relative_position_bias.unsqueeze(0)
|
135 |
+
attn = self.softmax(attn)
|
136 |
+
attn = self.attn_drop(attn)
|
137 |
+
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
138 |
+
x = self.proj(x)
|
139 |
+
x = self.proj_drop(x)
|
140 |
+
return x
|
141 |
+
|
142 |
+
def _to_channel_last(x):
|
143 |
+
"""
|
144 |
+
Args:
|
145 |
+
x: (B, C, H, W)
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
x: (B, H, W, C)
|
149 |
+
"""
|
150 |
+
return x.permute(0, 2, 3, 1)
|
151 |
+
|
152 |
+
def _to_channel_first(x):
|
153 |
+
return x.permute(0, 3, 1, 2)
|
154 |
+
|
155 |
+
def _to_query(x, N, num_heads, dim_head):
|
156 |
+
B = x.shape[0]
|
157 |
+
x = x.reshape(B, 1, N, num_heads, dim_head).permute(0, 1, 3, 2, 4)
|
158 |
+
return x
|
159 |
+
|
160 |
+
class Mlp(nn.Module):
|
161 |
+
"""
|
162 |
+
Multi-Layer Perceptron (MLP) block
|
163 |
+
"""
|
164 |
+
|
165 |
+
def __init__(self,
|
166 |
+
in_features,
|
167 |
+
hidden_features=None,
|
168 |
+
out_features=None,
|
169 |
+
act_layer=nn.GELU,
|
170 |
+
drop=0.):
|
171 |
+
"""
|
172 |
+
Args:
|
173 |
+
in_features: input features dimension.
|
174 |
+
hidden_features: hidden features dimension.
|
175 |
+
out_features: output features dimension.
|
176 |
+
act_layer: activation function.
|
177 |
+
drop: dropout rate.
|
178 |
+
"""
|
179 |
+
|
180 |
+
super().__init__()
|
181 |
+
out_features = out_features or in_features
|
182 |
+
hidden_features = hidden_features or in_features
|
183 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
184 |
+
self.act = act_layer()
|
185 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
186 |
+
self.drop = nn.Dropout(drop)
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
x = self.fc1(x)
|
190 |
+
x = self.act(x)
|
191 |
+
x = self.drop(x)
|
192 |
+
x = self.fc2(x)
|
193 |
+
x = self.drop(x)
|
194 |
+
return x
|
195 |
+
|
196 |
+
def window_reverse(windows, window_size, H, W, h_w, w_w):
|
197 |
+
"""
|
198 |
+
Args:
|
199 |
+
windows: local window features (num_windows*B, window_size, window_size, C)
|
200 |
+
window_size: Window size
|
201 |
+
H: Height of image
|
202 |
+
W: Width of image
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
x: (B, H, W, C)
|
206 |
+
"""
|
207 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
208 |
+
x = windows.view(B, h_w, w_w, window_size, window_size, -1)
|
209 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
210 |
+
return x
|
211 |
+
|
212 |
+
class feedforward(nn.Module):
|
213 |
+
def __init__(self, dim, window_size, mlp_ratio=4., act_layer=nn.GELU, drop=0., drop_path=0., layer_scale=None):
|
214 |
+
super(feedforward, self).__init__()
|
215 |
+
if layer_scale is not None and type(layer_scale) in [int, float]:
|
216 |
+
self.layer_scale = True
|
217 |
+
self.gamma1 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True)
|
218 |
+
self.gamma2 = nn.Parameter(layer_scale * torch.ones(dim), requires_grad=True)
|
219 |
+
else:
|
220 |
+
self.gamma1 = 1.0
|
221 |
+
self.gamma2 = 1.0
|
222 |
+
self.window_size = window_size
|
223 |
+
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
|
224 |
+
self.norm = nn.LayerNorm(dim)
|
225 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
226 |
+
def forward(self, attn_windows, shortcut):
|
227 |
+
B, H, W, C = shortcut.shape
|
228 |
+
h_w = int(torch.div(H, self.window_size).item())
|
229 |
+
w_w = int(torch.div(W, self.window_size).item())
|
230 |
+
x = window_reverse(attn_windows, self.window_size, H, W, h_w, w_w)
|
231 |
+
x = shortcut + self.drop_path(self.gamma1 * x)
|
232 |
+
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm(x)))
|
233 |
+
return x
|
234 |
+
|
235 |
+
class pyramid_trans_expr2(nn.Module):
|
236 |
+
def __init__(self, img_size=224, num_classes=8, window_size=[28,14,7], num_heads=[2, 4, 8], dims=[64, 128, 256], embed_dim=768):
|
237 |
+
super().__init__()
|
238 |
+
|
239 |
+
self.img_size = img_size
|
240 |
+
self.num_heads = num_heads
|
241 |
+
self.dim_head = []
|
242 |
+
for num_head, dim in zip(num_heads, dims):
|
243 |
+
self.dim_head.append(int(torch.div(dim, num_head).item()))
|
244 |
+
self.num_classes = num_classes
|
245 |
+
self.window_size = window_size
|
246 |
+
self.N = [win * win for win in window_size]
|
247 |
+
self.face_landback = MobileFaceNet([112, 112], 136)
|
248 |
+
face_landback_checkpoint = torch.load(r'./pretrain/mobilefacenet_model_best.pth.tar',
|
249 |
+
map_location=lambda storage, loc: storage)
|
250 |
+
self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
|
251 |
+
|
252 |
+
for param in self.face_landback.parameters():
|
253 |
+
param.requires_grad = False
|
254 |
+
|
255 |
+
self.VIT = VisionTransformer(depth=2, embed_dim=embed_dim, num_classes=num_classes)
|
256 |
+
|
257 |
+
self.ir_back = Backbone(50, 0.0, 'ir')
|
258 |
+
ir_checkpoint = torch.load(r'./pretrain/ir50.pth', map_location=lambda storage, loc: storage)
|
259 |
+
|
260 |
+
self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
|
261 |
+
|
262 |
+
self.attn1 = WindowAttentionGlobal(dim=dims[0], num_heads=num_heads[0], window_size=window_size[0])
|
263 |
+
self.attn2 = WindowAttentionGlobal(dim=dims[1], num_heads=num_heads[1], window_size=window_size[1])
|
264 |
+
self.attn3 = WindowAttentionGlobal(dim=dims[2], num_heads=num_heads[2], window_size=window_size[2])
|
265 |
+
self.window1 = window(window_size=window_size[0], dim=dims[0])
|
266 |
+
self.window2 = window(window_size=window_size[1], dim=dims[1])
|
267 |
+
self.window3 = window(window_size=window_size[2], dim=dims[2])
|
268 |
+
self.conv1 = nn.Conv2d(in_channels=dims[0], out_channels=dims[0], kernel_size=3, stride=2, padding=1)
|
269 |
+
self.conv2 = nn.Conv2d(in_channels=dims[1], out_channels=dims[1], kernel_size=3, stride=2, padding=1)
|
270 |
+
self.conv3 = nn.Conv2d(in_channels=dims[2], out_channels=dims[2], kernel_size=3, stride=2, padding=1)
|
271 |
+
|
272 |
+
dpr = [x.item() for x in torch.linspace(0, 0.5, 5)]
|
273 |
+
self.ffn1 = feedforward(dim=dims[0], window_size=window_size[0], layer_scale=1e-5, drop_path=dpr[0])
|
274 |
+
self.ffn2 = feedforward(dim=dims[1], window_size=window_size[1], layer_scale=1e-5, drop_path=dpr[1])
|
275 |
+
self.ffn3 = feedforward(dim=dims[2], window_size=window_size[2], layer_scale=1e-5, drop_path=dpr[2])
|
276 |
+
|
277 |
+
self.last_face_conv = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
|
278 |
+
|
279 |
+
self.embed_q = nn.Sequential(nn.Conv2d(dims[0], 768, kernel_size=3, stride=2, padding=1),
|
280 |
+
nn.Conv2d(768, 768, kernel_size=3, stride=2, padding=1))
|
281 |
+
self.embed_k = nn.Sequential(nn.Conv2d(dims[1], 768, kernel_size=3, stride=2, padding=1))
|
282 |
+
self.embed_v = PatchEmbed(img_size=14, patch_size=14, in_c=256, embed_dim=768)
|
283 |
+
|
284 |
+
def forward(self, x):
|
285 |
+
x_face = F.interpolate(x, size=112)
|
286 |
+
x_face1 , x_face2, x_face3 = self.face_landback(x_face)
|
287 |
+
x_face3 = self.last_face_conv(x_face3)
|
288 |
+
x_face1, x_face2, x_face3 = _to_channel_last(x_face1), _to_channel_last(x_face2), _to_channel_last(x_face3)
|
289 |
+
|
290 |
+
q1, q2, q3 = _to_query(x_face1, self.N[0], self.num_heads[0], self.dim_head[0]), \
|
291 |
+
_to_query(x_face2, self.N[1], self.num_heads[1], self.dim_head[1]), \
|
292 |
+
_to_query(x_face3, self.N[2], self.num_heads[2], self.dim_head[2])
|
293 |
+
|
294 |
+
x_ir1, x_ir2, x_ir3 = self.ir_back(x)
|
295 |
+
x_ir1, x_ir2, x_ir3 = self.conv1(x_ir1), self.conv2(x_ir2), self.conv3(x_ir3)
|
296 |
+
x_window1, shortcut1 = self.window1(x_ir1)
|
297 |
+
x_window2, shortcut2 = self.window2(x_ir2)
|
298 |
+
x_window3, shortcut3 = self.window3(x_ir3)
|
299 |
+
|
300 |
+
o1, o2, o3 = self.attn1(x_window1, q1), self.attn2(x_window2, q2), self.attn3(x_window3, q3)
|
301 |
+
|
302 |
+
o1, o2, o3 = self.ffn1(o1, shortcut1), self.ffn2(o2, shortcut2), self.ffn3(o3, shortcut3)
|
303 |
+
|
304 |
+
o1, o2, o3 = _to_channel_first(o1), _to_channel_first(o2), _to_channel_first(o3)
|
305 |
+
|
306 |
+
o1, o2, o3 = self.embed_q(o1).flatten(2).transpose(1, 2), self.embed_k(o2).flatten(2).transpose(1, 2), self.embed_v(o3)
|
307 |
+
|
308 |
+
o = torch.cat([o1, o2, o3], dim=1)
|
309 |
+
|
310 |
+
out = self.VIT(o)
|
311 |
+
return out
|
312 |
+
|
313 |
+
def compute_param_flop():
|
314 |
+
model = pyramid_trans_expr2()
|
315 |
+
img = torch.rand(size=(1,3,224,224))
|
316 |
+
flops, params = profile(model, inputs=(img,))
|
317 |
+
print(f'flops:{flops/1000**3}G,params:{params/1000**2}M')
|
models/__pycache__/PosterV2_7cls.cpython-310.pyc
ADDED
Binary file (12.1 kB). View file
|
|
models/__pycache__/PosterV2_7cls.cpython-311.pyc
ADDED
Binary file (24.9 kB). View file
|
|
models/__pycache__/ir50.cpython-310.pyc
ADDED
Binary file (6.01 kB). View file
|
|
models/__pycache__/ir50.cpython-311.pyc
ADDED
Binary file (12 kB). View file
|
|
models/__pycache__/mobilefacenet.cpython-310.pyc
ADDED
Binary file (6.5 kB). View file
|
|
models/__pycache__/mobilefacenet.cpython-311.pyc
ADDED
Binary file (12.6 kB). View file
|
|
models/__pycache__/vit_model.cpython-310.pyc
ADDED
Binary file (19.6 kB). View file
|
|
models/__pycache__/vit_model.cpython-311.pyc
ADDED
Binary file (34.9 kB). View file
|
|
models/ir50.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout2d, Dropout, AvgPool2d, \
|
2 |
+
MaxPool2d, AdaptiveAvgPool2d, Sequential, Module, Parameter
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch
|
5 |
+
from collections import namedtuple
|
6 |
+
import math
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
|
10 |
+
################################## Original Arcface Model #############################################################
|
11 |
+
|
12 |
+
class Flatten(Module):
|
13 |
+
def forward(self, input):
|
14 |
+
return input.view(input.size(0), -1)
|
15 |
+
|
16 |
+
|
17 |
+
def l2_norm(input, axis=1):
|
18 |
+
norm = torch.norm(input, 2, axis, True)
|
19 |
+
output = torch.div(input, norm)
|
20 |
+
return output
|
21 |
+
|
22 |
+
|
23 |
+
class SEModule(Module):
|
24 |
+
def __init__(self, channels, reduction):
|
25 |
+
super(SEModule, self).__init__()
|
26 |
+
self.avg_pool = AdaptiveAvgPool2d(1)
|
27 |
+
self.fc1 = Conv2d(
|
28 |
+
channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
29 |
+
self.relu = ReLU(inplace=True)
|
30 |
+
self.fc2 = Conv2d(
|
31 |
+
channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
32 |
+
self.sigmoid = Sigmoid()
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
module_input = x
|
36 |
+
x = self.avg_pool(x)
|
37 |
+
x = self.fc1(x)
|
38 |
+
x = self.relu(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
x = self.sigmoid(x)
|
41 |
+
return module_input * x
|
42 |
+
|
43 |
+
|
44 |
+
# i = 0
|
45 |
+
|
46 |
+
class bottleneck_IR(Module):
|
47 |
+
def __init__(self, in_channel, depth, stride):
|
48 |
+
super(bottleneck_IR, self).__init__()
|
49 |
+
if in_channel == depth:
|
50 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
51 |
+
else:
|
52 |
+
self.shortcut_layer = Sequential(
|
53 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False), BatchNorm2d(depth))
|
54 |
+
self.res_layer = Sequential(
|
55 |
+
BatchNorm2d(in_channel),
|
56 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
57 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth))
|
58 |
+
i = 0
|
59 |
+
|
60 |
+
def forward(self, x):
|
61 |
+
shortcut = self.shortcut_layer(x)
|
62 |
+
# print(shortcut.shape)
|
63 |
+
# print('---s---')
|
64 |
+
res = self.res_layer(x)
|
65 |
+
# print(res.shape)
|
66 |
+
# print('---r---')
|
67 |
+
# i = i + 50
|
68 |
+
# print(i)
|
69 |
+
# print('50')
|
70 |
+
return res + shortcut
|
71 |
+
|
72 |
+
|
73 |
+
class bottleneck_IR_SE(Module):
|
74 |
+
def __init__(self, in_channel, depth, stride):
|
75 |
+
super(bottleneck_IR_SE, self).__init__()
|
76 |
+
if in_channel == depth:
|
77 |
+
self.shortcut_layer = MaxPool2d(1, stride)
|
78 |
+
else:
|
79 |
+
self.shortcut_layer = Sequential(
|
80 |
+
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
81 |
+
BatchNorm2d(depth))
|
82 |
+
self.res_layer = Sequential(
|
83 |
+
BatchNorm2d(in_channel),
|
84 |
+
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
85 |
+
PReLU(depth),
|
86 |
+
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
87 |
+
BatchNorm2d(depth),
|
88 |
+
SEModule(depth, 16)
|
89 |
+
)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
shortcut = self.shortcut_layer(x)
|
93 |
+
res = self.res_layer(x)
|
94 |
+
return res + shortcut
|
95 |
+
|
96 |
+
|
97 |
+
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
98 |
+
'''A named tuple describing a ResNet block.'''
|
99 |
+
# print('50')
|
100 |
+
|
101 |
+
|
102 |
+
def get_block(in_channel, depth, num_units, stride=2):
|
103 |
+
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
104 |
+
|
105 |
+
|
106 |
+
def get_blocks(num_layers):
|
107 |
+
if num_layers == 50:
|
108 |
+
blocks1 = [
|
109 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
110 |
+
# get_block(in_channel=64, depth=128, num_units=4),
|
111 |
+
# get_block(in_channel=128, depth=256, num_units=14),
|
112 |
+
# get_block(in_channel=256, depth=512, num_units=3)
|
113 |
+
]
|
114 |
+
blocks2 = [
|
115 |
+
# get_block(in_channel=64, depth=64, num_units=3),
|
116 |
+
get_block(in_channel=64, depth=128, num_units=4),
|
117 |
+
# get_block(in_channel=128, depth=256, num_units=14),
|
118 |
+
# get_block(in_channel=256, depth=512, num_units=3)
|
119 |
+
]
|
120 |
+
blocks3 = [
|
121 |
+
# get_block(in_channel=64, depth=64, num_units=3),
|
122 |
+
# get_block(in_channel=64, depth=128, num_units=4),
|
123 |
+
get_block(in_channel=128, depth=256, num_units=14),
|
124 |
+
# get_block(in_channel=256, depth=512, num_units=3)
|
125 |
+
]
|
126 |
+
|
127 |
+
elif num_layers == 100:
|
128 |
+
blocks = [
|
129 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
130 |
+
get_block(in_channel=64, depth=128, num_units=13),
|
131 |
+
get_block(in_channel=128, depth=256, num_units=30),
|
132 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
133 |
+
]
|
134 |
+
elif num_layers == 152:
|
135 |
+
blocks = [
|
136 |
+
get_block(in_channel=64, depth=64, num_units=3),
|
137 |
+
get_block(in_channel=64, depth=128, num_units=8),
|
138 |
+
get_block(in_channel=128, depth=256, num_units=36),
|
139 |
+
get_block(in_channel=256, depth=512, num_units=3)
|
140 |
+
]
|
141 |
+
return blocks1, blocks2, blocks3
|
142 |
+
|
143 |
+
|
144 |
+
class Backbone(Module):
|
145 |
+
def __init__(self, num_layers, drop_ratio, mode='ir'):
|
146 |
+
super(Backbone, self).__init__()
|
147 |
+
# assert num_layers in [50, 100, 152], 'num_layers should be 50,100, or 152'
|
148 |
+
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
149 |
+
blocks1, blocks2, blocks3 = get_blocks(num_layers)
|
150 |
+
# blocks2 = get_blocks(num_layers)
|
151 |
+
if mode == 'ir':
|
152 |
+
unit_module = bottleneck_IR
|
153 |
+
elif mode == 'ir_se':
|
154 |
+
unit_module = bottleneck_IR_SE
|
155 |
+
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
156 |
+
BatchNorm2d(64),
|
157 |
+
PReLU(64))
|
158 |
+
self.output_layer = Sequential(BatchNorm2d(512),
|
159 |
+
Dropout(drop_ratio),
|
160 |
+
Flatten(),
|
161 |
+
Linear(512 * 7 * 7, 512),
|
162 |
+
BatchNorm1d(512))
|
163 |
+
modules1 = []
|
164 |
+
for block in blocks1:
|
165 |
+
for bottleneck in block:
|
166 |
+
modules1.append(
|
167 |
+
unit_module(bottleneck.in_channel,
|
168 |
+
bottleneck.depth,
|
169 |
+
bottleneck.stride))
|
170 |
+
|
171 |
+
modules2 = []
|
172 |
+
for block in blocks2:
|
173 |
+
for bottleneck in block:
|
174 |
+
modules2.append(
|
175 |
+
unit_module(bottleneck.in_channel,
|
176 |
+
bottleneck.depth,
|
177 |
+
bottleneck.stride))
|
178 |
+
|
179 |
+
modules3 = []
|
180 |
+
for block in blocks3:
|
181 |
+
for bottleneck in block:
|
182 |
+
modules3.append(
|
183 |
+
unit_module(bottleneck.in_channel,
|
184 |
+
bottleneck.depth,
|
185 |
+
bottleneck.stride))
|
186 |
+
# modules4 = []
|
187 |
+
# for block in blocks4:
|
188 |
+
# for bottleneck in block:
|
189 |
+
# modules4.append(
|
190 |
+
# unit_module(bottleneck.in_channel,
|
191 |
+
# bottleneck.depth,
|
192 |
+
# bottleneck.stride))
|
193 |
+
self.body1 = Sequential(*modules1)
|
194 |
+
self.body2 = Sequential(*modules2)
|
195 |
+
self.body3 = Sequential(*modules3)
|
196 |
+
# self.body4 = Sequential(*modules4)
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
x = F.interpolate(x, size=112)
|
200 |
+
x = self.input_layer(x)
|
201 |
+
x1 = self.body1(x)
|
202 |
+
x2 = self.body2(x1)
|
203 |
+
x3 = self.body3(x2)
|
204 |
+
|
205 |
+
# x = self.output_layer(x)
|
206 |
+
# return l2_norm(x)
|
207 |
+
|
208 |
+
return x1, x2, x3
|
209 |
+
|
210 |
+
def load_pretrained_weights(model, checkpoint):
|
211 |
+
import collections
|
212 |
+
if 'state_dict' in checkpoint:
|
213 |
+
state_dict = checkpoint['state_dict']
|
214 |
+
else:
|
215 |
+
state_dict = checkpoint
|
216 |
+
model_dict = model.state_dict()
|
217 |
+
new_state_dict = collections.OrderedDict()
|
218 |
+
matched_layers, discarded_layers = [], []
|
219 |
+
for i, (k, v) in enumerate(state_dict.items()):
|
220 |
+
# print(i)
|
221 |
+
|
222 |
+
# If the pretrained state_dict was saved as nn.DataParallel,
|
223 |
+
# keys would contain "module.", which should be ignored.
|
224 |
+
if k.startswith('module.'):
|
225 |
+
k = k[7:]
|
226 |
+
if k in model_dict and model_dict[k].size() == v.size():
|
227 |
+
|
228 |
+
new_state_dict[k] = v
|
229 |
+
matched_layers.append(k)
|
230 |
+
else:
|
231 |
+
# print(k)
|
232 |
+
discarded_layers.append(k)
|
233 |
+
# new_state_dict.requires_grad = False
|
234 |
+
model_dict.update(new_state_dict)
|
235 |
+
model.load_state_dict(model_dict)
|
236 |
+
print('load_weight', len(matched_layers))
|
237 |
+
return model
|
238 |
+
|
239 |
+
# model = Backbone(50, 0.0, 'ir')
|
240 |
+
# ir_checkpoint = torch.load(r'C:\Users\86187\Desktop\project\mixfacial\models\pretrain\new_ir50.pth')
|
241 |
+
# print('hello')
|
242 |
+
# i1, i2, i3 = 0, 0, 0
|
243 |
+
# ir_checkpoint = torch.load(r'C:\Users\86187\Desktop\project\mixfacial\models\pretrain\ir50.pth', map_location=lambda storage, loc: storage)
|
244 |
+
# for (k1, v1), (k2, v2) in zip(model.state_dict().items(), ir_checkpoint.items()):
|
245 |
+
# print(f'k1:{k1}, k2:{k2}')
|
246 |
+
# model.state_dict()[k1] = v2
|
247 |
+
|
248 |
+
# torch.save(model.state_dict(), r'C:\Users\86187\Desktop\project\mixfacial\models\pretrain\new_ir50.pth')
|
249 |
+
# print(k)
|
250 |
+
# if k.startswith('body1'):
|
251 |
+
# i1+=1
|
252 |
+
# if k.startswith('body2'):
|
253 |
+
# i2+=1
|
254 |
+
# if k.startswith('body3'):
|
255 |
+
# i3+=1
|
256 |
+
# print(f'i1:{i1}, i2:{i2}, i3:{i3}')
|
257 |
+
|
258 |
+
# print('-'*100)
|
259 |
+
# ir_checkpoint = torch.load(r'C:\Users\86187\Desktop\project\mixfacial\models\pretrain\ir50.pth', map_location=lambda storage, loc: storage)
|
260 |
+
# le = 0
|
261 |
+
# for k, v in ir_checkpoint.items():
|
262 |
+
# # print(k)
|
263 |
+
# if k.startswith('body'):
|
264 |
+
# if le < i1:
|
265 |
+
# le += 1
|
266 |
+
# key = k.split('.')[0] + str(1) + k.split('.')[1:]
|
267 |
+
# print(key)
|
268 |
+
# # ir_checkpoint = ir_checkpoint["model"]
|
269 |
+
# model = load_pretrained_weights(model, ir_checkpoint)
|
270 |
+
# img = torch.rand(size=(2,3,224,224))
|
271 |
+
# out1, out2, out3 = model(img)
|
272 |
+
# print(out1.shape, out2.shape, out3.shape)
|
models/matrix.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
plt.rcParams['font.sans-serif'] = ['SimHei']
|
7 |
+
plt.rcParams['axes.unicode_minus'] = False
|
8 |
+
|
9 |
+
|
10 |
+
# -*- coding:utf-8 -*-
|
11 |
+
|
12 |
+
def plot_confusion_matrix(cm, classes,
|
13 |
+
normalize=False,
|
14 |
+
title='Confusion matrix',
|
15 |
+
cmap=plt.cm.Blues):
|
16 |
+
"""
|
17 |
+
This function prints and plots the confusion matrix.
|
18 |
+
Normalization can be applied by setting `normalize=True`.
|
19 |
+
"""
|
20 |
+
if normalize:
|
21 |
+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
22 |
+
print("Normalized confusion matrix")
|
23 |
+
else:
|
24 |
+
print('Confusion matrix, without normalization')
|
25 |
+
|
26 |
+
print(cm)
|
27 |
+
|
28 |
+
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
29 |
+
plt.title(title)
|
30 |
+
plt.colorbar()
|
31 |
+
tick_marks = np.arange(len(classes))
|
32 |
+
plt.xticks(tick_marks, classes, fontsize=16)
|
33 |
+
plt.yticks(tick_marks, classes, fontsize=16)
|
34 |
+
|
35 |
+
fmt = '.2f' if normalize else 'd'
|
36 |
+
thresh = cm.max() / 2.
|
37 |
+
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
|
38 |
+
plt.text(j, i, format(cm[i, j], fmt),
|
39 |
+
horizontalalignment="center",
|
40 |
+
color="white" if cm[i, j] > thresh else "black")
|
41 |
+
|
42 |
+
plt.tight_layout()
|
43 |
+
plt.ylabel('True Label',fontsize=12)
|
44 |
+
plt.xlabel('Predicted Label',fontsize=12)
|
45 |
+
plt.show()
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
cnf_matrix = np.array([[ 299 , 6 , 5 , 3 , 1 , 4, 11],
|
50 |
+
[ 9, 51 , 0, 2 , 8, 2 , 2],
|
51 |
+
[ 2 , 1 ,120 , 6 ,13 , 9 , 9],
|
52 |
+
[ 5 , 1 , 7 ,1148 , 2 , 4 , 18],
|
53 |
+
[ 0 , 0 , 9 , 4 ,442 , 1 , 22],
|
54 |
+
[ 2 ,0 , 7 , 3 , 0 ,145 , 5],
|
55 |
+
[ 10 ,0, 6 ,11, 29 , 0, 624]])
|
56 |
+
|
57 |
+
class_names = ["SU", 'FE', 'AN', 'HA', 'SA', 'DI', 'NE']
|
58 |
+
|
59 |
+
|
60 |
+
plt.figure(dpi=200)
|
61 |
+
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
|
62 |
+
title=None)
|
models/mobilefacenet.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, ReLU, Sigmoid, Dropout2d, Dropout, AvgPool2d, \
|
2 |
+
MaxPool2d, AdaptiveAvgPool2d, Sequential, Module, Parameter
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from collections import namedtuple
|
7 |
+
import math
|
8 |
+
import pdb
|
9 |
+
|
10 |
+
|
11 |
+
################################## Original Arcface Model #############################################################
|
12 |
+
######## ccc#######################
|
13 |
+
class Flatten(Module):
|
14 |
+
def forward(self, input):
|
15 |
+
return input.view(input.size(0), -1)
|
16 |
+
|
17 |
+
|
18 |
+
################################## MobileFaceNet #############################################################
|
19 |
+
|
20 |
+
class Conv_block(Module):
|
21 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
22 |
+
super(Conv_block, self).__init__()
|
23 |
+
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding,
|
24 |
+
bias=False)
|
25 |
+
self.bn = BatchNorm2d(out_c)
|
26 |
+
self.prelu = PReLU(out_c)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
x = self.conv(x)
|
30 |
+
x = self.bn(x)
|
31 |
+
x = self.prelu(x)
|
32 |
+
return x
|
33 |
+
|
34 |
+
|
35 |
+
class Linear_block(Module):
|
36 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
37 |
+
super(Linear_block, self).__init__()
|
38 |
+
self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding,
|
39 |
+
bias=False)
|
40 |
+
self.bn = BatchNorm2d(out_c)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
x = self.conv(x)
|
44 |
+
x = self.bn(x)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class Depth_Wise(Module):
|
49 |
+
def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
|
50 |
+
super(Depth_Wise, self).__init__()
|
51 |
+
self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
|
52 |
+
self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
|
53 |
+
self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
|
54 |
+
self.residual = residual
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
if self.residual:
|
58 |
+
short_cut = x
|
59 |
+
x = self.conv(x)
|
60 |
+
x = self.conv_dw(x)
|
61 |
+
x = self.project(x)
|
62 |
+
if self.residual:
|
63 |
+
output = short_cut + x
|
64 |
+
else:
|
65 |
+
output = x
|
66 |
+
return output
|
67 |
+
|
68 |
+
|
69 |
+
class Residual(Module):
|
70 |
+
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
|
71 |
+
super(Residual, self).__init__()
|
72 |
+
modules = []
|
73 |
+
for _ in range(num_block):
|
74 |
+
modules.append(
|
75 |
+
Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
|
76 |
+
self.model = Sequential(*modules)
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
return self.model(x)
|
80 |
+
|
81 |
+
|
82 |
+
class GNAP(Module):
|
83 |
+
def __init__(self, embedding_size):
|
84 |
+
super(GNAP, self).__init__()
|
85 |
+
assert embedding_size == 512
|
86 |
+
self.bn1 = BatchNorm2d(512, affine=False)
|
87 |
+
self.pool = nn.AdaptiveAvgPool2d((1, 1))
|
88 |
+
|
89 |
+
self.bn2 = BatchNorm1d(512, affine=False)
|
90 |
+
|
91 |
+
def forward(self, x):
|
92 |
+
x = self.bn1(x)
|
93 |
+
x_norm = torch.norm(x, 2, 1, True)
|
94 |
+
x_norm_mean = torch.mean(x_norm)
|
95 |
+
weight = x_norm_mean / x_norm
|
96 |
+
x = x * weight
|
97 |
+
x = self.pool(x)
|
98 |
+
x = x.view(x.shape[0], -1)
|
99 |
+
feature = self.bn2(x)
|
100 |
+
return feature
|
101 |
+
|
102 |
+
|
103 |
+
class GDC(Module):
|
104 |
+
def __init__(self, embedding_size):
|
105 |
+
super(GDC, self).__init__()
|
106 |
+
self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0))
|
107 |
+
self.conv_6_flatten = Flatten()
|
108 |
+
self.linear = Linear(512, embedding_size, bias=False)
|
109 |
+
# self.bn = BatchNorm1d(embedding_size, affine=False)
|
110 |
+
self.bn = BatchNorm1d(embedding_size)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
x = self.conv_6_dw(x) #### [B, 512, 1, 1]
|
114 |
+
x = self.conv_6_flatten(x) #### [B, 512]
|
115 |
+
x = self.linear(x) #### [B, 136]
|
116 |
+
x = self.bn(x)
|
117 |
+
return x
|
118 |
+
|
119 |
+
|
120 |
+
class MobileFaceNet(Module):
|
121 |
+
def __init__(self, input_size, embedding_size=512, output_name="GDC"):
|
122 |
+
super(MobileFaceNet, self).__init__()
|
123 |
+
assert output_name in ["GNAP", 'GDC']
|
124 |
+
assert input_size[0] in [112]
|
125 |
+
self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
|
126 |
+
self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
|
127 |
+
self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
|
128 |
+
self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
|
129 |
+
self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
|
130 |
+
self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
|
131 |
+
self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
|
132 |
+
self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
|
133 |
+
self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
|
134 |
+
if output_name == "GNAP":
|
135 |
+
self.output_layer = GNAP(512)
|
136 |
+
else:
|
137 |
+
self.output_layer = GDC(embedding_size)
|
138 |
+
|
139 |
+
self._initialize_weights()
|
140 |
+
|
141 |
+
def _initialize_weights(self):
|
142 |
+
for m in self.modules():
|
143 |
+
if isinstance(m, nn.Conv2d):
|
144 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
145 |
+
if m.bias is not None:
|
146 |
+
m.bias.data.zero_()
|
147 |
+
elif isinstance(m, nn.BatchNorm2d):
|
148 |
+
m.weight.data.fill_(1)
|
149 |
+
m.bias.data.zero_()
|
150 |
+
elif isinstance(m, nn.Linear):
|
151 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
152 |
+
if m.bias is not None:
|
153 |
+
m.bias.data.zero_()
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
out = self.conv1(x)
|
157 |
+
# print(out.shape)
|
158 |
+
out = self.conv2_dw(out)
|
159 |
+
# print(out.shape)
|
160 |
+
out = self.conv_23(out)
|
161 |
+
# print(out.shape)
|
162 |
+
out3 = self.conv_3(out)
|
163 |
+
# print(out.shape)
|
164 |
+
out = self.conv_34(out3)
|
165 |
+
# print(out.shape)
|
166 |
+
out4 = self.conv_4(out) # [128, 14, 14]
|
167 |
+
# print(out.shape)
|
168 |
+
out = self.conv_45(out4) # [128, 7, 7]
|
169 |
+
# print(out.shape)
|
170 |
+
out = self.conv_5(out) # [128, 7, 7]
|
171 |
+
# print(out.shape)
|
172 |
+
conv_features = self.conv_6_sep(out) ##### [B, 512, 7, 7]
|
173 |
+
out = self.output_layer(conv_features) ##### [B, 136]
|
174 |
+
return out3, out4, conv_features
|
175 |
+
|
176 |
+
|
177 |
+
# model = MobileFaceNet([112, 112],136)
|
178 |
+
# input = torch.ones(8,3,112,112).cuda()
|
179 |
+
# model = model.cuda()
|
180 |
+
# x = model(input)
|
181 |
+
# import numpy as np
|
182 |
+
# parameters = model.parameters()
|
183 |
+
# parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
184 |
+
# print('Total Parameters: %.3fM' % parameters)
|
185 |
+
#
|
186 |
+
#
|
187 |
+
# from ptflops import get_model_complexity_info
|
188 |
+
# macs, params = get_model_complexity_info(model, (3, 112, 112), as_strings=True,
|
189 |
+
# print_per_layer_stat=True, verbose=True)
|
190 |
+
# print('{:<30} {:<8}'.format('Computational complexity: ', macs))
|
191 |
+
# print('{:<30} {:<8}'.format('Number of parameters: ', params))
|
192 |
+
#
|
193 |
+
# print(x.shape)
|
models/pretrain/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
models/pretrain/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*
|
2 |
+
!.gitignore
|
models/pretrain/ir50.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:62fcfa833776648f818b15fac4f5b760d76847316097e8e046f77ac445defb75
|
3 |
+
size 122022895
|
models/pretrain/mobilefacenet_model_best.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b994af026bfddbafc507a6f1c8737a9896bab20ed2b0cfb6ae90b81736970313
|
3 |
+
size 12281146
|
models/vit_model.py
ADDED
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
original code from rwightman:
|
3 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
4 |
+
"""
|
5 |
+
from functools import partial
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torch.hub
|
14 |
+
from functools import partial
|
15 |
+
# import mat
|
16 |
+
# from vision_transformer.ir50 import Backbone
|
17 |
+
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import torch.hub
|
23 |
+
from functools import partial
|
24 |
+
import math
|
25 |
+
|
26 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
27 |
+
from timm.models.registry import register_model
|
28 |
+
from timm.models.vision_transformer import _cfg, Mlp, Block
|
29 |
+
# from .ir50 import Backbone
|
30 |
+
|
31 |
+
|
32 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
33 |
+
"""3x3 convolution with padding"""
|
34 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
35 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
36 |
+
|
37 |
+
|
38 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
39 |
+
"""1x1 convolution"""
|
40 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
41 |
+
|
42 |
+
|
43 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
44 |
+
"""
|
45 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
46 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
47 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
48 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
49 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
50 |
+
'survival rate' as the argument.
|
51 |
+
"""
|
52 |
+
if drop_prob == 0. or not training:
|
53 |
+
return x
|
54 |
+
keep_prob = 1 - drop_prob
|
55 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
56 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
57 |
+
random_tensor.floor_() # binarize
|
58 |
+
output = x.div(keep_prob) * random_tensor
|
59 |
+
return output
|
60 |
+
|
61 |
+
|
62 |
+
class BasicBlock(nn.Module):
|
63 |
+
__constants__ = ['downsample']
|
64 |
+
|
65 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
66 |
+
super(BasicBlock, self).__init__()
|
67 |
+
norm_layer = nn.BatchNorm2d
|
68 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
69 |
+
self.bn1 = norm_layer(planes)
|
70 |
+
self.relu = nn.ReLU(inplace=True)
|
71 |
+
self.conv2 = conv3x3(planes, planes)
|
72 |
+
self.bn2 = norm_layer(planes)
|
73 |
+
self.downsample = downsample
|
74 |
+
self.stride = stride
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
identity = x
|
78 |
+
|
79 |
+
out = self.conv1(x)
|
80 |
+
out = self.bn1(out)
|
81 |
+
out = self.relu(out)
|
82 |
+
out = self.conv2(out)
|
83 |
+
out = self.bn2(out)
|
84 |
+
|
85 |
+
if self.downsample is not None:
|
86 |
+
identity = self.downsample(x)
|
87 |
+
|
88 |
+
out += identity
|
89 |
+
out = self.relu(out)
|
90 |
+
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class DropPath(nn.Module):
|
95 |
+
"""
|
96 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, drop_prob=None):
|
100 |
+
super(DropPath, self).__init__()
|
101 |
+
self.drop_prob = drop_prob
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
return drop_path(x, self.drop_prob, self.training)
|
105 |
+
|
106 |
+
|
107 |
+
class PatchEmbed(nn.Module):
|
108 |
+
"""
|
109 |
+
2D Image to Patch Embedding
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(self, img_size=14, patch_size=16, in_c=256, embed_dim=768, norm_layer=None):
|
113 |
+
super().__init__()
|
114 |
+
img_size = (img_size, img_size)
|
115 |
+
patch_size = (patch_size, patch_size)
|
116 |
+
self.img_size = img_size
|
117 |
+
self.patch_size = patch_size
|
118 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
119 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
120 |
+
|
121 |
+
self.proj = nn.Conv2d(256, 768, kernel_size=1)
|
122 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
B, C, H, W = x.shape
|
126 |
+
# assert H == self.img_size[0] and W == self.img_size[1], \
|
127 |
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
128 |
+
# print(x.shape)
|
129 |
+
|
130 |
+
# flatten: [B, C, H, W] -> [B, C, HW]
|
131 |
+
# transpose: [B, C, HW] -> [B, HW, C]
|
132 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
133 |
+
x = self.norm(x)
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
class Attention(nn.Module):
|
138 |
+
def __init__(self,
|
139 |
+
dim, in_chans, # 输入token的dim
|
140 |
+
num_heads=8,
|
141 |
+
qkv_bias=False,
|
142 |
+
qk_scale=None,
|
143 |
+
attn_drop_ratio=0.,
|
144 |
+
proj_drop_ratio=0.):
|
145 |
+
super(Attention, self).__init__()
|
146 |
+
self.num_heads = 8
|
147 |
+
self.img_chanel = in_chans + 1
|
148 |
+
head_dim = dim // num_heads
|
149 |
+
self.scale = head_dim ** -0.5
|
150 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
151 |
+
self.attn_drop = nn.Dropout(attn_drop_ratio)
|
152 |
+
self.proj = nn.Linear(dim, dim)
|
153 |
+
self.proj_drop = nn.Dropout(proj_drop_ratio)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
x_img = x[:, :self.img_chanel, :]
|
157 |
+
# [batch_size, num_patches + 1, total_embed_dim]
|
158 |
+
B, N, C = x_img.shape
|
159 |
+
# print(C)
|
160 |
+
qkv = self.qkv(x_img).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
161 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
162 |
+
# k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
163 |
+
# q = x_img.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
164 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
165 |
+
attn = attn.softmax(dim=-1)
|
166 |
+
attn = self.attn_drop(attn)
|
167 |
+
|
168 |
+
x_img = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
169 |
+
x_img = self.proj(x_img)
|
170 |
+
x_img = self.proj_drop(x_img)
|
171 |
+
#
|
172 |
+
#
|
173 |
+
# # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
|
174 |
+
# # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
|
175 |
+
# # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
|
176 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
177 |
+
# # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
|
178 |
+
# q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
179 |
+
#
|
180 |
+
# # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
|
181 |
+
# # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
|
182 |
+
# attn = (q @ k.transpose(-2, -1)) * self.scale
|
183 |
+
# attn = attn.softmax(dim=-1)
|
184 |
+
# attn = self.attn_drop(attn)
|
185 |
+
#
|
186 |
+
# # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
|
187 |
+
# # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
|
188 |
+
# # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
|
189 |
+
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
190 |
+
# x = self.proj(x)
|
191 |
+
# x = self.proj_drop(x)
|
192 |
+
return x_img
|
193 |
+
|
194 |
+
|
195 |
+
class AttentionBlock(nn.Module):
|
196 |
+
__constants__ = ['downsample']
|
197 |
+
|
198 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
199 |
+
super(AttentionBlock, self).__init__()
|
200 |
+
norm_layer = nn.BatchNorm2d
|
201 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
202 |
+
self.bn1 = norm_layer(planes)
|
203 |
+
self.relu = nn.ReLU(inplace=True)
|
204 |
+
self.conv2 = conv3x3(planes, planes)
|
205 |
+
self.bn2 = norm_layer(planes)
|
206 |
+
self.downsample = downsample
|
207 |
+
self.stride = stride
|
208 |
+
# self.cbam = CBAM(planes, 16)
|
209 |
+
self.inplanes = inplanes
|
210 |
+
self.eca_block = eca_block()
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
identity = x
|
214 |
+
|
215 |
+
out = self.conv1(x)
|
216 |
+
out = self.bn1(out)
|
217 |
+
out = self.relu(out)
|
218 |
+
|
219 |
+
out = self.conv2(out)
|
220 |
+
out = self.bn2(out)
|
221 |
+
inplanes = self.inplanes
|
222 |
+
out = self.eca_block(out)
|
223 |
+
if self.downsample is not None:
|
224 |
+
identity = self.downsample(x)
|
225 |
+
|
226 |
+
out += identity
|
227 |
+
out = self.relu(out)
|
228 |
+
|
229 |
+
return out
|
230 |
+
|
231 |
+
|
232 |
+
class Mlp(nn.Module):
|
233 |
+
"""
|
234 |
+
MLP as used in Vision Transformer, MLP-Mixer and related networks
|
235 |
+
"""
|
236 |
+
|
237 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
238 |
+
super().__init__()
|
239 |
+
out_features = out_features or in_features
|
240 |
+
hidden_features = hidden_features or in_features
|
241 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
242 |
+
self.act = act_layer()
|
243 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
244 |
+
self.drop = nn.Dropout(drop)
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
x = self.fc1(x)
|
248 |
+
x = self.act(x)
|
249 |
+
x = self.drop(x)
|
250 |
+
x = self.fc2(x)
|
251 |
+
x = self.drop(x)
|
252 |
+
return x
|
253 |
+
|
254 |
+
|
255 |
+
class Block(nn.Module):
|
256 |
+
def __init__(self,
|
257 |
+
dim, in_chans,
|
258 |
+
num_heads,
|
259 |
+
mlp_ratio=4.,
|
260 |
+
qkv_bias=False,
|
261 |
+
qk_scale=None,
|
262 |
+
drop_ratio=0.,
|
263 |
+
attn_drop_ratio=0.,
|
264 |
+
drop_path_ratio=0.,
|
265 |
+
act_layer=nn.GELU,
|
266 |
+
norm_layer=nn.LayerNorm):
|
267 |
+
super(Block, self).__init__()
|
268 |
+
self.norm1 = norm_layer(dim)
|
269 |
+
self.img_chanel = in_chans + 1
|
270 |
+
|
271 |
+
self.conv = nn.Conv1d(self.img_chanel, self.img_chanel, 1)
|
272 |
+
self.attn = Attention(dim, in_chans=in_chans, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
273 |
+
attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
|
274 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
275 |
+
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
|
276 |
+
self.norm2 = norm_layer(dim)
|
277 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
278 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
|
279 |
+
|
280 |
+
def forward(self, x):
|
281 |
+
# x = x + self.drop_path(self.attn(self.norm1(x)))
|
282 |
+
# x = x + self.drop_path(self.mlp(self.norm2(x)))
|
283 |
+
|
284 |
+
x_img = x
|
285 |
+
# [:, :self.img_chanel, :]
|
286 |
+
# x_lm = x[:, self.img_chanel:, :]
|
287 |
+
x_img = x_img + self.drop_path(self.attn(self.norm1(x)))
|
288 |
+
x = x_img + self.drop_path(self.mlp(self.norm2(x_img)))
|
289 |
+
#
|
290 |
+
# x_lm = x_lm + self.drop_path(self.attn_lm(self.norm3(x)))
|
291 |
+
# x_lm = x_lm + self.drop_path(self.mlp2(self.norm4(x_lm)))
|
292 |
+
# x = torch.cat((x_img, x_lm), dim=1)
|
293 |
+
# x = self.conv(x)
|
294 |
+
|
295 |
+
return x
|
296 |
+
|
297 |
+
|
298 |
+
class ClassificationHead(nn.Module):
|
299 |
+
def __init__(self, input_dim: int, target_dim: int):
|
300 |
+
super().__init__()
|
301 |
+
self.linear = torch.nn.Linear(input_dim, target_dim)
|
302 |
+
|
303 |
+
def forward(self, x):
|
304 |
+
x = x.view(x.size(0), -1)
|
305 |
+
y_hat = self.linear(x)
|
306 |
+
return y_hat
|
307 |
+
|
308 |
+
|
309 |
+
def load_pretrained_weights(model, checkpoint):
|
310 |
+
import collections
|
311 |
+
if 'state_dict' in checkpoint:
|
312 |
+
state_dict = checkpoint['state_dict']
|
313 |
+
else:
|
314 |
+
state_dict = checkpoint
|
315 |
+
model_dict = model.state_dict()
|
316 |
+
new_state_dict = collections.OrderedDict()
|
317 |
+
matched_layers, discarded_layers = [], []
|
318 |
+
for k, v in state_dict.items():
|
319 |
+
# If the pretrained state_dict was saved as nn.DataParallel,
|
320 |
+
# keys would contain "module.", which should be ignored.
|
321 |
+
if k.startswith('module.'):
|
322 |
+
k = k[7:]
|
323 |
+
if k in model_dict and model_dict[k].size() == v.size():
|
324 |
+
new_state_dict[k] = v
|
325 |
+
matched_layers.append(k)
|
326 |
+
else:
|
327 |
+
discarded_layers.append(k)
|
328 |
+
# new_state_dict.requires_grad = False
|
329 |
+
model_dict.update(new_state_dict)
|
330 |
+
|
331 |
+
model.load_state_dict(model_dict)
|
332 |
+
print('load_weight', len(matched_layers))
|
333 |
+
return model
|
334 |
+
|
335 |
+
class eca_block(nn.Module):
|
336 |
+
def __init__(self, channel=128, b=1, gamma=2):
|
337 |
+
super(eca_block, self).__init__()
|
338 |
+
kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
|
339 |
+
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
|
340 |
+
|
341 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
342 |
+
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
|
343 |
+
self.sigmoid = nn.Sigmoid()
|
344 |
+
|
345 |
+
def forward(self, x):
|
346 |
+
y = self.avg_pool(x)
|
347 |
+
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
348 |
+
y = self.sigmoid(y)
|
349 |
+
return x * y.expand_as(x)
|
350 |
+
#
|
351 |
+
#
|
352 |
+
# class IR20(nn.Module):
|
353 |
+
# def __init__(self, img_size_=112, num_classes=7, layers=[2, 2, 2, 2]):
|
354 |
+
# super().__init__()
|
355 |
+
# norm_layer = nn.BatchNorm2d
|
356 |
+
# self.img_size = img_size_
|
357 |
+
# self._norm_layer = norm_layer
|
358 |
+
# self.num_classes = num_classes
|
359 |
+
# self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
360 |
+
# self.bn1 = norm_layer(64)
|
361 |
+
# self.relu = nn.ReLU(inplace=True)
|
362 |
+
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
363 |
+
# # self.face_landback = MobileFaceNet([112, 112],136)
|
364 |
+
# # face_landback_checkpoint = torch.load('./models/pretrain/mobilefacenet_model_best.pth.tar', map_location=lambda storage, loc: storage)
|
365 |
+
# # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
|
366 |
+
# self.layer1 = self._make_layer(BasicBlock, 64, 64, layers[0])
|
367 |
+
# self.layer2 = self._make_layer(BasicBlock, 64, 128, layers[1], stride=2)
|
368 |
+
# self.layer3 = self._make_layer(AttentionBlock, 128, 256, layers[2], stride=2)
|
369 |
+
# self.layer4 = self._make_layer(AttentionBlock, 256, 256, layers[3], stride=1)
|
370 |
+
# self.ir_back = Backbone(50, 51, 52, 0.0, 'ir')
|
371 |
+
# self.ir_layer = nn.Linear(1024, 512)
|
372 |
+
# # ir_checkpoint = torch.load(r'F:\0815crossvit\vision_transformer\models\pretrain\Pretrained_on_MSCeleb.pth.tar',
|
373 |
+
# # map_location=lambda storage, loc: storage)
|
374 |
+
# # ir_checkpoint = ir_checkpoint['state_dict']
|
375 |
+
# # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
|
376 |
+
# # checkpoint = torch.load('./checkpoint/Pretrained_on_MSCeleb.pth.tar')
|
377 |
+
# # pre_trained_dict = checkpoint['state_dict']
|
378 |
+
# # IR20.load_state_dict(ir_checkpoint, strict=False)
|
379 |
+
# # self.IR = load_pretrained_weights(IR, ir_checkpoint)
|
380 |
+
#
|
381 |
+
# def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
382 |
+
# norm_layer = self._norm_layer
|
383 |
+
# downsample = None
|
384 |
+
# if stride != 1 or inplanes != planes:
|
385 |
+
# downsample = nn.Sequential(conv1x1(inplanes, planes, stride), norm_layer(planes))
|
386 |
+
# layers = []
|
387 |
+
# layers.append(block(inplanes, planes, stride, downsample))
|
388 |
+
# inplanes = planes
|
389 |
+
# for _ in range(1, blocks):
|
390 |
+
# layers.append(block(inplanes, planes))
|
391 |
+
# return nn.Sequential(*layers)
|
392 |
+
#
|
393 |
+
# def forward(self, x):
|
394 |
+
# x_ir = self.ir_back(x)
|
395 |
+
# # x_ir = self.ir_layer(x_ir)
|
396 |
+
# # print(x_ir.shape)
|
397 |
+
# # x = F.interpolate(x, size=112)
|
398 |
+
# # x = self.conv1(x)
|
399 |
+
# # x = self.bn1(x)
|
400 |
+
# # x = self.relu(x)
|
401 |
+
# # x = self.maxpool(x)
|
402 |
+
# #
|
403 |
+
# # x = self.layer1(x)
|
404 |
+
# # x = self.layer2(x)
|
405 |
+
# # x = self.layer3(x)
|
406 |
+
# # x = self.layer4(x)
|
407 |
+
# # print(x.shape)
|
408 |
+
# # print(x)
|
409 |
+
# out = x_ir
|
410 |
+
#
|
411 |
+
# return out
|
412 |
+
#
|
413 |
+
#
|
414 |
+
# class IR(nn.Module):
|
415 |
+
# def __init__(self, img_size_=112, num_classes=7):
|
416 |
+
# super().__init__()
|
417 |
+
# depth = 8
|
418 |
+
# # if type == "small":
|
419 |
+
# # depth = 4
|
420 |
+
# # if type == "base":
|
421 |
+
# # depth = 6
|
422 |
+
# # if type == "large":
|
423 |
+
# # depth = 8
|
424 |
+
#
|
425 |
+
# self.img_size = img_size_
|
426 |
+
# self.num_classes = num_classes
|
427 |
+
# self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
428 |
+
# # self.bn1 = norm_layer(64)
|
429 |
+
# self.relu = nn.ReLU(inplace=True)
|
430 |
+
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
431 |
+
# # self.face_landback = MobileFaceNet([112, 112],136)
|
432 |
+
# # face_landback_checkpoint = torch.load('./models/pretrain/mobilefacenet_model_best.pth.tar', map_location=lambda storage, loc: storage)
|
433 |
+
# # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
|
434 |
+
#
|
435 |
+
# # for param in self.face_landback.parameters():
|
436 |
+
# # param.requires_grad = False
|
437 |
+
#
|
438 |
+
# ###########################################################################333
|
439 |
+
#
|
440 |
+
# self.ir_back = IR20()
|
441 |
+
#
|
442 |
+
# # ir_checkpoint = torch.load(r'F:\0815crossvit\vision_transformer\models\pretrain\ir50.pth',
|
443 |
+
# # map_location=lambda storage, loc: storage)
|
444 |
+
# # # ir_checkpoint = ir_checkpoint["model"]
|
445 |
+
# # self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
|
446 |
+
# # load_state_dict(checkpoint_model, strict=False)
|
447 |
+
# # self.ir_layer = nn.Linear(1024,512)
|
448 |
+
#
|
449 |
+
# #############################################################3
|
450 |
+
# #
|
451 |
+
# # self.pyramid_fuse = HyVisionTransformer(in_chans=49, q_chanel = 49, embed_dim=512,
|
452 |
+
# # depth=depth, num_heads=8, mlp_ratio=2.,
|
453 |
+
# # drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1)
|
454 |
+
#
|
455 |
+
# # self.se_block = SE_block(input_dim=512)
|
456 |
+
# self.head = ClassificationHead(input_dim=768, target_dim=self.num_classes)
|
457 |
+
#
|
458 |
+
# def forward(self, x):
|
459 |
+
# B_ = x.shape[0]
|
460 |
+
# # x_face = F.interpolate(x, size=112)
|
461 |
+
# # _, x_face = self.face_landback(x_face)
|
462 |
+
# # x_face = x_face.view(B_, -1, 49).transpose(1,2)
|
463 |
+
# ############### landmark x_face ([B, 49, 512])
|
464 |
+
# x_ir = self.ir_back(x)
|
465 |
+
# # print(x_ir.shape)
|
466 |
+
# # x_ir = self.ir_layer(x_ir)
|
467 |
+
# # print(x_ir.shape)
|
468 |
+
# ############### image x_ir ([B, 49, 512])
|
469 |
+
#
|
470 |
+
# # y_hat = self.pyramid_fuse(x_ir, x_face)
|
471 |
+
# # y_hat = self.se_block(y_hat)
|
472 |
+
# # y_feat = y_hat
|
473 |
+
#
|
474 |
+
# # out = self.head(x_ir)
|
475 |
+
#
|
476 |
+
# out = x_ir
|
477 |
+
# return out
|
478 |
+
|
479 |
+
|
480 |
+
class eca_block(nn.Module):
|
481 |
+
def __init__(self, channel=196, b=1, gamma=2):
|
482 |
+
super(eca_block, self).__init__()
|
483 |
+
kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
|
484 |
+
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
|
485 |
+
|
486 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
487 |
+
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
|
488 |
+
self.sigmoid = nn.Sigmoid()
|
489 |
+
|
490 |
+
def forward(self, x):
|
491 |
+
y = self.avg_pool(x)
|
492 |
+
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
493 |
+
y = self.sigmoid(y)
|
494 |
+
return x * y.expand_as(x)
|
495 |
+
|
496 |
+
class SE_block(nn.Module):
|
497 |
+
def __init__(self, input_dim: int):
|
498 |
+
super().__init__()
|
499 |
+
self.linear1 = torch.nn.Linear(input_dim, input_dim)
|
500 |
+
self.relu = nn.ReLU()
|
501 |
+
self.linear2 = torch.nn.Linear(input_dim, input_dim)
|
502 |
+
self.sigmod = nn.Sigmoid()
|
503 |
+
|
504 |
+
def forward(self, x):
|
505 |
+
x1 = self.linear1(x)
|
506 |
+
x1 = self.relu(x1)
|
507 |
+
x1 = self.linear2(x1)
|
508 |
+
x1 = self.sigmod(x1)
|
509 |
+
x = x * x1
|
510 |
+
return x
|
511 |
+
|
512 |
+
|
513 |
+
class VisionTransformer(nn.Module):
|
514 |
+
def __init__(self, img_size=14, patch_size=14, in_c=147, num_classes=7,
|
515 |
+
embed_dim=768, depth=6, num_heads=8, mlp_ratio=4.0, qkv_bias=True,
|
516 |
+
qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
|
517 |
+
attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
|
518 |
+
act_layer=None):
|
519 |
+
"""
|
520 |
+
Args:
|
521 |
+
img_size (int, tuple): input image size
|
522 |
+
patch_size (int, tuple): patch size
|
523 |
+
in_c (int): number of input channels
|
524 |
+
num_classes (int): number of classes for classification head
|
525 |
+
embed_dim (int): embedding dimension
|
526 |
+
depth (int): depth of transformer
|
527 |
+
num_heads (int): number of attention heads
|
528 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
529 |
+
qkv_bias (bool): enable bias for qkv if True
|
530 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
531 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
532 |
+
distilled (bool): model includes a distillation token and head as in DeiT models
|
533 |
+
drop_ratio (float): dropout rate
|
534 |
+
attn_drop_ratio (float): attention dropout rate
|
535 |
+
drop_path_ratio (float): stochastic depth rate
|
536 |
+
embed_layer (nn.Module): patch embedding layer
|
537 |
+
norm_layer: (nn.Module): normalization layer
|
538 |
+
"""
|
539 |
+
super(VisionTransformer, self).__init__()
|
540 |
+
self.num_classes = num_classes
|
541 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
542 |
+
self.num_tokens = 2 if distilled else 1
|
543 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
544 |
+
act_layer = act_layer or nn.GELU
|
545 |
+
|
546 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
547 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, in_c + 1, embed_dim))
|
548 |
+
self.pos_drop = nn.Dropout(p=drop_ratio)
|
549 |
+
|
550 |
+
self.se_block = SE_block(input_dim=embed_dim)
|
551 |
+
|
552 |
+
|
553 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=256, embed_dim=768)
|
554 |
+
num_patches = self.patch_embed.num_patches
|
555 |
+
self.head = ClassificationHead(input_dim=embed_dim, target_dim=self.num_classes)
|
556 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
557 |
+
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
558 |
+
# self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
559 |
+
self.pos_drop = nn.Dropout(p=drop_ratio)
|
560 |
+
# self.IR = IR()
|
561 |
+
self.eca_block = eca_block()
|
562 |
+
|
563 |
+
|
564 |
+
# self.ir_back = Backbone(50, 0.0, 'ir')
|
565 |
+
# ir_checkpoint = torch.load('./models/pretrain/ir50.pth', map_location=lambda storage, loc: storage)
|
566 |
+
# # ir_checkpoint = ir_checkpoint["model"]
|
567 |
+
# self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
|
568 |
+
|
569 |
+
self.CON1 = nn.Conv2d(256, 768, kernel_size=1, stride=1, bias=False)
|
570 |
+
self.IRLinear1 = nn.Linear(1024, 768)
|
571 |
+
self.IRLinear2 = nn.Linear(768, 512)
|
572 |
+
self.eca_block = eca_block()
|
573 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
|
574 |
+
self.blocks = nn.Sequential(*[
|
575 |
+
Block(dim=embed_dim, in_chans=in_c, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
576 |
+
qk_scale=qk_scale,
|
577 |
+
drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
|
578 |
+
norm_layer=norm_layer, act_layer=act_layer)
|
579 |
+
for i in range(depth)
|
580 |
+
])
|
581 |
+
self.norm = norm_layer(embed_dim)
|
582 |
+
|
583 |
+
# Representation layer
|
584 |
+
if representation_size and not distilled:
|
585 |
+
self.has_logits = True
|
586 |
+
self.num_features = representation_size
|
587 |
+
self.pre_logits = nn.Sequential(OrderedDict([
|
588 |
+
("fc", nn.Linear(embed_dim, representation_size)),
|
589 |
+
("act", nn.Tanh())
|
590 |
+
]))
|
591 |
+
else:
|
592 |
+
self.has_logits = False
|
593 |
+
self.pre_logits = nn.Identity()
|
594 |
+
|
595 |
+
# Classifier head(s)
|
596 |
+
# self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
597 |
+
self.head_dist = None
|
598 |
+
if distilled:
|
599 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
600 |
+
|
601 |
+
# Weight init
|
602 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
603 |
+
if self.dist_token is not None:
|
604 |
+
nn.init.trunc_normal_(self.dist_token, std=0.02)
|
605 |
+
|
606 |
+
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
607 |
+
self.apply(_init_vit_weights)
|
608 |
+
|
609 |
+
def forward_features(self, x):
|
610 |
+
# [B, C, H, W] -> [B, num_patches, embed_dim]
|
611 |
+
# x = self.patch_embed(x) # [B, 196, 768]
|
612 |
+
# [1, 1, 768] -> [B, 1, 768]
|
613 |
+
# print(x.shape)
|
614 |
+
|
615 |
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
616 |
+
if self.dist_token is None:
|
617 |
+
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
|
618 |
+
else:
|
619 |
+
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
620 |
+
# print(x.shape)
|
621 |
+
x = self.pos_drop(x + self.pos_embed)
|
622 |
+
x = self.blocks(x)
|
623 |
+
x = self.norm(x)
|
624 |
+
if self.dist_token is None:
|
625 |
+
return self.pre_logits(x[:, 0])
|
626 |
+
else:
|
627 |
+
return x[:, 0], x[:, 1]
|
628 |
+
|
629 |
+
def forward(self, x):
|
630 |
+
|
631 |
+
# B = x.shape[0]
|
632 |
+
# print(x)
|
633 |
+
# x = self.eca_block(x)
|
634 |
+
# x = self.IR(x)
|
635 |
+
# x = eca_block(x)
|
636 |
+
# x = self.ir_back(x)
|
637 |
+
# print(x.shape)
|
638 |
+
# x = self.CON1(x)
|
639 |
+
# x = x.view(-1, 196, 768)
|
640 |
+
#
|
641 |
+
# # print(x.shape)
|
642 |
+
# # x = self.IRLinear1(x)
|
643 |
+
# # print(x)
|
644 |
+
# x_cls = torch.mean(x, 1).view(B, 1, -1)
|
645 |
+
# x = torch.cat((x_cls, x), dim=1)
|
646 |
+
# # print(x.shape)
|
647 |
+
# x = self.pos_drop(x + self.pos_embed)
|
648 |
+
# # print(x.shape)
|
649 |
+
# x = self.blocks(x)
|
650 |
+
# # print(x)
|
651 |
+
# x = self.norm(x)
|
652 |
+
# # print(x)
|
653 |
+
# # x1 = self.IRLinear2(x)
|
654 |
+
# x1 = x[:, 0, :]
|
655 |
+
|
656 |
+
# print(x1)
|
657 |
+
# print(x1.shape)
|
658 |
+
|
659 |
+
x = self.forward_features(x)
|
660 |
+
# # print(x.shape)
|
661 |
+
# if self.head_dist is not None:
|
662 |
+
# x, x_dist = self.head(x[0]), self.head_dist(x[1])
|
663 |
+
# if self.training and not torch.jit.is_scripting():
|
664 |
+
# # during inference, return the average of both classifier predictions
|
665 |
+
# return x, x_dist
|
666 |
+
# else:
|
667 |
+
# return (x + x_dist) / 2
|
668 |
+
# else:
|
669 |
+
# print(x.shape)
|
670 |
+
x = self.se_block(x)
|
671 |
+
|
672 |
+
x1 = self.head(x)
|
673 |
+
|
674 |
+
return x1
|
675 |
+
|
676 |
+
|
677 |
+
def _init_vit_weights(m):
|
678 |
+
"""
|
679 |
+
ViT weight initialization
|
680 |
+
:param m: module
|
681 |
+
"""
|
682 |
+
if isinstance(m, nn.Linear):
|
683 |
+
nn.init.trunc_normal_(m.weight, std=.01)
|
684 |
+
if m.bias is not None:
|
685 |
+
nn.init.zeros_(m.bias)
|
686 |
+
elif isinstance(m, nn.Conv2d):
|
687 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
688 |
+
if m.bias is not None:
|
689 |
+
nn.init.zeros_(m.bias)
|
690 |
+
elif isinstance(m, nn.LayerNorm):
|
691 |
+
nn.init.zeros_(m.bias)
|
692 |
+
nn.init.ones_(m.weight)
|
693 |
+
|
694 |
+
|
695 |
+
def vit_base_patch16_224(num_classes: int = 7):
|
696 |
+
"""
|
697 |
+
ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
698 |
+
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
699 |
+
weights ported from official Google JAX impl:
|
700 |
+
链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
|
701 |
+
"""
|
702 |
+
model = VisionTransformer(img_size=224,
|
703 |
+
patch_size=16,
|
704 |
+
embed_dim=768,
|
705 |
+
depth=12,
|
706 |
+
num_heads=12,
|
707 |
+
representation_size=None,
|
708 |
+
num_classes=num_classes)
|
709 |
+
|
710 |
+
return model
|
711 |
+
|
712 |
+
|
713 |
+
def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
|
714 |
+
"""
|
715 |
+
ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
716 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
717 |
+
weights ported from official Google JAX impl:
|
718 |
+
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
|
719 |
+
"""
|
720 |
+
model = VisionTransformer(img_size=224,
|
721 |
+
patch_size=16,
|
722 |
+
embed_dim=768,
|
723 |
+
depth=12,
|
724 |
+
num_heads=12,
|
725 |
+
representation_size=768 if has_logits else None,
|
726 |
+
num_classes=num_classes)
|
727 |
+
return model
|
728 |
+
|
729 |
+
|
730 |
+
def vit_base_patch32_224(num_classes: int = 1000):
|
731 |
+
"""
|
732 |
+
ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
733 |
+
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
734 |
+
weights ported from official Google JAX impl:
|
735 |
+
链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl
|
736 |
+
"""
|
737 |
+
model = VisionTransformer(img_size=224,
|
738 |
+
patch_size=32,
|
739 |
+
embed_dim=768,
|
740 |
+
depth=12,
|
741 |
+
num_heads=12,
|
742 |
+
representation_size=None,
|
743 |
+
num_classes=num_classes)
|
744 |
+
return model
|
745 |
+
|
746 |
+
|
747 |
+
def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
|
748 |
+
"""
|
749 |
+
ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
750 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
751 |
+
weights ported from official Google JAX impl:
|
752 |
+
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
|
753 |
+
"""
|
754 |
+
model = VisionTransformer(img_size=224,
|
755 |
+
patch_size=32,
|
756 |
+
embed_dim=768,
|
757 |
+
depth=12,
|
758 |
+
num_heads=12,
|
759 |
+
representation_size=768 if has_logits else None,
|
760 |
+
num_classes=num_classes)
|
761 |
+
return model
|
762 |
+
|
763 |
+
|
764 |
+
def vit_large_patch16_224(num_classes: int = 1000):
|
765 |
+
"""
|
766 |
+
ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
767 |
+
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
768 |
+
weights ported from official Google JAX impl:
|
769 |
+
链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8
|
770 |
+
"""
|
771 |
+
model = VisionTransformer(img_size=224,
|
772 |
+
patch_size=16,
|
773 |
+
embed_dim=1024,
|
774 |
+
depth=24,
|
775 |
+
num_heads=16,
|
776 |
+
representation_size=None,
|
777 |
+
num_classes=num_classes)
|
778 |
+
return model
|
779 |
+
|
780 |
+
|
781 |
+
def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
|
782 |
+
"""
|
783 |
+
ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
784 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
785 |
+
weights ported from official Google JAX impl:
|
786 |
+
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
|
787 |
+
"""
|
788 |
+
model = VisionTransformer(img_size=224,
|
789 |
+
patch_size=16,
|
790 |
+
embed_dim=1024,
|
791 |
+
depth=24,
|
792 |
+
num_heads=16,
|
793 |
+
representation_size=1024 if has_logits else None,
|
794 |
+
num_classes=num_classes)
|
795 |
+
return model
|
796 |
+
|
797 |
+
|
798 |
+
def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
|
799 |
+
"""
|
800 |
+
ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
801 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
802 |
+
weights ported from official Google JAX impl:
|
803 |
+
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
|
804 |
+
"""
|
805 |
+
model = VisionTransformer(img_size=224,
|
806 |
+
patch_size=32,
|
807 |
+
embed_dim=1024,
|
808 |
+
depth=24,
|
809 |
+
num_heads=16,
|
810 |
+
representation_size=1024 if has_logits else None,
|
811 |
+
num_classes=num_classes)
|
812 |
+
return model
|
813 |
+
|
814 |
+
|
815 |
+
def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
|
816 |
+
"""
|
817 |
+
ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
|
818 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
819 |
+
NOTE: converted weights not currently available, too large for github release hosting.
|
820 |
+
"""
|
821 |
+
model = VisionTransformer(img_size=224,
|
822 |
+
patch_size=14,
|
823 |
+
embed_dim=1280,
|
824 |
+
depth=32,
|
825 |
+
num_heads=16,
|
826 |
+
representation_size=1280 if has_logits else None,
|
827 |
+
num_classes=num_classes)
|
828 |
+
return model
|
models/vit_model_8.py
ADDED
@@ -0,0 +1,828 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
original code from rwightman:
|
3 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
4 |
+
"""
|
5 |
+
from functools import partial
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
import torch.hub
|
14 |
+
from functools import partial
|
15 |
+
# import mat
|
16 |
+
# from vision_transformer.ir50 import Backbone
|
17 |
+
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import torch.hub
|
23 |
+
from functools import partial
|
24 |
+
import math
|
25 |
+
|
26 |
+
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
27 |
+
from timm.models.registry import register_model
|
28 |
+
from timm.models.vision_transformer import _cfg, Mlp, Block
|
29 |
+
from .ir50 import Backbone
|
30 |
+
|
31 |
+
|
32 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
33 |
+
"""3x3 convolution with padding"""
|
34 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
35 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
36 |
+
|
37 |
+
|
38 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
39 |
+
"""1x1 convolution"""
|
40 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
41 |
+
|
42 |
+
|
43 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
44 |
+
"""
|
45 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
46 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
47 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
48 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
49 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
50 |
+
'survival rate' as the argument.
|
51 |
+
"""
|
52 |
+
if drop_prob == 0. or not training:
|
53 |
+
return x
|
54 |
+
keep_prob = 1 - drop_prob
|
55 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
56 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
57 |
+
random_tensor.floor_() # binarize
|
58 |
+
output = x.div(keep_prob) * random_tensor
|
59 |
+
return output
|
60 |
+
|
61 |
+
|
62 |
+
class BasicBlock(nn.Module):
|
63 |
+
__constants__ = ['downsample']
|
64 |
+
|
65 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
66 |
+
super(BasicBlock, self).__init__()
|
67 |
+
norm_layer = nn.BatchNorm2d
|
68 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
69 |
+
self.bn1 = norm_layer(planes)
|
70 |
+
self.relu = nn.ReLU(inplace=True)
|
71 |
+
self.conv2 = conv3x3(planes, planes)
|
72 |
+
self.bn2 = norm_layer(planes)
|
73 |
+
self.downsample = downsample
|
74 |
+
self.stride = stride
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
identity = x
|
78 |
+
|
79 |
+
out = self.conv1(x)
|
80 |
+
out = self.bn1(out)
|
81 |
+
out = self.relu(out)
|
82 |
+
out = self.conv2(out)
|
83 |
+
out = self.bn2(out)
|
84 |
+
|
85 |
+
if self.downsample is not None:
|
86 |
+
identity = self.downsample(x)
|
87 |
+
|
88 |
+
out += identity
|
89 |
+
out = self.relu(out)
|
90 |
+
|
91 |
+
return out
|
92 |
+
|
93 |
+
|
94 |
+
class DropPath(nn.Module):
|
95 |
+
"""
|
96 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, drop_prob=None):
|
100 |
+
super(DropPath, self).__init__()
|
101 |
+
self.drop_prob = drop_prob
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
return drop_path(x, self.drop_prob, self.training)
|
105 |
+
|
106 |
+
|
107 |
+
class PatchEmbed(nn.Module):
|
108 |
+
"""
|
109 |
+
2D Image to Patch Embedding
|
110 |
+
"""
|
111 |
+
|
112 |
+
def __init__(self, img_size=14, patch_size=16, in_c=256, embed_dim=768, norm_layer=None):
|
113 |
+
super().__init__()
|
114 |
+
img_size = (img_size, img_size)
|
115 |
+
patch_size = (patch_size, patch_size)
|
116 |
+
self.img_size = img_size
|
117 |
+
self.patch_size = patch_size
|
118 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
119 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
120 |
+
|
121 |
+
self.proj = nn.Conv2d(256, 768, kernel_size=1)
|
122 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
123 |
+
|
124 |
+
def forward(self, x):
|
125 |
+
B, C, H, W = x.shape
|
126 |
+
# assert H == self.img_size[0] and W == self.img_size[1], \
|
127 |
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
128 |
+
# print(x.shape)
|
129 |
+
|
130 |
+
# flatten: [B, C, H, W] -> [B, C, HW]
|
131 |
+
# transpose: [B, C, HW] -> [B, HW, C]
|
132 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
133 |
+
x = self.norm(x)
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
class Attention(nn.Module):
|
138 |
+
def __init__(self,
|
139 |
+
dim, in_chans, # 输入token的dim
|
140 |
+
num_heads=8,
|
141 |
+
qkv_bias=False,
|
142 |
+
qk_scale=None,
|
143 |
+
attn_drop_ratio=0.,
|
144 |
+
proj_drop_ratio=0.):
|
145 |
+
super(Attention, self).__init__()
|
146 |
+
self.num_heads = 8
|
147 |
+
self.img_chanel = in_chans + 1
|
148 |
+
head_dim = dim // num_heads
|
149 |
+
self.scale = head_dim ** -0.5
|
150 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
151 |
+
self.attn_drop = nn.Dropout(attn_drop_ratio)
|
152 |
+
self.proj = nn.Linear(dim, dim)
|
153 |
+
self.proj_drop = nn.Dropout(proj_drop_ratio)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
x_img = x[:, :self.img_chanel, :]
|
157 |
+
# [batch_size, num_patches + 1, total_embed_dim]
|
158 |
+
B, N, C = x_img.shape
|
159 |
+
# print(C)
|
160 |
+
qkv = self.qkv(x_img).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
161 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
162 |
+
# k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
163 |
+
# q = x_img.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
164 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
165 |
+
attn = attn.softmax(dim=-1)
|
166 |
+
attn = self.attn_drop(attn)
|
167 |
+
|
168 |
+
x_img = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
169 |
+
x_img = self.proj(x_img)
|
170 |
+
x_img = self.proj_drop(x_img)
|
171 |
+
#
|
172 |
+
#
|
173 |
+
# # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
|
174 |
+
# # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
|
175 |
+
# # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
|
176 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
177 |
+
# # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
|
178 |
+
# q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
179 |
+
#
|
180 |
+
# # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
|
181 |
+
# # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
|
182 |
+
# attn = (q @ k.transpose(-2, -1)) * self.scale
|
183 |
+
# attn = attn.softmax(dim=-1)
|
184 |
+
# attn = self.attn_drop(attn)
|
185 |
+
#
|
186 |
+
# # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
|
187 |
+
# # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
|
188 |
+
# # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
|
189 |
+
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
190 |
+
# x = self.proj(x)
|
191 |
+
# x = self.proj_drop(x)
|
192 |
+
return x_img
|
193 |
+
|
194 |
+
|
195 |
+
class AttentionBlock(nn.Module):
|
196 |
+
__constants__ = ['downsample']
|
197 |
+
|
198 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
199 |
+
super(AttentionBlock, self).__init__()
|
200 |
+
norm_layer = nn.BatchNorm2d
|
201 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
202 |
+
self.bn1 = norm_layer(planes)
|
203 |
+
self.relu = nn.ReLU(inplace=True)
|
204 |
+
self.conv2 = conv3x3(planes, planes)
|
205 |
+
self.bn2 = norm_layer(planes)
|
206 |
+
self.downsample = downsample
|
207 |
+
self.stride = stride
|
208 |
+
# self.cbam = CBAM(planes, 16)
|
209 |
+
self.inplanes = inplanes
|
210 |
+
self.eca_block = eca_block()
|
211 |
+
|
212 |
+
def forward(self, x):
|
213 |
+
identity = x
|
214 |
+
|
215 |
+
out = self.conv1(x)
|
216 |
+
out = self.bn1(out)
|
217 |
+
out = self.relu(out)
|
218 |
+
|
219 |
+
out = self.conv2(out)
|
220 |
+
out = self.bn2(out)
|
221 |
+
inplanes = self.inplanes
|
222 |
+
out = self.eca_block(out)
|
223 |
+
if self.downsample is not None:
|
224 |
+
identity = self.downsample(x)
|
225 |
+
|
226 |
+
out += identity
|
227 |
+
out = self.relu(out)
|
228 |
+
|
229 |
+
return out
|
230 |
+
|
231 |
+
|
232 |
+
class Mlp(nn.Module):
|
233 |
+
"""
|
234 |
+
MLP as used in Vision Transformer, MLP-Mixer and related networks
|
235 |
+
"""
|
236 |
+
|
237 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
238 |
+
super().__init__()
|
239 |
+
out_features = out_features or in_features
|
240 |
+
hidden_features = hidden_features or in_features
|
241 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
242 |
+
self.act = act_layer()
|
243 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
244 |
+
self.drop = nn.Dropout(drop)
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
x = self.fc1(x)
|
248 |
+
x = self.act(x)
|
249 |
+
x = self.drop(x)
|
250 |
+
x = self.fc2(x)
|
251 |
+
x = self.drop(x)
|
252 |
+
return x
|
253 |
+
|
254 |
+
|
255 |
+
class Block(nn.Module):
|
256 |
+
def __init__(self,
|
257 |
+
dim, in_chans,
|
258 |
+
num_heads,
|
259 |
+
mlp_ratio=4.,
|
260 |
+
qkv_bias=False,
|
261 |
+
qk_scale=None,
|
262 |
+
drop_ratio=0.,
|
263 |
+
attn_drop_ratio=0.,
|
264 |
+
drop_path_ratio=0.,
|
265 |
+
act_layer=nn.GELU,
|
266 |
+
norm_layer=nn.LayerNorm):
|
267 |
+
super(Block, self).__init__()
|
268 |
+
self.norm1 = norm_layer(dim)
|
269 |
+
self.img_chanel = in_chans + 1
|
270 |
+
|
271 |
+
self.conv = nn.Conv1d(self.img_chanel, self.img_chanel, 1)
|
272 |
+
self.attn = Attention(dim, in_chans=in_chans, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
273 |
+
attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
|
274 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
275 |
+
self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
|
276 |
+
self.norm2 = norm_layer(dim)
|
277 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
278 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
|
279 |
+
|
280 |
+
def forward(self, x):
|
281 |
+
# x = x + self.drop_path(self.attn(self.norm1(x)))
|
282 |
+
# x = x + self.drop_path(self.mlp(self.norm2(x)))
|
283 |
+
|
284 |
+
x_img = x
|
285 |
+
# [:, :self.img_chanel, :]
|
286 |
+
# x_lm = x[:, self.img_chanel:, :]
|
287 |
+
x_img = x_img + self.drop_path(self.attn(self.norm1(x)))
|
288 |
+
x = x_img + self.drop_path(self.mlp(self.norm2(x_img)))
|
289 |
+
#
|
290 |
+
# x_lm = x_lm + self.drop_path(self.attn_lm(self.norm3(x)))
|
291 |
+
# x_lm = x_lm + self.drop_path(self.mlp2(self.norm4(x_lm)))
|
292 |
+
# x = torch.cat((x_img, x_lm), dim=1)
|
293 |
+
# x = self.conv(x)
|
294 |
+
|
295 |
+
return x
|
296 |
+
|
297 |
+
|
298 |
+
class ClassificationHead(nn.Module):
|
299 |
+
def __init__(self, input_dim: int, target_dim: int):
|
300 |
+
super().__init__()
|
301 |
+
self.linear = torch.nn.Linear(input_dim, target_dim)
|
302 |
+
|
303 |
+
def forward(self, x):
|
304 |
+
x = x.view(x.size(0), -1)
|
305 |
+
y_hat = self.linear(x)
|
306 |
+
return y_hat
|
307 |
+
|
308 |
+
|
309 |
+
def load_pretrained_weights(model, checkpoint):
|
310 |
+
import collections
|
311 |
+
if 'state_dict' in checkpoint:
|
312 |
+
state_dict = checkpoint['state_dict']
|
313 |
+
else:
|
314 |
+
state_dict = checkpoint
|
315 |
+
model_dict = model.state_dict()
|
316 |
+
new_state_dict = collections.OrderedDict()
|
317 |
+
matched_layers, discarded_layers = [], []
|
318 |
+
for k, v in state_dict.items():
|
319 |
+
# If the pretrained state_dict was saved as nn.DataParallel,
|
320 |
+
# keys would contain "module.", which should be ignored.
|
321 |
+
if k.startswith('module.'):
|
322 |
+
k = k[7:]
|
323 |
+
if k in model_dict and model_dict[k].size() == v.size():
|
324 |
+
new_state_dict[k] = v
|
325 |
+
matched_layers.append(k)
|
326 |
+
else:
|
327 |
+
discarded_layers.append(k)
|
328 |
+
# new_state_dict.requires_grad = False
|
329 |
+
model_dict.update(new_state_dict)
|
330 |
+
|
331 |
+
model.load_state_dict(model_dict)
|
332 |
+
print('load_weight', len(matched_layers))
|
333 |
+
return model
|
334 |
+
|
335 |
+
class eca_block(nn.Module):
|
336 |
+
def __init__(self, channel=128, b=1, gamma=2):
|
337 |
+
super(eca_block, self).__init__()
|
338 |
+
kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
|
339 |
+
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
|
340 |
+
|
341 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
342 |
+
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
|
343 |
+
self.sigmoid = nn.Sigmoid()
|
344 |
+
|
345 |
+
def forward(self, x):
|
346 |
+
y = self.avg_pool(x)
|
347 |
+
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
348 |
+
y = self.sigmoid(y)
|
349 |
+
return x * y.expand_as(x)
|
350 |
+
#
|
351 |
+
#
|
352 |
+
# class IR20(nn.Module):
|
353 |
+
# def __init__(self, img_size_=112, num_classes=7, layers=[2, 2, 2, 2]):
|
354 |
+
# super().__init__()
|
355 |
+
# norm_layer = nn.BatchNorm2d
|
356 |
+
# self.img_size = img_size_
|
357 |
+
# self._norm_layer = norm_layer
|
358 |
+
# self.num_classes = num_classes
|
359 |
+
# self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
360 |
+
# self.bn1 = norm_layer(64)
|
361 |
+
# self.relu = nn.ReLU(inplace=True)
|
362 |
+
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
363 |
+
# # self.face_landback = MobileFaceNet([112, 112],136)
|
364 |
+
# # face_landback_checkpoint = torch.load('./models/pretrain/mobilefacenet_model_best.pth.tar', map_location=lambda storage, loc: storage)
|
365 |
+
# # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
|
366 |
+
# self.layer1 = self._make_layer(BasicBlock, 64, 64, layers[0])
|
367 |
+
# self.layer2 = self._make_layer(BasicBlock, 64, 128, layers[1], stride=2)
|
368 |
+
# self.layer3 = self._make_layer(AttentionBlock, 128, 256, layers[2], stride=2)
|
369 |
+
# self.layer4 = self._make_layer(AttentionBlock, 256, 256, layers[3], stride=1)
|
370 |
+
# self.ir_back = Backbone(50, 51, 52, 0.0, 'ir')
|
371 |
+
# self.ir_layer = nn.Linear(1024, 512)
|
372 |
+
# # ir_checkpoint = torch.load(r'F:\0815crossvit\vision_transformer\models\pretrain\Pretrained_on_MSCeleb.pth.tar',
|
373 |
+
# # map_location=lambda storage, loc: storage)
|
374 |
+
# # ir_checkpoint = ir_checkpoint['state_dict']
|
375 |
+
# # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
|
376 |
+
# # checkpoint = torch.load('./checkpoint/Pretrained_on_MSCeleb.pth.tar')
|
377 |
+
# # pre_trained_dict = checkpoint['state_dict']
|
378 |
+
# # IR20.load_state_dict(ir_checkpoint, strict=False)
|
379 |
+
# # self.IR = load_pretrained_weights(IR, ir_checkpoint)
|
380 |
+
#
|
381 |
+
# def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
382 |
+
# norm_layer = self._norm_layer
|
383 |
+
# downsample = None
|
384 |
+
# if stride != 1 or inplanes != planes:
|
385 |
+
# downsample = nn.Sequential(conv1x1(inplanes, planes, stride), norm_layer(planes))
|
386 |
+
# layers = []
|
387 |
+
# layers.append(block(inplanes, planes, stride, downsample))
|
388 |
+
# inplanes = planes
|
389 |
+
# for _ in range(1, blocks):
|
390 |
+
# layers.append(block(inplanes, planes))
|
391 |
+
# return nn.Sequential(*layers)
|
392 |
+
#
|
393 |
+
# def forward(self, x):
|
394 |
+
# x_ir = self.ir_back(x)
|
395 |
+
# # x_ir = self.ir_layer(x_ir)
|
396 |
+
# # print(x_ir.shape)
|
397 |
+
# # x = F.interpolate(x, size=112)
|
398 |
+
# # x = self.conv1(x)
|
399 |
+
# # x = self.bn1(x)
|
400 |
+
# # x = self.relu(x)
|
401 |
+
# # x = self.maxpool(x)
|
402 |
+
# #
|
403 |
+
# # x = self.layer1(x)
|
404 |
+
# # x = self.layer2(x)
|
405 |
+
# # x = self.layer3(x)
|
406 |
+
# # x = self.layer4(x)
|
407 |
+
# # print(x.shape)
|
408 |
+
# # print(x)
|
409 |
+
# out = x_ir
|
410 |
+
#
|
411 |
+
# return out
|
412 |
+
#
|
413 |
+
#
|
414 |
+
# class IR(nn.Module):
|
415 |
+
# def __init__(self, img_size_=112, num_classes=7):
|
416 |
+
# super().__init__()
|
417 |
+
# depth = 8
|
418 |
+
# # if type == "small":
|
419 |
+
# # depth = 4
|
420 |
+
# # if type == "base":
|
421 |
+
# # depth = 6
|
422 |
+
# # if type == "large":
|
423 |
+
# # depth = 8
|
424 |
+
#
|
425 |
+
# self.img_size = img_size_
|
426 |
+
# self.num_classes = num_classes
|
427 |
+
# self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
428 |
+
# # self.bn1 = norm_layer(64)
|
429 |
+
# self.relu = nn.ReLU(inplace=True)
|
430 |
+
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
431 |
+
# # self.face_landback = MobileFaceNet([112, 112],136)
|
432 |
+
# # face_landback_checkpoint = torch.load('./models/pretrain/mobilefacenet_model_best.pth.tar', map_location=lambda storage, loc: storage)
|
433 |
+
# # self.face_landback.load_state_dict(face_landback_checkpoint['state_dict'])
|
434 |
+
#
|
435 |
+
# # for param in self.face_landback.parameters():
|
436 |
+
# # param.requires_grad = False
|
437 |
+
#
|
438 |
+
# ###########################################################################333
|
439 |
+
#
|
440 |
+
# self.ir_back = IR20()
|
441 |
+
#
|
442 |
+
# # ir_checkpoint = torch.load(r'F:\0815crossvit\vision_transformer\models\pretrain\ir50.pth',
|
443 |
+
# # map_location=lambda storage, loc: storage)
|
444 |
+
# # # ir_checkpoint = ir_checkpoint["model"]
|
445 |
+
# # self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
|
446 |
+
# # load_state_dict(checkpoint_model, strict=False)
|
447 |
+
# # self.ir_layer = nn.Linear(1024,512)
|
448 |
+
#
|
449 |
+
# #############################################################3
|
450 |
+
# #
|
451 |
+
# # self.pyramid_fuse = HyVisionTransformer(in_chans=49, q_chanel = 49, embed_dim=512,
|
452 |
+
# # depth=depth, num_heads=8, mlp_ratio=2.,
|
453 |
+
# # drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1)
|
454 |
+
#
|
455 |
+
# # self.se_block = SE_block(input_dim=512)
|
456 |
+
# self.head = ClassificationHead(input_dim=768, target_dim=self.num_classes)
|
457 |
+
#
|
458 |
+
# def forward(self, x):
|
459 |
+
# B_ = x.shape[0]
|
460 |
+
# # x_face = F.interpolate(x, size=112)
|
461 |
+
# # _, x_face = self.face_landback(x_face)
|
462 |
+
# # x_face = x_face.view(B_, -1, 49).transpose(1,2)
|
463 |
+
# ############### landmark x_face ([B, 49, 512])
|
464 |
+
# x_ir = self.ir_back(x)
|
465 |
+
# # print(x_ir.shape)
|
466 |
+
# # x_ir = self.ir_layer(x_ir)
|
467 |
+
# # print(x_ir.shape)
|
468 |
+
# ############### image x_ir ([B, 49, 512])
|
469 |
+
#
|
470 |
+
# # y_hat = self.pyramid_fuse(x_ir, x_face)
|
471 |
+
# # y_hat = self.se_block(y_hat)
|
472 |
+
# # y_feat = y_hat
|
473 |
+
#
|
474 |
+
# # out = self.head(x_ir)
|
475 |
+
#
|
476 |
+
# out = x_ir
|
477 |
+
# return out
|
478 |
+
|
479 |
+
|
480 |
+
class eca_block(nn.Module):
|
481 |
+
def __init__(self, channel=196, b=1, gamma=2):
|
482 |
+
super(eca_block, self).__init__()
|
483 |
+
kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
|
484 |
+
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
|
485 |
+
|
486 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
487 |
+
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
|
488 |
+
self.sigmoid = nn.Sigmoid()
|
489 |
+
|
490 |
+
def forward(self, x):
|
491 |
+
y = self.avg_pool(x)
|
492 |
+
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
493 |
+
y = self.sigmoid(y)
|
494 |
+
return x * y.expand_as(x)
|
495 |
+
|
496 |
+
class SE_block(nn.Module):
|
497 |
+
def __init__(self, input_dim: int):
|
498 |
+
super().__init__()
|
499 |
+
self.linear1 = torch.nn.Linear(input_dim, input_dim)
|
500 |
+
self.relu = nn.ReLU()
|
501 |
+
self.linear2 = torch.nn.Linear(input_dim, input_dim)
|
502 |
+
self.sigmod = nn.Sigmoid()
|
503 |
+
|
504 |
+
def forward(self, x):
|
505 |
+
x1 = self.linear1(x)
|
506 |
+
x1 = self.relu(x1)
|
507 |
+
x1 = self.linear2(x1)
|
508 |
+
x1 = self.sigmod(x1)
|
509 |
+
x = x * x1
|
510 |
+
return x
|
511 |
+
|
512 |
+
|
513 |
+
class VisionTransformer(nn.Module):
|
514 |
+
def __init__(self, img_size=14, patch_size=14, in_c=147, num_classes=8,
|
515 |
+
embed_dim=768, depth=6, num_heads=8, mlp_ratio=4.0, qkv_bias=True,
|
516 |
+
qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
|
517 |
+
attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
|
518 |
+
act_layer=None):
|
519 |
+
"""
|
520 |
+
Args:
|
521 |
+
img_size (int, tuple): input image size
|
522 |
+
patch_size (int, tuple): patch size
|
523 |
+
in_c (int): number of input channels
|
524 |
+
num_classes (int): number of classes for classification head
|
525 |
+
embed_dim (int): embedding dimension
|
526 |
+
depth (int): depth of transformer
|
527 |
+
num_heads (int): number of attention heads
|
528 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
529 |
+
qkv_bias (bool): enable bias for qkv if True
|
530 |
+
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
531 |
+
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
532 |
+
distilled (bool): model includes a distillation token and head as in DeiT models
|
533 |
+
drop_ratio (float): dropout rate
|
534 |
+
attn_drop_ratio (float): attention dropout rate
|
535 |
+
drop_path_ratio (float): stochastic depth rate
|
536 |
+
embed_layer (nn.Module): patch embedding layer
|
537 |
+
norm_layer: (nn.Module): normalization layer
|
538 |
+
"""
|
539 |
+
super(VisionTransformer, self).__init__()
|
540 |
+
self.num_classes = num_classes
|
541 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
542 |
+
self.num_tokens = 2 if distilled else 1
|
543 |
+
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
544 |
+
act_layer = act_layer or nn.GELU
|
545 |
+
|
546 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
547 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, in_c + 1, embed_dim))
|
548 |
+
self.pos_drop = nn.Dropout(p=drop_ratio)
|
549 |
+
|
550 |
+
self.se_block = SE_block(input_dim=embed_dim)
|
551 |
+
|
552 |
+
|
553 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=256, embed_dim=768)
|
554 |
+
num_patches = self.patch_embed.num_patches
|
555 |
+
self.head = ClassificationHead(input_dim=embed_dim, target_dim=self.num_classes)
|
556 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
557 |
+
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
558 |
+
# self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
559 |
+
self.pos_drop = nn.Dropout(p=drop_ratio)
|
560 |
+
# self.IR = IR()
|
561 |
+
self.eca_block = eca_block()
|
562 |
+
|
563 |
+
|
564 |
+
# self.ir_back = Backbone(50, 0.0, 'ir')
|
565 |
+
# ir_checkpoint = torch.load('./models/pretrain/ir50.pth', map_location=lambda storage, loc: storage)
|
566 |
+
# # ir_checkpoint = ir_checkpoint["model"]
|
567 |
+
# self.ir_back = load_pretrained_weights(self.ir_back, ir_checkpoint)
|
568 |
+
|
569 |
+
self.CON1 = nn.Conv2d(256, 768, kernel_size=1, stride=1, bias=False)
|
570 |
+
self.IRLinear1 = nn.Linear(1024, 768)
|
571 |
+
self.IRLinear2 = nn.Linear(768, 512)
|
572 |
+
self.eca_block = eca_block()
|
573 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)] # stochastic depth decay rule
|
574 |
+
self.blocks = nn.Sequential(*[
|
575 |
+
Block(dim=embed_dim, in_chans=in_c, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
576 |
+
qk_scale=qk_scale,
|
577 |
+
drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
|
578 |
+
norm_layer=norm_layer, act_layer=act_layer)
|
579 |
+
for i in range(depth)
|
580 |
+
])
|
581 |
+
self.norm = norm_layer(embed_dim)
|
582 |
+
|
583 |
+
# Representation layer
|
584 |
+
if representation_size and not distilled:
|
585 |
+
self.has_logits = True
|
586 |
+
self.num_features = representation_size
|
587 |
+
self.pre_logits = nn.Sequential(OrderedDict([
|
588 |
+
("fc", nn.Linear(embed_dim, representation_size)),
|
589 |
+
("act", nn.Tanh())
|
590 |
+
]))
|
591 |
+
else:
|
592 |
+
self.has_logits = False
|
593 |
+
self.pre_logits = nn.Identity()
|
594 |
+
|
595 |
+
# Classifier head(s)
|
596 |
+
# self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
597 |
+
self.head_dist = None
|
598 |
+
if distilled:
|
599 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
600 |
+
|
601 |
+
# Weight init
|
602 |
+
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
603 |
+
if self.dist_token is not None:
|
604 |
+
nn.init.trunc_normal_(self.dist_token, std=0.02)
|
605 |
+
|
606 |
+
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
607 |
+
self.apply(_init_vit_weights)
|
608 |
+
|
609 |
+
def forward_features(self, x):
|
610 |
+
# [B, C, H, W] -> [B, num_patches, embed_dim]
|
611 |
+
# x = self.patch_embed(x) # [B, 196, 768]
|
612 |
+
# [1, 1, 768] -> [B, 1, 768]
|
613 |
+
# print(x.shape)
|
614 |
+
|
615 |
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
|
616 |
+
if self.dist_token is None:
|
617 |
+
x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]
|
618 |
+
else:
|
619 |
+
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
620 |
+
# print(x.shape)
|
621 |
+
x = self.pos_drop(x + self.pos_embed)
|
622 |
+
x = self.blocks(x)
|
623 |
+
x = self.norm(x)
|
624 |
+
if self.dist_token is None:
|
625 |
+
return self.pre_logits(x[:, 0])
|
626 |
+
else:
|
627 |
+
return x[:, 0], x[:, 1]
|
628 |
+
|
629 |
+
def forward(self, x):
|
630 |
+
|
631 |
+
# B = x.shape[0]
|
632 |
+
# print(x)
|
633 |
+
# x = self.eca_block(x)
|
634 |
+
# x = self.IR(x)
|
635 |
+
# x = eca_block(x)
|
636 |
+
# x = self.ir_back(x)
|
637 |
+
# print(x.shape)
|
638 |
+
# x = self.CON1(x)
|
639 |
+
# x = x.view(-1, 196, 768)
|
640 |
+
#
|
641 |
+
# # print(x.shape)
|
642 |
+
# # x = self.IRLinear1(x)
|
643 |
+
# # print(x)
|
644 |
+
# x_cls = torch.mean(x, 1).view(B, 1, -1)
|
645 |
+
# x = torch.cat((x_cls, x), dim=1)
|
646 |
+
# # print(x.shape)
|
647 |
+
# x = self.pos_drop(x + self.pos_embed)
|
648 |
+
# # print(x.shape)
|
649 |
+
# x = self.blocks(x)
|
650 |
+
# # print(x)
|
651 |
+
# x = self.norm(x)
|
652 |
+
# # print(x)
|
653 |
+
# # x1 = self.IRLinear2(x)
|
654 |
+
# x1 = x[:, 0, :]
|
655 |
+
|
656 |
+
# print(x1)
|
657 |
+
# print(x1.shape)
|
658 |
+
|
659 |
+
x = self.forward_features(x)
|
660 |
+
# # print(x.shape)
|
661 |
+
# if self.head_dist is not None:
|
662 |
+
# x, x_dist = self.head(x[0]), self.head_dist(x[1])
|
663 |
+
# if self.training and not torch.jit.is_scripting():
|
664 |
+
# # during inference, return the average of both classifier predictions
|
665 |
+
# return x, x_dist
|
666 |
+
# else:
|
667 |
+
# return (x + x_dist) / 2
|
668 |
+
# else:
|
669 |
+
# print(x.shape)
|
670 |
+
x = self.se_block(x)
|
671 |
+
|
672 |
+
x1 = self.head(x)
|
673 |
+
|
674 |
+
return x1
|
675 |
+
|
676 |
+
|
677 |
+
def _init_vit_weights(m):
|
678 |
+
"""
|
679 |
+
ViT weight initialization
|
680 |
+
:param m: module
|
681 |
+
"""
|
682 |
+
if isinstance(m, nn.Linear):
|
683 |
+
nn.init.trunc_normal_(m.weight, std=.01)
|
684 |
+
if m.bias is not None:
|
685 |
+
nn.init.zeros_(m.bias)
|
686 |
+
elif isinstance(m, nn.Conv2d):
|
687 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out")
|
688 |
+
if m.bias is not None:
|
689 |
+
nn.init.zeros_(m.bias)
|
690 |
+
elif isinstance(m, nn.LayerNorm):
|
691 |
+
nn.init.zeros_(m.bias)
|
692 |
+
nn.init.ones_(m.weight)
|
693 |
+
|
694 |
+
|
695 |
+
def vit_base_patch16_224(num_classes: int = 7):
|
696 |
+
"""
|
697 |
+
ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
698 |
+
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
699 |
+
weights ported from official Google JAX impl:
|
700 |
+
链接: https://pan.baidu.com/s/1zqb08naP0RPqqfSXfkB2EA 密码: eu9f
|
701 |
+
"""
|
702 |
+
model = VisionTransformer(img_size=224,
|
703 |
+
patch_size=16,
|
704 |
+
embed_dim=768,
|
705 |
+
depth=12,
|
706 |
+
num_heads=12,
|
707 |
+
representation_size=None,
|
708 |
+
num_classes=num_classes)
|
709 |
+
|
710 |
+
return model
|
711 |
+
|
712 |
+
|
713 |
+
def vit_base_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
|
714 |
+
"""
|
715 |
+
ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
716 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
717 |
+
weights ported from official Google JAX impl:
|
718 |
+
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth
|
719 |
+
"""
|
720 |
+
model = VisionTransformer(img_size=224,
|
721 |
+
patch_size=16,
|
722 |
+
embed_dim=768,
|
723 |
+
depth=12,
|
724 |
+
num_heads=12,
|
725 |
+
representation_size=768 if has_logits else None,
|
726 |
+
num_classes=num_classes)
|
727 |
+
return model
|
728 |
+
|
729 |
+
|
730 |
+
def vit_base_patch32_224(num_classes: int = 1000):
|
731 |
+
"""
|
732 |
+
ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
733 |
+
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
734 |
+
weights ported from official Google JAX impl:
|
735 |
+
链接: https://pan.baidu.com/s/1hCv0U8pQomwAtHBYc4hmZg 密码: s5hl
|
736 |
+
"""
|
737 |
+
model = VisionTransformer(img_size=224,
|
738 |
+
patch_size=32,
|
739 |
+
embed_dim=768,
|
740 |
+
depth=12,
|
741 |
+
num_heads=12,
|
742 |
+
representation_size=None,
|
743 |
+
num_classes=num_classes)
|
744 |
+
return model
|
745 |
+
|
746 |
+
|
747 |
+
def vit_base_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
|
748 |
+
"""
|
749 |
+
ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
750 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
751 |
+
weights ported from official Google JAX impl:
|
752 |
+
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth
|
753 |
+
"""
|
754 |
+
model = VisionTransformer(img_size=224,
|
755 |
+
patch_size=32,
|
756 |
+
embed_dim=768,
|
757 |
+
depth=12,
|
758 |
+
num_heads=12,
|
759 |
+
representation_size=768 if has_logits else None,
|
760 |
+
num_classes=num_classes)
|
761 |
+
return model
|
762 |
+
|
763 |
+
|
764 |
+
def vit_large_patch16_224(num_classes: int = 1000):
|
765 |
+
"""
|
766 |
+
ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
767 |
+
ImageNet-1k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
768 |
+
weights ported from official Google JAX impl:
|
769 |
+
链接: https://pan.baidu.com/s/1cxBgZJJ6qUWPSBNcE4TdRQ 密码: qqt8
|
770 |
+
"""
|
771 |
+
model = VisionTransformer(img_size=224,
|
772 |
+
patch_size=16,
|
773 |
+
embed_dim=1024,
|
774 |
+
depth=24,
|
775 |
+
num_heads=16,
|
776 |
+
representation_size=None,
|
777 |
+
num_classes=num_classes)
|
778 |
+
return model
|
779 |
+
|
780 |
+
|
781 |
+
def vit_large_patch16_224_in21k(num_classes: int = 21843, has_logits: bool = True):
|
782 |
+
"""
|
783 |
+
ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
784 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
785 |
+
weights ported from official Google JAX impl:
|
786 |
+
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth
|
787 |
+
"""
|
788 |
+
model = VisionTransformer(img_size=224,
|
789 |
+
patch_size=16,
|
790 |
+
embed_dim=1024,
|
791 |
+
depth=24,
|
792 |
+
num_heads=16,
|
793 |
+
representation_size=1024 if has_logits else None,
|
794 |
+
num_classes=num_classes)
|
795 |
+
return model
|
796 |
+
|
797 |
+
|
798 |
+
def vit_large_patch32_224_in21k(num_classes: int = 21843, has_logits: bool = True):
|
799 |
+
"""
|
800 |
+
ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
801 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
802 |
+
weights ported from official Google JAX impl:
|
803 |
+
https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth
|
804 |
+
"""
|
805 |
+
model = VisionTransformer(img_size=224,
|
806 |
+
patch_size=32,
|
807 |
+
embed_dim=1024,
|
808 |
+
depth=24,
|
809 |
+
num_heads=16,
|
810 |
+
representation_size=1024 if has_logits else None,
|
811 |
+
num_classes=num_classes)
|
812 |
+
return model
|
813 |
+
|
814 |
+
|
815 |
+
def vit_huge_patch14_224_in21k(num_classes: int = 21843, has_logits: bool = True):
|
816 |
+
"""
|
817 |
+
ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
|
818 |
+
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
819 |
+
NOTE: converted weights not currently available, too large for github release hosting.
|
820 |
+
"""
|
821 |
+
model = VisionTransformer(img_size=224,
|
822 |
+
patch_size=14,
|
823 |
+
embed_dim=1280,
|
824 |
+
depth=32,
|
825 |
+
num_heads=16,
|
826 |
+
representation_size=1280 if has_logits else None,
|
827 |
+
num_classes=num_classes)
|
828 |
+
return model
|
prediction.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from main import *
|
2 |
+
from deepface import DeepFace
|
3 |
+
|
4 |
+
# Checking for all types of devices available
|
5 |
+
if torch.backends.mps.is_available():
|
6 |
+
device = "mps"
|
7 |
+
elif torch.cuda.is_available():
|
8 |
+
device = "cuda"
|
9 |
+
else:
|
10 |
+
device = "cpu"
|
11 |
+
|
12 |
+
print(f"Using device: {device}")
|
13 |
+
# Predicting the model
|
14 |
+
# def prediction(model, image_path):
|
15 |
+
model = pyramid_trans_expr2(img_size=224, num_classes=7)
|
16 |
+
|
17 |
+
model = torch.nn.DataParallel(model)
|
18 |
+
model = model.to(device)
|
19 |
+
|
20 |
+
model_path = "raf-db-model_best.pth"
|
21 |
+
image_arr = []
|
22 |
+
for foldername, subfolders, filenames in os.walk(
|
23 |
+
"/Users/futuregadgetlab/Downloads/Testing/"
|
24 |
+
):
|
25 |
+
for filename in filenames:
|
26 |
+
# Construct the full path to the file
|
27 |
+
file_path = os.path.join(foldername, filename)
|
28 |
+
image_arr.append(f"{file_path}")
|
29 |
+
|
30 |
+
|
31 |
+
def main():
|
32 |
+
if model_path is not None:
|
33 |
+
if os.path.isfile(model_path):
|
34 |
+
print("=> loading checkpoint '{}'".format(model_path))
|
35 |
+
checkpoint = torch.load(model_path, map_location=device)
|
36 |
+
best_acc = checkpoint["best_acc"]
|
37 |
+
best_acc = best_acc.to()
|
38 |
+
print(f"best_acc:{best_acc}")
|
39 |
+
model.load_state_dict(checkpoint["state_dict"])
|
40 |
+
print(
|
41 |
+
"=> loaded checkpoint '{}' (epoch {})".format(
|
42 |
+
model_path, checkpoint["epoch"]
|
43 |
+
)
|
44 |
+
)
|
45 |
+
else:
|
46 |
+
print("=> no checkpoint found at '{}'".format(model_path))
|
47 |
+
predict(model, image_path=image_arr)
|
48 |
+
return
|
49 |
+
|
50 |
+
|
51 |
+
def predict(model, image_path):
|
52 |
+
from face_detection import face_detection
|
53 |
+
|
54 |
+
with torch.no_grad():
|
55 |
+
transform = transforms.Compose(
|
56 |
+
[
|
57 |
+
transforms.Resize((224, 224)),
|
58 |
+
transforms.RandomHorizontalFlip(),
|
59 |
+
transforms.ToTensor(),
|
60 |
+
transforms.Normalize(
|
61 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
62 |
+
),
|
63 |
+
transforms.RandomErasing(p=1, scale=(0.05, 0.05)),
|
64 |
+
]
|
65 |
+
)
|
66 |
+
face = face_detection(image_path)
|
67 |
+
image_tensor = transform(face).unsqueeze(0)
|
68 |
+
image_tensor = image_tensor.to(device)
|
69 |
+
|
70 |
+
model.eval()
|
71 |
+
img_pred = model(image_tensor)
|
72 |
+
topk = (3,)
|
73 |
+
with torch.no_grad():
|
74 |
+
maxk = max(topk)
|
75 |
+
# batch_size = target.size(0)
|
76 |
+
_, pred = img_pred.topk(maxk, 1, True, True)
|
77 |
+
pred = pred.t()
|
78 |
+
|
79 |
+
img_pred = pred
|
80 |
+
img_pred = img_pred.squeeze().cpu().numpy()
|
81 |
+
im_pre_label = np.array(img_pred)
|
82 |
+
y_pred = im_pre_label.flatten()
|
83 |
+
emotions = {
|
84 |
+
0: "Surprise",
|
85 |
+
1: "Fear",
|
86 |
+
2: "Disgust",
|
87 |
+
3: "Happy",
|
88 |
+
4: "Sad",
|
89 |
+
5: "Angry",
|
90 |
+
6: "Neutral",
|
91 |
+
}
|
92 |
+
labels = []
|
93 |
+
for i in y_pred:
|
94 |
+
labels.append(emotions.get(i))
|
95 |
+
|
96 |
+
print(
|
97 |
+
f"-->Image Path {image_path} [!] The predicted labels are {y_pred} and the label is {labels}"
|
98 |
+
)
|
99 |
+
return
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
main()
|
raf-db-model_best.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d9bf1d0d88238966ce0d1a289a2bb5f927ec2fe635ef1ec4396c323028924701
|
3 |
+
size 238971279
|
requirements.txt
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
appdirs==1.4.4
|
2 |
+
asgiref==3.7.2
|
3 |
+
attr==0.3.1
|
4 |
+
azure-core==1.29.5
|
5 |
+
azure-storage-blob==12.18.3
|
6 |
+
bleach==5.0.1
|
7 |
+
boto==2.49.0
|
8 |
+
boto3==1.16.63
|
9 |
+
botocore==1.19.63
|
10 |
+
boxing==0.1.4
|
11 |
+
Brotli @ file:///Users/runner/miniforge3/conda-bld/brotli-split_1695989934239/work
|
12 |
+
certifi==2023.7.22
|
13 |
+
cffi==1.16.0
|
14 |
+
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1698833585322/work
|
15 |
+
click==8.1.7
|
16 |
+
colorama==0.4.6
|
17 |
+
contourpy @ file:///Users/runner/miniforge3/conda-bld/contourpy_1699041448398/work
|
18 |
+
coreapi==2.3.3
|
19 |
+
coreschema==0.0.4
|
20 |
+
cryptography==41.0.5
|
21 |
+
cycler @ file:///home/conda/feedstock_root/build_artifacts/cycler_1696677705766/work
|
22 |
+
defusedxml==0.7.1
|
23 |
+
Django==3.2.20
|
24 |
+
django-annoying==0.10.6
|
25 |
+
django-cors-headers==3.6.0
|
26 |
+
django-debug-toolbar==3.2.1
|
27 |
+
django-environ==0.10.0
|
28 |
+
django-extensions==3.1.0
|
29 |
+
django-filter==2.4.0
|
30 |
+
django-model-utils==4.1.1
|
31 |
+
django-ranged-fileresponse==0.1.2
|
32 |
+
django-rest-swagger==2.2.0
|
33 |
+
django-rq==2.5.1
|
34 |
+
django-storages==1.12.3
|
35 |
+
django-user-agents==0.4.0
|
36 |
+
djangorestframework==3.13.1
|
37 |
+
drf-dynamic-fields==0.3.0
|
38 |
+
drf-flex-fields==0.9.5
|
39 |
+
drf-generators==0.3.0
|
40 |
+
drf-yasg==1.20.0
|
41 |
+
expiringdict==1.2.2
|
42 |
+
filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1698714947081/work
|
43 |
+
fonttools @ file:///Users/runner/miniforge3/conda-bld/fonttools_1699023568720/work
|
44 |
+
fsspec==2023.10.0
|
45 |
+
gmpy2 @ file:///Users/runner/miniforge3/conda-bld/gmpy2_1666808749046/work
|
46 |
+
google-api-core==2.11.0
|
47 |
+
google-cloud-appengine-logging==1.1.0
|
48 |
+
google-cloud-audit-log==0.2.0
|
49 |
+
google-cloud-core==2.3.2
|
50 |
+
google-cloud-logging==2.7.1
|
51 |
+
google-cloud-storage==2.5.0
|
52 |
+
google-crc32c==1.5.0
|
53 |
+
google-resumable-media==2.3.3
|
54 |
+
googleapis-common-protos==1.56.4
|
55 |
+
grpc-google-iam-v1==0.12.4
|
56 |
+
grpcio-status==1.59.2
|
57 |
+
htmlmin==0.1.12
|
58 |
+
huggingface-hub==0.18.0
|
59 |
+
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work
|
60 |
+
ijson==3.2.3
|
61 |
+
inflection==0.5.1
|
62 |
+
isodate==0.6.1
|
63 |
+
itypes==1.2.0
|
64 |
+
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1654302431367/work
|
65 |
+
jmespath==0.10.0
|
66 |
+
joblib==1.3.2
|
67 |
+
jsonschema==3.2.0
|
68 |
+
kiwisolver @ file:///Users/runner/miniforge3/conda-bld/kiwisolver_1695380058985/work
|
69 |
+
label-studio==1.8.2.post1
|
70 |
+
label-studio-converter==0.0.54rc0
|
71 |
+
label-studio-tools==0.0.3
|
72 |
+
launchdarkly-server-sdk==7.5.0
|
73 |
+
lockfile==0.12.2
|
74 |
+
lxml==4.9.3
|
75 |
+
MarkupSafe @ file:///Users/runner/miniforge3/conda-bld/markupsafe_1695367660391/work
|
76 |
+
matplotlib @ file:///Users/runner/miniforge3/conda-bld/matplotlib-suite_1698868590489/work
|
77 |
+
mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1678228039184/work
|
78 |
+
munkres==1.1.4
|
79 |
+
networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1698504735452/work
|
80 |
+
nltk==3.6.7
|
81 |
+
numpy @ file:///Users/runner/miniforge3/conda-bld/numpy_1694920094885/work/dist/numpy-1.26.0-cp311-cp311-macosx_11_0_arm64.whl#sha256=6909902123b8421906e90ad77fb0041d9eb2d95bbdc29f3d09c7d244b0e0e5a5
|
82 |
+
openapi-codec==1.3.2
|
83 |
+
ordered-set==4.0.2
|
84 |
+
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1696202382185/work
|
85 |
+
pandas==2.1.2
|
86 |
+
Pillow @ file:///Users/runner/miniforge3/conda-bld/pillow_1697423665652/work
|
87 |
+
proto-plus==1.22.3
|
88 |
+
psycopg2-binary==2.9.6
|
89 |
+
pycparser==2.21
|
90 |
+
pyparsing @ file:///home/conda/feedstock_root/build_artifacts/pyparsing_1690737849915/work
|
91 |
+
pyRFC3339==1.1
|
92 |
+
pyrsistent==0.20.0
|
93 |
+
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work
|
94 |
+
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
|
95 |
+
python-json-logger==2.0.4
|
96 |
+
pytz==2023.3.post1
|
97 |
+
PyYAML @ file:///Users/runner/miniforge3/conda-bld/pyyaml_1695373486380/work
|
98 |
+
redis==3.5.3
|
99 |
+
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1684774241324/work
|
100 |
+
rq==1.10.1
|
101 |
+
ruamel.yaml==0.18.5
|
102 |
+
ruamel.yaml.clib==0.2.8
|
103 |
+
rules==2.2
|
104 |
+
s3transfer==0.3.7
|
105 |
+
safetensors==0.4.0
|
106 |
+
scikit-learn==1.3.2
|
107 |
+
scipy==1.11.3
|
108 |
+
semver==2.13.0
|
109 |
+
sentry-sdk==1.34.0
|
110 |
+
simplejson==3.19.2
|
111 |
+
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
|
112 |
+
sqlparse==0.4.4
|
113 |
+
sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1684180540116/work
|
114 |
+
thop==0.1.1.post2209072238
|
115 |
+
threadpoolctl==3.2.0
|
116 |
+
timm==0.9.10
|
117 |
+
torch==2.1.0
|
118 |
+
torchaudio==2.1.0
|
119 |
+
torchsampler==0.1.2
|
120 |
+
torchvision==0.16.0
|
121 |
+
tornado @ file:///Users/runner/miniforge3/conda-bld/tornado_1695373481350/work
|
122 |
+
tqdm==4.66.1
|
123 |
+
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1695040754690/work
|
124 |
+
tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1680081134351/work
|
125 |
+
ua-parser==0.18.0
|
126 |
+
ujson==5.8.0
|
127 |
+
uritemplate==4.1.1
|
128 |
+
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1697720414277/work
|
129 |
+
user-agents==2.2.0
|
130 |
+
webencodings==0.5.1
|
131 |
+
xmljson==0.2.0
|