|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once |
|
|
|
#include <c10/core/ScalarType.h> |
|
#include <c10/util/Exception.h> |
|
#include <torch/types.h> |
|
#include <vrs/RecordFormat.h> |
|
|
|
#include <projectaria_tools/tools/samples/vrs_mutation/ImageMutationFilterCopier.h> |
|
|
|
#include <torch/script.h> |
|
#include <torch/serialize.h> |
|
#include <torch/torch.h> |
|
#include <cstdint> |
|
#include <iostream> |
|
#include <memory> |
|
#include <string> |
|
|
|
#include <c10/cuda/CUDACachingAllocator.h> |
|
#include <opencv2/core.hpp> |
|
#include <opencv2/imgproc.hpp> |
|
|
|
namespace EgoBlur { |
|
|
|
struct EgoBlurImageMutator : public vrs::utils::UserDefinedImageMutator { |
|
|
|
|
|
|
|
|
|
std::shared_ptr<torch::jit::script::Module> faceModel_; |
|
std::shared_ptr<torch::jit::script::Module> licensePlateModel_; |
|
float faceModelConfidenceThreshold_; |
|
float licensePlateModelConfidenceThreshold_; |
|
float scaleFactorDetections_; |
|
float nmsThreshold_; |
|
bool useGPU_; |
|
bool clockwise90Rotation_; |
|
std::unordered_map<std::string, std::unordered_map<std::string, int>> stats_; |
|
torch::Device device_ = torch::kCPU; |
|
|
|
explicit EgoBlurImageMutator( |
|
const std::string& faceModelPath = "", |
|
const float faceModelConfidenceThreshold = 0.1, |
|
const std::string& licensePlateModelPath = "", |
|
const float licensePlateModelConfidenceThreshold = 0.1, |
|
const float scaleFactorDetections = 1.15, |
|
const float nmsThreshold = 0.3, |
|
const bool useGPU = true, |
|
const bool clockwise90Rotation = true) |
|
: faceModelConfidenceThreshold_(faceModelConfidenceThreshold), |
|
licensePlateModelConfidenceThreshold_( |
|
licensePlateModelConfidenceThreshold), |
|
scaleFactorDetections_(scaleFactorDetections), |
|
nmsThreshold_(nmsThreshold), |
|
useGPU_(useGPU), |
|
clockwise90Rotation_(clockwise90Rotation) { |
|
device_ = getDevice(); |
|
std::cout << "attempting to load ego blur face model: " << faceModelPath |
|
<< std::endl; |
|
|
|
if (!faceModelPath.empty()) { |
|
faceModel_ = loadModel(faceModelPath); |
|
} |
|
|
|
std::cout << "attempting to load ego blur license plate model: " |
|
<< licensePlateModelPath << std::endl; |
|
|
|
if (!licensePlateModelPath.empty()) { |
|
licensePlateModel_ = loadModel(licensePlateModelPath); |
|
} |
|
} |
|
|
|
std::shared_ptr<torch::jit::script::Module> loadModel( |
|
const std::string& path) { |
|
std::shared_ptr<torch::jit::script::Module> model; |
|
try { |
|
model = std::make_shared<torch::jit::script::Module>(); |
|
|
|
*model = torch::jit::load(path); |
|
std::cout << "Loaded model: " << path << std::endl; |
|
model->to(device_); |
|
model->eval(); |
|
} catch (const c10::Error&) { |
|
std::cout << "Failed to load model: " << path << std::endl; |
|
throw; |
|
} |
|
return model; |
|
} |
|
|
|
at::DeviceType getDevice() const { |
|
if (useGPU_ && torch::cuda::is_available()) { |
|
|
|
return torch::kCUDA; |
|
} else { |
|
|
|
return torch::kCPU; |
|
} |
|
} |
|
|
|
torch::Tensor filterDetections( |
|
c10::intrusive_ptr<c10::ivalue::Tuple> detections, |
|
float scoreThreshold) const { |
|
|
|
torch::Tensor scoreThresholdMask = |
|
torch::gt( |
|
detections->elements().at(2).toTensor(), |
|
torch::tensor(scoreThreshold)) |
|
.detach(); |
|
|
|
torch::Tensor filteredBoundingBoxes = detections->elements() |
|
.at(0) |
|
.toTensor() |
|
.index({scoreThresholdMask}) |
|
.detach(); |
|
torch::Tensor filteredBoundingBoxesScores = detections->elements() |
|
.at(2) |
|
.toTensor() |
|
.index({scoreThresholdMask}) |
|
.detach(); |
|
|
|
|
|
torch::Tensor filteredBoundingBoxesPostNMS = |
|
performNMS( |
|
filteredBoundingBoxes, filteredBoundingBoxesScores, nmsThreshold_) |
|
.detach(); |
|
scoreThresholdMask.reset(); |
|
filteredBoundingBoxes.reset(); |
|
filteredBoundingBoxesScores.reset(); |
|
return filteredBoundingBoxesPostNMS; |
|
} |
|
|
|
|
|
torch::Tensor performNMS( |
|
const torch::Tensor& boxes, |
|
const torch::Tensor& scores, |
|
float overlapThreshold) const { |
|
|
|
torch::Tensor boxesCPU = boxes.to(torch::kCPU).detach(); |
|
torch::Tensor scoresCPU = scores.to(torch::kCPU).detach(); |
|
|
|
|
|
int numBoxes = boxesCPU.size(0); |
|
|
|
|
|
auto boxesAccessor = boxesCPU.accessor<float, 2>(); |
|
auto scoresAccessor = scoresCPU.accessor<float, 1>(); |
|
|
|
std::vector<bool> picked(numBoxes, false); |
|
|
|
for (int i = 0; i < numBoxes; ++i) { |
|
if (!picked[i]) { |
|
for (int j = i + 1; j < numBoxes; ++j) { |
|
if (!picked[j]) { |
|
float x1 = std::max(boxesAccessor[i][0], boxesAccessor[j][0]); |
|
float y1 = std::max(boxesAccessor[i][1], boxesAccessor[j][1]); |
|
float x2 = std::min(boxesAccessor[i][2], boxesAccessor[j][2]); |
|
float y2 = std::min(boxesAccessor[i][3], boxesAccessor[j][3]); |
|
|
|
float intersection = |
|
std::max(0.0f, x2 - x1) * std::max(0.0f, y2 - y1); |
|
float iou = intersection / |
|
((boxesAccessor[i][2] - boxesAccessor[i][0]) * |
|
(boxesAccessor[i][3] - boxesAccessor[i][1]) + |
|
(boxesAccessor[j][2] - boxesAccessor[j][0]) * |
|
(boxesAccessor[j][3] - boxesAccessor[j][1]) - |
|
intersection); |
|
|
|
if (iou > overlapThreshold) { |
|
if (scoresAccessor[i] > scoresAccessor[j]) { |
|
picked[j] = true; |
|
} else { |
|
picked[i] = true; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
std::vector<int> selectedIndices; |
|
for (int i = 0; i < numBoxes; ++i) { |
|
if (!picked[i]) { |
|
selectedIndices.push_back(i); |
|
} |
|
} |
|
|
|
torch::Tensor filteredBoundingBoxes = |
|
torch::index_select( |
|
boxes.to(torch::kCPU), |
|
0, |
|
torch::from_blob( |
|
selectedIndices.data(), |
|
{static_cast<long>(selectedIndices.size())}, |
|
torch::kInt)) |
|
.detach(); |
|
|
|
boxesCPU.reset(); |
|
scoresCPU.reset(); |
|
return filteredBoundingBoxes; |
|
} |
|
|
|
static std::vector<float> scaleBox( |
|
const std::vector<float>& box, |
|
int maxWidth, |
|
int maxHeight, |
|
float scale) { |
|
|
|
float x1 = box[0]; |
|
float y1 = box[1]; |
|
float x2 = box[2]; |
|
float y2 = box[3]; |
|
float w = x2 - x1; |
|
float h = y2 - y1; |
|
|
|
|
|
float xc = x1 + (w / 2); |
|
float yc = y1 + (h / 2); |
|
|
|
w = scale * w; |
|
h = scale * h; |
|
|
|
x1 = std::max(xc - (w / 2), 0.0f); |
|
y1 = std::max(yc - (h / 2), 0.0f); |
|
x2 = std::min(xc + (w / 2), static_cast<float>(maxWidth)); |
|
y2 = std::min(yc + (h / 2), static_cast<float>(maxHeight)); |
|
|
|
return {x1, y1, x2, y2}; |
|
} |
|
|
|
cv::Mat blurImage( |
|
const cv::Mat& image, |
|
const std::vector<torch::Tensor>& detections, |
|
float scale) { |
|
|
|
cv::Mat response = image.clone(); |
|
cv::Mat mask; |
|
if (image.channels() == 3) { |
|
mask = cv::Mat::zeros(image.size(), CV_8UC3); |
|
} else { |
|
mask = cv::Mat::zeros(image.size(), CV_8UC1); |
|
} |
|
for (const auto& detection : detections) { |
|
for (auto& box : detection.unbind()) { |
|
std::vector<float> boxVector( |
|
box.data_ptr<float>(), box.data_ptr<float>() + box.numel()); |
|
if (scale != 1.0f) { |
|
boxVector = scaleBox(boxVector, image.cols, image.rows, scale); |
|
} |
|
int x1 = static_cast<int>(boxVector[0]); |
|
int y1 = static_cast<int>(boxVector[1]); |
|
int x2 = static_cast<int>(boxVector[2]); |
|
int y2 = static_cast<int>(boxVector[3]); |
|
int w = x2 - x1; |
|
int h = y2 - y1; |
|
|
|
|
|
cv::Scalar color; |
|
if (image.channels() == 3) { |
|
color = cv::Scalar(255, 255, 255); |
|
} else { |
|
color = cv::Scalar(255); |
|
} |
|
|
|
cv::ellipse( |
|
mask, |
|
cv::Point((x1 + x2) / 2, (y1 + y2) / 2), |
|
cv::Size(w / 2, h / 2), |
|
0, |
|
0, |
|
360, |
|
color, |
|
-1); |
|
|
|
cv::Size ksize = cv::Size(image.rows / 8, image.cols / 8); |
|
cv::Mat blurredImage; |
|
cv::blur(image(cv::Rect({x1, y1, w, h})), blurredImage, ksize); |
|
blurredImage.copyTo( |
|
response(cv::Rect({x1, y1, w, h})), mask(cv::Rect({x1, y1, w, h}))); |
|
blurredImage.release(); |
|
} |
|
} |
|
mask.release(); |
|
return response; |
|
} |
|
|
|
cv::Mat detectAndBlur( |
|
vrs::utils::PixelFrame* frame, |
|
const std::string& frameId) { |
|
|
|
const int width = frame->getWidth(); |
|
const int height = frame->getHeight(); |
|
|
|
const int channels = |
|
frame->getPixelFormat() == vrs::PixelFormat::RGB8 ? 3 : 1; |
|
|
|
cv::Mat img = cv::Mat( |
|
height, |
|
width, |
|
CV_8UC(channels), |
|
static_cast<void*>(frame->getBuffer().data())) |
|
.clone(); |
|
|
|
|
|
if (clockwise90Rotation_) { |
|
cv::rotate(img, img, cv::ROTATE_90_CLOCKWISE); |
|
} |
|
|
|
torch::NoGradGuard no_grad; |
|
|
|
|
|
torch::Tensor imgTensor = torch::from_blob( |
|
(void*)frame->rdata(), {height, width, channels}, torch::kUInt8); |
|
|
|
torch::Tensor imgTensorFloat = imgTensor.to(torch::kFloat); |
|
|
|
|
|
torch::Tensor imgTensorFloatOnDevice = imgTensorFloat.to(device_); |
|
|
|
torch::Tensor imgTensorFloatOnDevicePostRotation; |
|
|
|
if (clockwise90Rotation_) { |
|
imgTensorFloatOnDevicePostRotation = |
|
torch::rot90(imgTensorFloatOnDevice, -1); |
|
} else { |
|
imgTensorFloatOnDevicePostRotation = imgTensorFloatOnDevice; |
|
} |
|
|
|
torch::Tensor imgTensorFloatOnDevicePostRotationCHW = |
|
imgTensorFloatOnDevicePostRotation.permute({2, 0, 1}); |
|
|
|
|
|
std::vector<torch::jit::IValue> inputs = { |
|
imgTensorFloatOnDevicePostRotationCHW}; |
|
|
|
|
|
std::vector<torch::Tensor> boundingBoxes; |
|
|
|
cv::Mat finalImage; |
|
|
|
torch::Tensor faceBoundingBoxes; |
|
torch::Tensor licensePlateBoundingBoxes; |
|
|
|
|
|
|
|
if (faceModel_) { |
|
c10::intrusive_ptr<c10::ivalue::Tuple> faceDetections = |
|
faceModel_->forward(inputs) |
|
.toTuple(); |
|
faceBoundingBoxes = |
|
filterDetections(faceDetections, faceModelConfidenceThreshold_); |
|
int totalFaceDetectionsForCurrentFrame = faceBoundingBoxes.sizes()[0]; |
|
stats_[frameId]["faces"] += totalFaceDetectionsForCurrentFrame; |
|
if (faceBoundingBoxes.sizes()[0] > 0) { |
|
boundingBoxes.push_back(faceBoundingBoxes); |
|
} |
|
faceDetections.reset(); |
|
} |
|
|
|
|
|
if (licensePlateModel_) { |
|
c10::intrusive_ptr<c10::ivalue::Tuple> licensePlateDetections = |
|
licensePlateModel_->forward(inputs) |
|
.toTuple(); |
|
licensePlateBoundingBoxes = filterDetections( |
|
licensePlateDetections, licensePlateModelConfidenceThreshold_); |
|
int totaLlicensePlateDetectionsForCurrentFrame = |
|
licensePlateBoundingBoxes.sizes()[0]; |
|
stats_[frameId]["licensePlate"] += |
|
totaLlicensePlateDetectionsForCurrentFrame; |
|
if (licensePlateBoundingBoxes.sizes()[0] > 0) { |
|
boundingBoxes.push_back(licensePlateBoundingBoxes); |
|
} |
|
licensePlateDetections.reset(); |
|
} |
|
|
|
if (!boundingBoxes.empty()) { |
|
|
|
finalImage = blurImage(img, boundingBoxes, scaleFactorDetections_); |
|
|
|
|
|
if (clockwise90Rotation_) { |
|
cv::rotate(finalImage, finalImage, cv::ROTATE_90_COUNTERCLOCKWISE); |
|
} |
|
|
|
boundingBoxes.clear(); |
|
} |
|
|
|
inputs.clear(); |
|
imgTensor.reset(); |
|
imgTensorFloat.reset(); |
|
imgTensorFloatOnDevice.reset(); |
|
imgTensorFloatOnDevicePostRotation.reset(); |
|
imgTensorFloatOnDevicePostRotationCHW.reset(); |
|
faceBoundingBoxes.reset(); |
|
licensePlateBoundingBoxes.reset(); |
|
img.release(); |
|
return finalImage; |
|
} |
|
|
|
bool operator()( |
|
double timestamp, |
|
const vrs::StreamId& streamId, |
|
vrs::utils::PixelFrame* frame) override { |
|
|
|
if (!frame) { |
|
return false; |
|
} |
|
|
|
cv::Mat blurredImage; |
|
|
|
if (streamId.getNumericName().find("214") != std::string::npos || |
|
streamId.getNumericName().find("1201") != std::string::npos) { |
|
|
|
std::string frameId = |
|
streamId.getNumericName() + "_" + std::to_string(timestamp); |
|
stats_[frameId]["faces"] = 0; |
|
stats_[frameId]["licensePlate"] = 0; |
|
blurredImage = detectAndBlur(frame, frameId); |
|
} |
|
|
|
if (!blurredImage.empty()) { |
|
|
|
if (streamId.getNumericName().find("214") != std::string::npos) { |
|
std::memcpy( |
|
frame->wdata(), |
|
blurredImage.data, |
|
frame->getWidth() * frame->getStride()); |
|
} |
|
|
|
else if (streamId.getNumericName().find("1201") != std::string::npos) { |
|
std::memcpy( |
|
frame->wdata(), |
|
blurredImage.data, |
|
frame->getWidth() * frame->getHeight()); |
|
} |
|
} |
|
blurredImage.release(); |
|
c10::cuda::CUDACachingAllocator::emptyCache(); |
|
return true; |
|
} |
|
|
|
std::string logStatistics() const { |
|
std::string statsString; |
|
int totalFrames = 0; |
|
int totalRGBFramesWithFaces = 0; |
|
int totalRGBFaces = 0; |
|
int totalSLAMFramesWithFaces = 0; |
|
int totalSLAMFaces = 0; |
|
int totalRGBFramesWithLicensePlate = 0; |
|
int totalRGBLicensePlate = 0; |
|
int totalSLAMFramesWithLicensePlate = 0; |
|
int totalSLAMLicensePlate = 0; |
|
|
|
for (const auto& outer : stats_) { |
|
const std::string& frameId = outer.first; |
|
const std::unordered_map<std::string, int>& categoryBoxCountMapping = |
|
outer.second; |
|
|
|
|
|
for (const auto& innerPair : categoryBoxCountMapping) { |
|
const std::string& category = innerPair.first; |
|
int boxCount = innerPair.second; |
|
|
|
if (boxCount > 0) { |
|
if (category == "faces") { |
|
if (frameId.find("214") != std::string::npos) { |
|
totalRGBFramesWithFaces++; |
|
totalRGBFaces += boxCount; |
|
} else if (frameId.find("1201") != std::string::npos) { |
|
totalSLAMFramesWithFaces++; |
|
totalSLAMFaces += boxCount; |
|
} |
|
} |
|
if (category == "licensePlate") { |
|
if (frameId.find("214") != std::string::npos) { |
|
totalRGBFramesWithLicensePlate++; |
|
totalRGBLicensePlate += boxCount; |
|
} else if (frameId.find("1201") != std::string::npos) { |
|
totalSLAMFramesWithLicensePlate++; |
|
totalSLAMLicensePlate += boxCount; |
|
} |
|
} |
|
} |
|
} |
|
totalFrames++; |
|
} |
|
|
|
std::ostringstream summary; |
|
summary << " ----------------" << "\n| Summary |" |
|
<< "\n ----------------" << "\nTotal frames: " << totalFrames |
|
<< "\n Faces:" << "\n RGB - Total detected frame: " |
|
<< totalRGBFramesWithFaces |
|
<< "\n RGB - Total detections: " << totalRGBFaces |
|
<< "\n SLAM - Total detected frame: " << totalSLAMFramesWithFaces |
|
<< "\n SLAM - Total detections: " << totalSLAMFaces |
|
<< "\n License Plates:" << "\n RGB - Total detected frame: " |
|
<< totalRGBFramesWithLicensePlate |
|
<< "\n RGB - Total detections: " << totalRGBLicensePlate |
|
<< "\n SLAM - Total detected frame: " |
|
<< totalSLAMFramesWithLicensePlate |
|
<< "\n SLAM - Total detections: " << totalSLAMLicensePlate; |
|
return summary.str(); |
|
} |
|
}; |
|
|
|
} |
|
|