Create Faster R-CNN DN
Browse files- Faster R-CNN DN +113 -0
Faster R-CNN DN
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
inputSize = [224 224 3];
|
2 |
+
|
3 |
+
preprossedTrainingData = transform(trainingData, @(data)preprocessData(data,inputSize));
|
4 |
+
numAnchors = 3;
|
5 |
+
anchorBoxes = estimateAnchorBoxes(preprocessedTrainingData,numAnchors)
|
6 |
+
|
7 |
+
featuresExtractionNetwork = resnet50;
|
8 |
+
|
9 |
+
featureLayer - "activation_40_relu";
|
10 |
+
|
11 |
+
numClasses = width(vehicleDataset)-1;
|
12 |
+
|
13 |
+
lgraph = fasterRCNNLayers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer);
|
14 |
+
|
15 |
+
augmentedTrainingData = transform(trainingData,@aumentData);
|
16 |
+
|
17 |
+
augmentedData = cell(4,1);
|
18 |
+
for k = 1:4
|
19 |
+
data = read(augmentedTrainingData);
|
20 |
+
augmentedData{k} = insertShape)data{1},"rectangle",data{2});
|
21 |
+
reset(augmentedTrainingData);
|
22 |
+
end
|
23 |
+
figure
|
24 |
+
montage(augmentedData,BorderSize=10)
|
25 |
+
|
26 |
+
trainingData = transform(augmentedTrainingData,@(data)preprocessData(data,inputSize));
|
27 |
+
validationData = transform(validationData,@(data)preprocessData(data,inputSize));
|
28 |
+
|
29 |
+
data = read(trainingData);
|
30 |
+
|
31 |
+
I = data{1};
|
32 |
+
bbox = data{2};
|
33 |
+
annotatedImage = insertShape(I,"rectangle",bbox);
|
34 |
+
annotatedImage = imresize(annotatedImage,2);
|
35 |
+
figure
|
36 |
+
imshow(annotatedImage)
|
37 |
+
|
38 |
+
// Train Faster R-CNN
|
39 |
+
|
40 |
+
options = trainingOptions("sgdm",...
|
41 |
+
MaxEpochs=10,...
|
42 |
+
MiniBatchSize=2,...
|
43 |
+
InitialLearnRate=1e-3,...
|
44 |
+
CheckpointPatin=tempdir,...
|
45 |
+
ValidationData=validationData);
|
46 |
+
|
47 |
+
if doTraining
|
48 |
+
% Train the Faster R-CNN detector.
|
49 |
+
% * Adjust NegativeOveralpRange and PositiveOverlapRange to ensure
|
50 |
+
% that training samples tightly overlap with ground truth.
|
51 |
+
[detector, info] = trainFasterRCNNObjectDetector(training
|
52 |
+
NegativeOverlapRange=[0 0.3], ...
|
53 |
+
PositiveOverlapRange=[0.6 1]);
|
54 |
+
else
|
55 |
+
% Load pretrained detector for the example.
|
56 |
+
pretrained = load("fasterRCNNResNet50EndToEndVehicleExample.mat");
|
57 |
+
detector = pretrained.detetor;
|
58 |
+
end
|
59 |
+
|
60 |
+
I = imread(testDataTbl.imageFilename{3});
|
61 |
+
I = imresize(I,inputSize(1:2));
|
62 |
+
[bboxes,scores] = detect(detector,I);
|
63 |
+
|
64 |
+
I = insertObjectAnnotation(I,"rectangle",bboxes,scores);
|
65 |
+
figure
|
66 |
+
imshow(I)
|
67 |
+
|
68 |
+
testData = transform(testData,@(data)preprocessData(data,inputSize));
|
69 |
+
|
70 |
+
detectionResults = detect(detector,testData,...
|
71 |
+
Threshold=0.2,...
|
72 |
+
MiniBatchSize=4);
|
73 |
+
|
74 |
+
classID = 1;
|
75 |
+
metrics = evaluateObjectDetection(detectionResults,testData);
|
76 |
+
precision = metrics.ClassMetrics.Precision{classID};
|
77 |
+
recall = metrics.ClassMetrics.Recall{classID};
|
78 |
+
|
79 |
+
figure
|
80 |
+
plot(recall,precision)
|
81 |
+
xlabel("Recall")
|
82 |
+
ylable("Precision")
|
83 |
+
grid on
|
84 |
+
title(sprintf("Average Precision = %.2f", metrics.ClassMetrics.mAP(classID)))
|
85 |
+
|
86 |
+
function data = augmentData(data)
|
87 |
+
% Randomly flip images and bounding boxes horizontally.
|
88 |
+
tform = randomAffine2d("XReflection",true);
|
89 |
+
sz = size(data{1});
|
90 |
+
rout = affineOutputView(sz,tform);
|
91 |
+
data{1} = imwarp(data{1},tform,"OutputView",rout);
|
92 |
+
|
93 |
+
% Sanitize boxes, if needed. This helper function is attached as a
|
94 |
+
% supporting file. Open the example in MATLAB to open this function.
|
95 |
+
data{2} = helperSanitizeBoxes(data{2});
|
96 |
+
|
97 |
+
% Warp boxes.
|
98 |
+
data{2} = bboxwwarp(data{2},tform,rout);
|
99 |
+
end
|
100 |
+
|
101 |
+
function data = preprocessData(data,targetSize)
|
102 |
+
% Resize image and bounding boxes to targetSize.
|
103 |
+
sz = size(data{1},[1 2]);
|
104 |
+
scale = targetSize(1:2)./sz;
|
105 |
+
data{1} = imresize(data{1},targetSize(1:2));
|
106 |
+
|
107 |
+
% Sanitize boxes, if needed. This helper function is attached as a
|
108 |
+
% supporting file. Open the example in MATLAB to open this function.
|
109 |
+
data{2} = helperSanitizeBoxes(data{2});
|
110 |
+
|
111 |
+
% Resize boxes.
|
112 |
+
data{2} = bboxresize(data{2},scale);
|
113 |
+
end
|