OpenCV学习笔记 - DNN模块使用(含源码、详细解释)
最近翻了翻以前做的一些笔记,碰巧翻到了2019年刚开始学习OpenCV时候做的笔记,不知不觉已经过去两年了,这两年从一个小白到现在不是太小白的小白o(╥﹏╥)o,在此分享一下,希望能帮助到更多的人。相关视频:https://www.bilibili.com/video/BV1FJ411T7W5?p=2文章目录DNN模块Googlenet模型实现图像分类介绍:代码:结果展示:SSD模型实现对象检测介
最近翻了翻以前做的一些笔记,碰巧翻到了2019年刚开始学习OpenCV时候做的笔记,不知不觉已经过去两年了,这两年从一个小白到现在不是太小白的小白o(╥﹏╥)o,在此分享一下,希望能帮助到更多的人。
相关视频:https://www.bilibili.com/video/BV1FJ411T7W5?p=2
文章目录
DNN模块
Googlenet模型实现图像分类
介绍:
论文:https://github.com/SnailTyan/deep-learning-papers-translation
这里有很多翻译好的论文,很方便。
所需文件:二进制模型文件,模型参数描述文件,分类label文件。
模型下载:
http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel
卷积层提取特征,全连接层进行分类。
描述文件:bvlc_googlenet.prototxt
这个在opencv的源码里边有opencv-3.3.1\samples\data\dnn
模型输出为一个1000维的向量,代表1000个分类的概率。
代码:
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/dnn.hpp>
#include <iostream>
#include <fstream>
using namespace cv;
using namespace std;
using namespace cv::dnn;
String model_bin_file = "model/bvlc_googlenet.caffemodel";
String model_txt_file = "model/bvlc_googlenet.prototxt";
String labels_txt_file = "model/synset_words.txt";
vector<String> readLabels();
int main(int argc, char** argv)
{
Mat src = imread("pictures/girl.jpg");
if (src.empty())
{
cout << "could not open image……" << endl;
return -1;
}
namedWindow("src", WINDOW_FREERATIO);
imshow("src", src);
// 读取labels
vector<String> labels = readLabels();
// 读取网络 包括模型描述文件和和模型文件
Net net = readNetFromCaffe(model_txt_file, model_bin_file);
if (net.empty())
{
cout << "net could not load……" << endl;
return -1;
}
Mat inputBlob = blobFromImage(src, 1.0, Size(224, 224), Scalar(104, 117, 123));
Mat prob;
for (size_t i = 0; i < 10; i++)
{
net.setInput(inputBlob, "data");
prob = net.forward("prob"); // 输出为1×1000 1000类的概率
}
Mat proMat = prob.reshape(1, 1); // 单通道 一行
Point classNumber;
double classProb;
minMaxLoc(proMat, NULL, &classProb, NULL, &classNumber);
int classidx = classNumber.x;
cout << "current image classification:" << labels.at(classidx).c_str()
<< "possible:" << classProb << endl;
putText(src, labels.at(classidx), Point(20, 20), FONT_HERSHEY_PLAIN, 1.5, Scalar(0, 0, 255), 1, 8);
imshow("image", src);
waitKey(0);
return 0;
}
vector<String> readLabels()
{
vector<String> classNames;
ifstream fin(labels_txt_file.c_str());
if (!fin.is_open())
{
cout << "could not open the file……" << endl;
exit(-1);
}
string name;
while (!fin.eof())
{
getline(fin, name);
if (name.length())
{
classNames.push_back(name.substr(name.find(" " + 1)));// 按空格的位置往后移一位进行分割
}
}
fin.close();
return classNames;
}
结果展示:
SSD模型实现对象检测
介绍:
模型下载:
https://github.com/weiliu89/caffe/tree/ssd#models
结构:
比传统的R-CNN要好很多。把两步和为一步,帧率得到了提高。
模型文件:还是有三个 二进制模型文件,模型参数描述文件,分类label文件
模型输出为一个7维向量 后四维为检测出来目标框的矩形坐标 倒数第5维为置信度
代码:
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/dnn.hpp>
#include <iostream>
#include <fstream>
using namespace std;
using namespace cv;
using namespace cv::dnn;
const size_t width = 300;
const size_t height = 300;
String labelFile = "model\\models_VGGNet_ILSVRC2016_SSD_300x300\\models\\VGGNet\\ILSVRC2016\\SSD_300x300\\labelmap_ilsvrc_det.prototxt";
String modelFile = "model\\models_VGGNet_ILSVRC2016_SSD_300x300\\models\\VGGNet\\ILSVRC2016\\SSD_300x300\\VGG_ILSVRC2016_SSD_300x300_iter_440000.caffemodel";
String model_text_file = "model\\models_VGGNet_ILSVRC2016_SSD_300x300\\models\\VGGNet\\ILSVRC2016\\SSD_300x300\\deploy.prototxt";
const int meanValues[3] = { 104, 117, 123 };
vector<String> readLabels();
static Mat getMean(const size_t &w, const size_t &h);
static Mat preprocess(const Mat& frame);
int main(int argc, char** argv)
{
Mat frame = imread("pictures/cat.jpg");
if (frame.empty())
{
cout << "could not open image……" << endl;
return -1;
}
namedWindow("input image", WINDOW_FREERATIO);
imshow("input image", frame);
vector<String> objNames = readLabels();
// import Caffe SSD model
Net net = readNetFromCaffe(model_text_file, modelFile);
if (net.empty())
{
cout << "read caffe model data failure..." << endl;
return -1;
}
Mat input_image = preprocess(frame);
Mat blobImage = blobFromImage(input_image);
net.setInput(blobImage, "data");
Mat detection = net.forward("detection_out");
Mat detectionMat(detection.size[2], detection.size[3], CV_32F, detection.ptr<float>());
float confidence_threshold = 0.1;
for (int i = 0; i < detectionMat.rows; i++)
{
// 输出为一个7维向量 后四维为检测出来目标框的矩形坐标 倒数第5维为置信度
float confidence = detectionMat.at<float>(i, 2);
if (confidence > confidence_threshold)
{
size_t objIndex = (size_t)(detectionMat.at<float>(i, 1));
float tl_x = detectionMat.at<float>(i, 3) * frame.cols;
float tl_y = detectionMat.at<float>(i, 4) * frame.rows;
float br_x = detectionMat.at<float>(i, 5) * frame.cols;
float br_y = detectionMat.at<float>(i, 6) * frame.rows;
Rect object_box((int)tl_x, (int)tl_y, (int)(br_x - tl_x), (int)(br_y - tl_y));
rectangle(frame, object_box, Scalar(0, 0, 255), 2, 8, 0);
putText(frame, format("%s", objNames[objIndex].c_str()), Point(tl_x, tl_y), FONT_HERSHEY_SIMPLEX, 1.0, Scalar(255, 0, 0), 2);
}
}
imshow("ssd-demo", frame);
waitKey(0);
return 0;
}
vector<String> readLabels()
{
vector<String> objNames;
ifstream fin(labelFile);
if (!fin.is_open())
{
cout << "could not load labeFile……" << endl;
exit(-1);
}
string name;
while (!fin.eof())
{
getline(fin, name);
if (name.length() && (name.find("display_name:") == 2))
{
string temp = name.substr(17);
temp.replace(temp.end() - 1, temp.end(), "");
objNames.push_back(temp);
}
}
return objNames;
}
Mat getMean(const size_t& w, const size_t& h)
{
Mat mean;
vector<Mat> channels;
for (size_t i = 0; i < 3; i++)
{
Mat channel(h, w, CV_32F, Scalar(meanValues[i]));
channels.push_back(channel);
}
merge(channels, mean);
return mean;
}
Mat preprocess(const Mat& frame)
{
Mat preprocessed;
frame.convertTo(preprocessed, CV_32F);
resize(preprocessed, preprocessed, Size(width, height)); // 300*300 image
Mat mean = getMean(width, height);
subtract(preprocessed, mean, preprocessed);
return preprocessed;
}
结果展示:
MobileNetSSD模型实时对象检测
介绍:
对SSD模型进行了简化,从1000个分类缩减为20个。
还是模型二进制文件,模型描述文件,label文件。
模型下载地址:https://github.com/PINTO0309/MobileNet-SSD-RealSense/blob/master/caffemodel/MobileNetSSD/MobileNetSSD_deploy.caffemodel
注意要使用deploy版本的。
模型输出也为一个7维向量 后四维为检测出来目标框的矩形坐标 倒数第5维为置信度
代码:
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/dnn.hpp>
#include <iostream>
#include <fstream>
using namespace std;
using namespace cv;
using namespace cv::dnn;
const size_t width = 300;
const size_t height = 300;
// 下面这两个参数是官方的参数
const float meanVal = 127.5;
const float scaleFactor = 0.0078;
String labelFile = "model/mobileNetSSD/pascal-classes.txt";
String modelFile = "model/mobileNetSSD/MobileNetSSD_deploy.caffemodel";
String model_text_file = "model/mobileNetSSD/MobileNetSSD_deploy.prototxt";
vector<String> readLabels();
int main(int argc, char** argv)
{
VideoCapture capture;
capture.open("pictures/vtest.avi");
namedWindow("input", CV_WINDOW_FREERATIO);
namedWindow("ssd-video-demo", CV_WINDOW_FREERATIO);
int w = capture.get(CAP_PROP_FRAME_WIDTH);
int h = capture.get(CAP_PROP_FRAME_HEIGHT);
printf("frame width:%d, frame height:%d\n", w, h);
// set up net
Net net = readNetFromCaffe(model_text_file, modelFile);
if (net.empty())
{
cout << "could not load NetModel……" << endl;
return -1;
}
// read the label
vector<String> classNames = readLabels();
Mat frame;
int i = 0;
while (capture.read(frame))
{
i++;
imshow("input", frame);
// 预测
double t1 = (double)getTickCount();
Mat inputblob = blobFromImage(frame, scaleFactor, Size(width, height), meanVal, false);
net.setInput(inputblob, "data");
Mat detection = net.forward("detection_out");
double t2 = (double)getTickCount();
cout << "第" << i << "帧" << "耗费时间:" << (t2 - t1) / getTickFrequency() << "s\n" << endl;
// 绘制
Mat detectionMat(detection.size[2], detection.size[3], CV_32F, detection.ptr<float>());
float confidence_threshold = 0.25;
for (int i = 0; i < detectionMat.rows; i++) {
float confidence = detectionMat.at<float>(i, 2);
if (confidence > confidence_threshold) {
size_t objIndex = (size_t)(detectionMat.at<float>(i, 1));
float tl_x = detectionMat.at<float>(i, 3) * frame.cols;
float tl_y = detectionMat.at<float>(i, 4) * frame.rows;
float br_x = detectionMat.at<float>(i, 5) * frame.cols;
float br_y = detectionMat.at<float>(i, 6) * frame.rows;
Rect object_box((int)tl_x, (int)tl_y, (int)(br_x - tl_x), (int)(br_y - tl_y));
rectangle(frame, object_box, Scalar(0, 0, 255), 2, 8, 0);
//putText(frame, format("%s", classNames[objIndex]), Point(tl_x, tl_y), FONT_HERSHEY_SIMPLEX, 1.0, Scalar(255, 0, 0), 2);
putText(frame, classNames[objIndex], Point(tl_x, tl_y), FONT_HERSHEY_SIMPLEX, 1.0, Scalar(255, 0, 0), 2);
}
}
imshow("ssd-video-demo", frame);
char c = waitKey(50);
if (c == 27) // ESC
{
break;
}
}
waitKey(0);
return 0;
}
vector<String> readLabels()
{
vector<String> objNames;
ifstream fin(labelFile);
if (!fin.is_open())
{
cout << "could not load labeFile……" << endl;
exit(-1);
}
string name;
while (!fin.eof())
{
getline(fin, name);
if (name.length())
{
string temp = name.substr(0, name.find(" ", 0));
objNames.push_back(temp);
}
}
return objNames;
}
结果展示:
FCN模型图像分割
介绍:
论文:https://www.cv-foundation.org/openaccess/content_cvpr_2015/papers/Long_Fully_Convolutional_Networks_2015_CVPR_paper.pdf
全卷积网络
模型与数据:
还是三个文件:
模型下载地址:https://github.com/shelhamer/fcn.berkeleyvision.org
模型输出为21×500×500的数组。21为channel,也就是类别。500×500为rows×cols,对应于图片中的每一个像素值。
代码:
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/dnn.hpp>
#include <iostream>
#include <fstream>
#include <string.h>
#include <stdio.h>
using namespace std;
using namespace cv;
using namespace cv::dnn;
const size_t width = 500;
const size_t height = 500;
String labelFile = "model\\FCN\\pascal-classes.txt";
String modelFile = "model\\FCN\\fcn8s-heavy-pascal.caffemodel";
String model_text_file = "model\\FCN\\fcn8s-heavy-pascal.prototxt";
Scalar meanValues = Scalar(104, 117, 123);
vector<Vec3b> readColors();
vector<String> readLabels();
int main(int argc, char** argv)
{
Mat frame = imread("pictures/rgb.jpg");
//Mat frame = imread("E:/Dataset/Flange/picture_sample/水渍and砂眼/test2.jpg");
Mat img_gray;
cvtColor(frame, img_gray, COLOR_BGR2GRAY);
if (frame.empty())
{
cout << "could not open image……" << endl;
return -1;
}
namedWindow("input image", WINDOW_FREERATIO);
imshow("input image", frame);
resize(frame, frame, Size(500, 500));
vector<Vec3b> colors = readColors();
// import Caffe SSD model
Net net = readNetFromCaffe(model_text_file, modelFile);
if (net.empty())
{
cout << "read caffe model data failure..." << endl;
return -1;
}
Mat blobImage = blobFromImage(frame);
// 预测
net.setInput(blobImage, "data");
Mat score = net.forward("score");
// 分割并显示
const int rows = score.size[2];
const int cols = score.size[3];
const int chns = score.size[1];
Mat maxCl(rows, cols, CV_8UC1); // 该像素处概率最大的那个channel 类别
Mat maxVal(rows, cols, CV_32FC1); // 该像素处概率最大的那个channel所对应的的概率值 该类别所对应的概率 这个值下边其实没用到
// setup LUT
for (int c = 0; c < chns; c++)
{
for (int row = 0; row < rows; row++)
{
const float* ptrScore = score.ptr<float>(0, c, row);
uchar* ptrMaxCl = maxCl.ptr<uchar>(row);
float* ptrMaxVal = maxVal.ptr<float>(row);
for (int col = 0; col < cols; col++)
{
if (ptrScore[col] > ptrMaxVal[col])
{
ptrMaxVal[col] = ptrScore[col]; // 概率
ptrMaxCl[col] = (uchar)c; // 类别
}
}
}
}
// look up colors
Mat result = Mat::zeros(rows, cols, CV_8UC3);
for (int row = 0; row < rows; row++) {
const uchar* ptrMaxCl = maxCl.ptr<uchar>(row);
Vec3b* ptrColor = result.ptr<Vec3b>(row);
for (int col = 0; col < cols; col++)
{
ptrColor[col] = colors[ptrMaxCl[col]]; // 取出每一个像素类别所对应的颜色 共21类
}
}
Mat dst;
addWeighted(frame, 0.3, result, 0.7, 0, dst);
imshow("FCN-demo", dst);
waitKey(0);
return 0;
}
vector<Vec3b> readColors()
{
vector<Vec3b> objColors;
ifstream fin(labelFile);
if (!fin.is_open())
{
cout << "could not load labeFile……" << endl;
exit(-1);
}
string line;
while (!fin.eof())
{
getline(fin, line);
if (line.length())
{
//string temp = color.substr(color.find(" ") + 1);
stringstream ss(line);
string name;
int temp;
Vec3b color;
ss >> name;
ss >> temp;
color[0] = (uchar)temp;
ss >> temp;
color[1] = (uchar)temp;
ss >> temp;
color[2] = (uchar)temp;
objColors.push_back(color);
}
}
return objColors;
}
vector<String> readLabels()
{
vector<String> objNames;
ifstream fin(labelFile);
if (!fin.is_open())
{
cout << "could not load labeFile……" << endl;
exit(-1);
}
string name;
while (!fin.eof())
{
getline(fin, name);
if (name.length() && (name.find("display_name:") == 2))
{
string temp = name.substr(17);
temp.replace(temp.end() - 1, temp.end(), "");
objNames.push_back(temp);
}
}
return objNames;
}
结果展示:
CNN预测年龄和性别
介绍:
论文:https://talhassner.github.io/home/projects/cnn_agegender/CVPR2015_CNN_AgeGenderEstimation.pdf
模型以及描述文件下载:
https://talhassner.github.io/home/publication/2015_CVPR
使用模型的方式与之前的差不多,我自己写了一个,但是感觉年龄识别结果相当不准。
代码1:
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/dnn/dnn.hpp>
#include <iostream>
#include <fstream>
using namespace std;
using namespace cv;
using namespace cv::dnn;
string age_labels[] = { "0-2", "4-6", "8-13", "15-20", "25-32", "38-43", "48-53", "60-"};
string age_model_file = "model/ageClassication/age_net.caffemodel";
string age_model_prototxt = "model/ageClassication/deploy_age.prototxt";
string gender_labels[] = { "man", "woman"};
string gender_model_file = "model/genderClassication/gender_net.caffemodel";
string gender_model_prototxt = "model/genderClassication/deploy_gender.prototxt";
int main(int argc, char** argv)
{
system("color 0A");
// 加载图片
Mat img = imread("pictures/boy.jpg");
if (img.empty())
{
cout << "could not load img……" << endl;
return -1;
}
namedWindow("input", CV_WINDOW_AUTOSIZE);
imshow("input", img);
// 加载网络模型
Net age_net = readNetFromCaffe(age_model_prototxt, age_model_file);
if (age_net.empty())
{
cout << "could not load Net age_model……" << endl;
exit(-1);
}
Net gender_net = readNetFromCaffe(gender_model_prototxt, gender_model_file);
if (gender_net.empty())
{
cout << "could not load Net gender_model……" << endl;
exit(-1);
}
// 预测
Mat input = blobFromImage(img, 1.0, Size(227, 227));
age_net.setInput(input, "data");
Mat age_prob = age_net.forward("prob");
gender_net.setInput(input, "data");
Mat gender_prob = gender_net.forward("prob");
// 在图像上表示结果
Point age_class_Number;
double age_class_Prob;
Mat age_probMat = age_prob.reshape(1, 1);
minMaxLoc(age_probMat, NULL, &age_class_Prob, NULL, &age_class_Number);
int age_index = age_class_Number.x;
cout << "对象年龄为:" << age_labels[age_index] << endl;
cout << "概率为:" << age_class_Prob << endl;
Point gender_class_Number;
double gender_class_Prob;
Mat gender_probMat = gender_prob.reshape(1, 1);
minMaxLoc(gender_prob, NULL, &gender_class_Prob, NULL, &gender_class_Number);
int gender_index = gender_class_Number.x;
cout << "对象性别为:" << gender_labels[gender_index] << endl;
cout << "概率为:" << gender_class_Prob << endl;
putText(img, "age:" + age_labels[age_index], Point(20, 20), FONT_HERSHEY_PLAIN, 1.5, Scalar(0, 0, 255), 1, 8);
putText(img, "gender:" + gender_labels[gender_index], Point(20, 40), FONT_HERSHEY_PLAIN, 1.5, Scalar(0, 255, 0), 1, 8);
namedWindow("results", CV_WINDOW_AUTOSIZE);
imshow("results", img);
waitKey(0);
return 0;
}
结果1展示:
把小孩识别成38-43岁……
视频里边用了一个文件haarcascade_frontalface_alt_tree.xml,先把人脸部分提取出来了:
主要使用了一个多尺度检测的函数detectMultiScale(),得到人脸所在的矩形区域,能够检测出来一张图片中的多张人脸。
然后直接把人脸部分输入,其他地方和上面的差不多。
代码2:
#include <opencv2/opencv.hpp>
#include <opencv2/dnn.hpp>
#include <iostream>
using namespace cv;
using namespace cv::dnn;
using namespace std;
String haar_file = "D:/opencv/build/etc/haarcascades/haarcascade_frontalface_alt_tree.xml";
String age_model = "model/ageClassication/age_net.caffemodel";
String age_text = "model/ageClassication/deploy_age.prototxt";
String gender_model = "model/genderClassication/gender_net.caffemodel";
String gender_text = "model/genderClassication/deploy_gender.prototxt";
void predict_age(Net& net, Mat image);
void predict_gender(Net& net, Mat image);
int main(int argc, char** argv) {
Mat src = imread("pictures/mutiFace1.jpg");
if (src.empty()) {
printf("could not load image...\n");
return -1;
}
namedWindow("input", CV_WINDOW_AUTOSIZE);
imshow("input", src);
// 检测人脸区域
CascadeClassifier detector;
detector.load(haar_file);
vector<Rect> faces;
Mat gray;
cvtColor(src, gray, COLOR_BGR2GRAY);
detector.detectMultiScale(gray, faces, 1.02, 1, 0, Size(40, 40), Size(1000, 1000));
// 加载网络模型
Net age_net = readNetFromCaffe(age_text, age_model);
Net gender_net = readNetFromCaffe(gender_text, gender_model);
for (size_t t = 0; t < faces.size(); t++) {
rectangle(src, faces[t], Scalar(30, 255, 30), 2, 8, 0);
predict_age(age_net, src(faces[t])); // 将人脸区域作为感兴趣区域输入网络
predict_gender(age_net, src(faces[t]));
}
imshow("age-gender-prediction-demo", src);
waitKey(0);
return 0;
}
vector<String> ageLabels() {
vector<String> ages;
ages.push_back("0-2");
ages.push_back("4 - 6");
ages.push_back("8 - 13");
ages.push_back("15 - 20");
ages.push_back("25 - 32");
ages.push_back("38 - 43");
ages.push_back("48 - 53");
ages.push_back("60-");
return ages;
}
void predict_age(Net& net, Mat image) {
// 输入
Mat blob = blobFromImage(image, 1.0, Size(227, 227));
net.setInput(blob, "data");
// 预测分类
Mat prob = net.forward("prob");
Mat probMat = prob.reshape(1, 1);
Point classNum;
double classProb;
vector<String> ages = ageLabels();
minMaxLoc(probMat, NULL, &classProb, NULL, &classNum);
int classidx = classNum.x;
putText(image, format("age:%s", ages.at(classidx).c_str()), Point(2, 10), FONT_HERSHEY_PLAIN, 0.8, Scalar(0, 0, 255), 1);
}
void predict_gender(Net& net, Mat image) {
// 输入
Mat blob = blobFromImage(image, 1.0, Size(227, 227));
net.setInput(blob, "data");
// 预测分类
Mat prob = net.forward("prob");
Mat probMat = prob.reshape(1, 1);
putText(image, format("gender:%s", (probMat.at<float>(0, 0) > probMat.at<float>(0, 1) ? "M" : "F")),
Point(2, 20), FONT_HERSHEY_PLAIN, 0.8, Scalar(0, 0, 255), 1);
}
结果2展示:
GOTURN模型实现对象跟踪
介绍:
GOTURN(Generic Object Tricking Using Regression Networks)使用回归网络进行追踪
资料参考:https://zhuanlan.zhihu.com/p/25338674
算法框架
整个算法的框架其实非常简单:输入当前帧和前一帧进入网络,输出当前帧bounding-box的位置。
输入输出
网络输出目标在search region上的相对坐标(top-left和bottom-right)。
模型下载:
https://github.com/opencv/opencv_extra/tree/c4219d5eb3105ed8e634278fad312a1a8d2c182d/testdata/tracking
note: 这四个压缩包都得下载,否则会解压出错。
可以参考opencv的samples里边的例子:https://github.com/opencv/opencv_contrib/blob/3.3.1/modules/tracking/samples/goturnTracker.cpp
该网络输入为上一帧要追踪的区域data1和当前帧区域data2,输出为单通道4×1的Mat:
表示上一帧中要追踪的box在当前帧中预测的box的位置(左上角和右下角坐标)。
输入:
input: "data1"
input_dim: 1
input_dim: 3
input_dim: 227
input_dim: 227
input: "data2"
input_dim: 1
input_dim: 3
input_dim: 227
input_dim: 227
代码:
#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/dnn/dnn.hpp>
#include <opencv2/video/video.hpp>
#include <iostream>
#include <fstream>
using namespace std;
using namespace cv;
using namespace cv::dnn;
string model_file = "model/GOTURN/goturn.caffemodel";
string model_prototxt = "model/GOTURN/goturn.prototxt";
Net net;
Rect trackObjects(Mat& frame, Mat& prevFrame);
Mat frame, prevFrame;
Rect prevBB;
int main(int argc, char** argv) {
net = readNetFromCaffe(model_prototxt, model_file);
if (net.empty())
{
cout << "could not load model file……";
exit(-1);
}
VideoCapture capture;
capture.open("pictures/vtest.avi");
capture.read(frame);
frame.copyTo(prevFrame);
prevBB = selectROI(frame, false, false);
namedWindow("frame", CV_WINDOW_AUTOSIZE);
while (capture.read(frame)) {
Rect currentBB = trackObjects(frame, prevFrame);
rectangle(frame, currentBB, Scalar(0, 0, 255), 2, 8, 0);
// ready for next frame
frame.copyTo(prevFrame);
prevBB.x = currentBB.x;
prevBB.y = currentBB.y;
prevBB.width = currentBB.width;
prevBB.height = currentBB.height;
imshow("frame", frame);
char c = waitKey(50);
if (c == 27) {
break;
}
}
}
Rect trackObjects(Mat& frame, Mat& prevFrame) {
Rect rect;
int INPUT_SIZE = 227;
//Using prevFrame & prevBB from model and curFrame GOTURN calculating curBB
Mat curFrame = frame.clone();
Rect2d curBB;
float padTargetPatch = 2.0;
Rect2f searchPatchRect, targetPatchRect;
Point2f currCenter, prevCenter;
Mat prevFramePadded, curFramePadded;
Mat searchPatch, targetPatch;
// 上一帧box的中心
prevCenter.x = (float)(prevBB.x + prevBB.width / 2);
prevCenter.y = (float)(prevBB.y + prevBB.height / 2);
// 接受padTargetPatch倍的背景
targetPatchRect.width = (float)(prevBB.width * padTargetPatch);
targetPatchRect.height = (float)(prevBB.height * padTargetPatch);
targetPatchRect.x = prevCenter.x + targetPatchRect.width / 2.0; // 这里因为下面使用的是边界填充之后的prevFramePadded,等于说又加了个targetPatchRect.width,所以这里是加targetPatchRect.width / 2.0
targetPatchRect.y = prevCenter.y + targetPatchRect.height / 2.0;
// 对上一帧边界进行填充,并提取出框出的目标targetPatch
copyMakeBorder(prevFrame, prevFramePadded, (int)targetPatchRect.height, (int)targetPatchRect.height, (int)targetPatchRect.width, (int)targetPatchRect.width, BORDER_REPLICATE);
targetPatch = prevFramePadded(targetPatchRect).clone();
// 对当前帧边界进行填充,并提取出目标targetPatch
copyMakeBorder(curFrame, curFramePadded, (int)targetPatchRect.height, (int)targetPatchRect.height, (int)targetPatchRect.width, (int)targetPatchRect.width, BORDER_REPLICATE);
searchPatch = curFramePadded(targetPatchRect).clone();
//Preprocess
//Resize
resize(targetPatch, targetPatch, Size(INPUT_SIZE, INPUT_SIZE));
resize(searchPatch, searchPatch, Size(INPUT_SIZE, INPUT_SIZE));
//Mean Subtract
targetPatch = targetPatch - 128;
searchPatch = searchPatch - 128;
//Convert to Float type
targetPatch.convertTo(targetPatch, CV_32F);
searchPatch.convertTo(searchPatch, CV_32F);
Mat targetBlob = blobFromImage(targetPatch);
Mat searchBlob = blobFromImage(searchPatch);
net.setInput(targetBlob, "data1");
net.setInput(searchBlob, "data2");
Mat res = net.forward("scale");
Mat resMat = res.reshape(1, 1);
//printf("width : %d, height : %d\n", (resMat.at<float>(2) - resMat.at<float>(0)), (resMat.at<float>(3) - resMat.at<float>(1)));
curBB.x = (double)targetPatchRect.x + (double)(resMat.at<float>(0) * targetPatchRect.width / INPUT_SIZE) - (double)targetPatchRect.width;
curBB.y = (double)targetPatchRect.y + (double)(resMat.at<float>(1) * targetPatchRect.height / INPUT_SIZE) - (double)targetPatchRect.height;
curBB.width = (resMat.at<float>(2) - resMat.at<float>(0)) * targetPatchRect.width / INPUT_SIZE;
curBB.height = (resMat.at<float>(3) - resMat.at<float>(1)) * targetPatchRect.height / INPUT_SIZE;
//Predicted BB
Rect boundingBox = curBB;
return boundingBox;
}
结果展示:
更多推荐
所有评论(0)