最初实现

        for (auto i = 0; i < img.rows; i++)
        {
            auto *ptr = img.ptr<cv::Vec3b>(i);
            for (auto j = 0; j < img.cols; j++)
            {
                auto &pix = ptr[j];
                for (auto k = 0; k < img.channels(); ++k)
                {
                    if (pix[k] > n)
                        pix[k] = 255;
                    else
                        pix[k] = 0;
                }
            }
        }

Benchmark Time CPU Iterations

BM_threshold/128 1180128 ns 1179980 ns 675
BM_cvthreshold/128 110379 ns 106367 ns 6181

我自己的函数较之opencv落后十倍更多,所以我打算开始看opencv源码

threshold源码阅读

threshold函数入口:

int threshold(const uchar* src_data, size_t src_step, uchar* dst_data, size_t dst_step, int width, int height, int depth, int cn, double thresh, double maxValue, int thresholdType)
{
    return threshold_range(0, height, src_data, src_step, dst_data, dst_step, width, depth, cn, thresh, maxValue, thresholdType);
}

对外入口本身不做具体计算,它只是把整张图像 [0, height) 这一段交给 threshold_range()

那接下来我们来看**threshold_range**

它先做两件事:

  1. 把宽度乘上通道数
width *= cn;

说明这里不是按“像素对象”处理,而是按“连续元素流”处理,减少行列访问时的时间损耗,更像是在处理“一维数组”

  1. 根据 depth + thresholdType 选择threshold实现

接下来是真正做阈值计算的是模板版 threshold<helper, type>()

template<typename helper, int type, typename T = typename helper::ElemType>
static inline int threshold(...)

它的执行逻辑:

  1. 准备一个全 0 向量
auto zero = helper::vmv(0, helper::setvlmax());

很多阈值模式最终都要在 0maxValuesrcthresh 之间选,所以先准备一份全零向量,后面直接拿来 merge。

  1. 按行处理并拿到这一行的首地址,并强转为数组
    for (int i = start; i < end; i++)
    {
        const T* src = reinterpret_cast<const T*>(src_data + i * src_step);			//i为行号,通过行号*步长计算行数
        T* dst = reinterpret_cast<T*>(dst_data + i * dst_step);
  1. 按段计算
vl0 = helper::setvl(width - j);
vl1 = helper::setvl(width - j - vl0);
auto src0 = helper::vload(src + j, vl0);
auto src1 = helper::vload(src + j + vl0, vl1);

使用rvv进行批量加载和计算

  1. 使用rvv::vmerge进行向量计算,并且匹配各种格式

优化代码

static inline std::array<uchar, 256> makeThresholdLUT(int thresh, uchar maxVal = 255)
{
    std::array<uchar, 256> lut{};
    thresh = std::clamp(thresh, 0, 255);
    for (int i = 0; i < 256; ++i)
    {
        lut[i] = (i > thresh) ? maxVal : 0;
    }
    return lut;
}

void threshold_fast_parallel(const cv::Mat& src, cv::Mat& dst, int thresh, uchar maxVal = 255)
{
    const int channels = src.channels();
    dst.create(src.size(), src.type());
    const auto lut = makeThresholdLUT(thresh, maxVal);

    cv::parallel_for_(cv::Range(0, src.rows), [&](const cv::Range& range)
    {
        for (int r = range.start; r < range.end; ++r)
        {
            const uchar* s = src.ptr<uchar>(r);
            uchar* d = dst.ptr<uchar>(r);
            const int rowBytes = src.cols * channels;

            int c = 0;
            constexpr int kUnroll = 16;
            for (; c <= rowBytes - kUnroll; c += kUnroll)
            {
                d[c + 0]  = lut[s[c + 0]];
                d[c + 1]  = lut[s[c + 1]];
                d[c + 2]  = lut[s[c + 2]];
                d[c + 3]  = lut[s[c + 3]];
                d[c + 4]  = lut[s[c + 4]];
                d[c + 5]  = lut[s[c + 5]];
                d[c + 6]  = lut[s[c + 6]];
                d[c + 7]  = lut[s[c + 7]];
                d[c + 8]  = lut[s[c + 8]];
                d[c + 9]  = lut[s[c + 9]];
                d[c + 10] = lut[s[c + 10]];
                d[c + 11] = lut[s[c + 11]];
                d[c + 12] = lut[s[c + 12]];
                d[c + 13] = lut[s[c + 13]];
                d[c + 14] = lut[s[c + 14]];
                d[c + 15] = lut[s[c + 15]];
            }
            for (; c < rowBytes; ++c)
            {
                d[c] = lut[s[c]];
            }
        }
    });
}

Benchmark Time CPU Iterations

BM_threshold_fast/128 39377 ns 36453 ns 21471
BM_cvthreshold/128 11920 ns 11159 ns 72564

设计思路

这份实现的核心思路,其实就是把一个最朴素的 threshold:

dst = (src > thresh) ? maxVal : 0;

从“逐像素判断”的形式,改造成“先建规则表,再按表批量映射”的形式。

如果直接写朴素版本,内层循环里每处理一个字节,都要做一次比较和一次条件分支。图像一大,这种操作会重复很多很多次。这个实现不这么做,而是先利用 8 位图像像素值只可能在 0~255 之间这一特点,提前把 256 种输入的结果全算好,做成一张 LUT。这样后面真正扫图的时候,循环体就不再做阈值判断,而是直接查表。原来是“边遍历边决策”,现在变成了“先决策,再遍历时只执行映射”。

在这个基础上,它又做了第二层优化:并行。因为普通 threshold 每一行之间互不依赖,所以天然适合按行切分给多个线程处理。于是整个设计就变成了两层:第一层是 LUT,把每个像素的判断逻辑压缩成常量查表;第二层是 parallel_for_,把整张图按行分块,让多个线程同时做查表映射。再加上一点循环展开,减少内层循环控制开销,这就构成了这份实现的整体设计。

所以从本质上讲,这份代码不是靠复杂指令取胜,而是靠三个朴素但有效的原则:消掉分支、保持线性内存访问、利用多核并行


实现流程

整个实现流程其实很短,只有一条主线。

调用 threshold_fast_parallel 之后,函数先读取输入图像的一些基本信息,比如通道数,然后为输出图像 dst 创建和输入相同大小、相同类型的内存空间。紧接着,它会调用 makeThresholdLUT,根据当前的 threshmaxVal 生成一张 256 项的查找表。到这里,阈值规则其实已经完全确定了,后面不再需要做任何“像素值和阈值比较”的判断。

生成 LUT 之后,函数进入并行阶段。cv::parallel_for_ 会把图像的行区间 [0, rows) 拆成若干段,每个线程拿到其中一段行。每个线程在自己的行区间里,逐行取出源图像和目标图像的首地址,然后把这一行当成一段连续字节流处理。对于灰度图,一行字节数就是 cols;对于三通道图,一行字节数就是 cols * 3。接着代码对这一行做查表映射,把 src 里的每个字节读出来,作为下标访问 LUT,再把结果写到 dst 对应位置。

为了让内层循环更紧凑,它不是一次只处理 1 个字节,而是一次展开处理 16 个字节。这样做的目的不是改变算法本身,而是减少循环变量更新、边界判断这些额外指令的比例。等 16 个一组的大块处理完之后,再用一个收尾循环把这一行剩下不足 16 个的元素补完。所有线程都做完自己负责的行区间后,整张图的 threshold 就完成了。

所以这条实现流程可以概括成一句话:先生成 LUT,再按行并行遍历整张图,对每个字节执行 dst = lut[src]


makeThresholdLUT 的函数结构

这个函数虽然很短,但内部其实也可以分成两个大段。

第一大段是参数规范化。它先创建一个长度固定为 256 的数组,然后用 std::clamp 把传进来的 thresh 限制在 [0, 255] 范围内。这样做是为了保证这个阈值一定和 8 位图像的像素范围匹配。因为 LUT 的下标一定是 0 到 255,如果阈值本身跑到了这个范围外,就会让映射规则失真,所以要先修正。

第二大段是构造映射表。它用一个从 0 到 255 的循环,把每一个可能输入值的输出结果都提前算出来。对于每个 i,如果 i > thresh,那对应输出就是 maxVal;否则就是 0。这一步完成后,LUT 本身就已经等价于一个完整的二值化规则了。后面的图像处理不再关心 threshold 的数学含义,它只关心“源值是多少,就去表里取什么结果”。

所以这个函数的作用非常纯粹:它不是处理图像,而是把 threshold 规则编译成一张静态映射表


threshold_fast_parallel 的函数结构

这个函数内部可以分成三大段来看。

第一大段是准备阶段。它先从输入图像里取出 channels,因为后面每一行实际要处理多少个字节,取决于通道数。然后调用 dst.create(src.size(), src.type()),确保目标图像有和源图一样的大小和数据类型。最后调用 makeThresholdLUT 生成查表。这一段的作用是把后续处理所需的上下文全部准备好,也就是“输出缓冲区准备好、规则准备好”。

第二大段是并行分工阶段。这一段由 cv::parallel_for_ 驱动。它不会一次把整张图交给一个线程,而是把总行数切成若干个 Range。每个线程拿到一个 range 之后,只需要处理自己负责的 [range.start, range.end) 这一段行即可。这里的关键思想是,threshold 属于逐点独立运算,没有跨行依赖,所以按行切分不会产生同步问题。LUT 又是只读数据,所以多个线程共享同一张表也不会有竞争。这就让并行非常自然。

第三大段是单行处理阶段。这是 lambda 里面真正干活的地方。对于每一行,它先用 src.ptr<uchar>(r)dst.ptr<uchar>(r) 拿到这一行源数据和目标数据的首地址,再根据 src.cols * channels 算出这一行要处理的总字节数。然后进入核心循环。核心循环分成两层:前一层是按 16 个元素展开的大块处理,连续执行 16 次 d[c + k] = lut[s[c + k]];后一层是处理余数,也就是剩下不足 16 个元素的部分。这样一行结束后,再进入下一行,直到当前线程的所有行都处理完。

所以这个函数的结构其实很工整:前面做准备,中间做并行切分,里面对每一行做查表映射。

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐