CVPR论文《100+ Times Faster Weighted Median Filter (WMF)》的实现和解析(附源代码)。
                        
                     
                    
                    
                        四年前第一次看到《100+ Times FasterWeighted Median Filter (WMF)》一文时,因为他附带了源代码,而且还是CVPR论文,因此,当时也对代码进行了一定的整理和解读,但是当时觉得这个算法虽然对原始速度有不少的提高,但是还是比较慢。因此,没有怎么在意,这几天有几位朋友又提到这篇文章,于是把当时的代码和论文又仔细的研读了一番,对论文的思想和其中的实现也有了一些新的新的,再次做个总结和分享。
  这篇文章的官网地址是:http://www.cse.cuhk.edu.hk/~leojia/projects/fastwmedian/,其中主要作者Jiaya Jia教授的官网地址是:http://jiaya.me/,根据Jiaya Jia的说法,这个算法很快将被OpenCv所收录,到时候OpenCv的大神应该对他还有所改进吧。
  在百度上搜索加权中值模糊,似乎只有一篇博客对这个文章进行了简单的描述,详见:https://blog.csdn.net/streamchuanxi/article/details/79573302?utm_source=blogxgwz9。
  由于作者只给出了最后的优化实现代码,而论文中还提出了各种中间过程的时间,因此本文以实现和验证论文中有关说法为主,涉及到的理论知识比较肤浅,一般是一笔而过。
  根据论文中得说法,所谓的加权中值滤波,也是一种非线性的图像平滑技术,他取一个局部窗口内所有像素的加权中值来代替局部窗口的中心点的值。用较为数学的方法表示如下:
  在图像I中的像素p,我们考虑以p为中心,半径为R的局部窗口,不同于普通的中值模糊,对于属于内每一个像素q,都有一个基于对应的特征图像的相似度的权重系数wpq,如下式所示:
                           
  f(p)和f(q)是像素p和q在对应的特征图中得特征值。g是一个权重函数,最常用的即为高斯函数,反应了像素p和q的相似程度。
  我们用I(q)表示像素点q的像素值,在窗口内的像素总数量用n表示,则n=(2r+1)*(2r+1),那么窗口内像素值和权重值构成一个对序列,即,对这个序列按照I(q)的值进行排序。排序后,我们依次累加权重值,直到累加的权重大于等于所有权重值的一半时停止,此时对应的I(q)即作为本局部窗口中心点的新的像素值。
                               
  很明显,上面的过程要比标准的中值模糊复杂一些,在处理时多了特征图和权重函数项,而标准的中值模糊我们可以认为是加权中值模糊的特例,即所有局部窗口的权重都为1或者说相等。
  在这里,特征图可以直接是源图像,也可以是其他的一些特征,比如原图像的边缘检测结果、局部均方差、局部熵或者其他的更为高级的特征。
  按照这个定义,我们给出一段针对灰度数据的Brute-force处理代码:
复制代码
int __cdecl ComparisonFunction(const void *X, const void *Y)        //    一定要用__cdecl这个标识符
{
    Value_Weight VWX = *(Value_Weight *)X;
    Value_Weight VWY = *(Value_Weight *)Y;
    if (VWX.Value < VWY.Value)
        return -1;
    else if (VWX.Value > VWY.Value)
        return +1;
    else
        return 0;
}
//    加权中值模糊,直接按照算法的定义实现。
//    Input        -    输入图像,灰度图,LevelV = 256级
//    FeatureMap    -    特征图像,灰度图,LevelF = 256级
//    Weight        -    特征的权重矩阵,大小是LevelF * LevelF
//    Output        -    输出图像,不能和Input为同一个数据。
int IM_WeightedMedianBlur_00(unsigned char *Input, unsigned char *FeatureMap, float *Weight, unsigned char *Output, int Width, int Height, int Stride, int Radius)
{
    int Channel = Stride / Width;
    if ((Input == NULL) || (Output == NULL))                                        return IM_STATUS_NULLREFRENCE;
    if ((FeatureMap == NULL) || (Weight == NULL))                                    return IM_STATUS_NULLREFRENCE;
    if ((Width <= 0) || (Height <= 0) || (Radius <= 0))                              return IM_STATUS_INVALIDPARAMETER;
    if ((Channel != 1))                                                      return IM_STATUS_NOTSUPPORTED;
    const int LevelV = 256;                //    Value 可能出现的不同数量
    const int LevelF = 256;                //    Feature 可能出现的不同数量
    Value_Weight *VW = (Value_Weight *)malloc((2 * Radius + 1) * (2 * Radius + 1) * sizeof(Value_Weight));            //    值和特征序列对内存
    if (VW == NULL)    return IM_STATUS_OK;
    for (int Y = 0; Y < Height; Y++)
    {
        unsigned char *LinePF = FeatureMap + Y * Stride;
        unsigned char *LinePD = Output + Y * Stride;
        for (int X = 0; X < Width; X++)
        {
            int CF_Index = LinePF[X] * LevelF;
            int PixelAmount = 0;
            float SumW = 0;
            for (int J = IM_Max(Y - Radius, 0); J <= IM_Min(Y + Radius, Height - 1); J++)
            {
                int Index = J * Stride;
                for (int I = IM_Max(X - Radius, 0); I <= IM_Min(X + Radius, Width - 1); I++)        //    注意越界
                {
                    int Value = Input[Index + I];                            //    值
                    int Feature = FeatureMap[Index  + I];                    //    特征
                    float CurWeight = Weight[CF_Index + Feature];            //    对应的权重
                    VW[PixelAmount].Value = Value;
                    VW[PixelAmount].Weight = CurWeight;                        //    保存数据
                    SumW += CurWeight;                                        //    计算累加数据
                    PixelAmount++;                                            //    有效的数据量    
                }
            }
            float HalfSumW = SumW * 0.5f;                                    //    一半的权重
            SumW = 0;
            qsort(VW, PixelAmount, sizeof VW[0], &ComparisonFunction);        //    调用系统的qsort按照Value的值从小到大排序,注意qsort的结果仍然保存在第一个参数中
            for (int I = 0; I < PixelAmount; I++)                            //    计算中值
            {
                SumW += VW[I].Weight;
                if (SumW >= HalfSumW)
                {
                    LinePD[X] = VW[I].Value;
                    break;
                }
            }
        }
    }
    free(VW);
    return IM_STATUS_OK;
}
复制代码
  很明显,这个函数的时间复杂度是o(radius * radius),空间复杂度到时很小。
  我们在一台 I5,3.3GHZ的机器上进行了测试,上述代码处理一副1000*1000像素的灰度图,半径为10(窗口大小21*21)时,处理时间约为27s,论文里给的Cpu和我的差不多,给出的处理one - metalpixel的RGB图用时90.7s,考虑到RGB的通道的数据量以及一些其他的处理,应该说论文如实汇报了测试数据。
  那么从代码优化上面讲,上面代码虽然还有优化的地方,但是都是小打小闹了。使用VS的性能分析器,可以大概获得如下的结果:
       
  可见核心代码基本都用于排序了,使用更快的排序有助于进一步提高速度。
  针对这个情况,论文的作者从多方面提出了改进措施,主要有三个方面,我们简单的重复下。
  一、联合直方图(Joint Histgram)
  直方图优化在很多算法中都有应用,比如标准的中值滤波,现在看到的最快的实现方式还是基于直方图的,详见:任意半径中值滤波(扩展至百分比滤波器)O(1)时间复杂度算法的原理、实现及效果,但是在加权中值滤波中,传统的一维直方图已经无法应用,因为这个算法不仅涉及到原图的像素值,还和另外一幅特征图有关,因此,文中提出了联合直方图,也是一种二维直方图。
  如果图像中的像素最多有LevelV个不同值,其对应的特征最多有LevelF个不同的值,那么我们定义一个宽和高分别为LevelV * LevelF大小的直方图。对于某一个窗口,统计其内部的(2r+1)*(2r+1)个像素和特征对的直方图数据,即如果某个点的像素值为V,对应的特征值为F,则相应位置的直方图数据加1。
  如果我们统计出这个二维的直方图数据后,由于中心点的特征值是固定的,因此,对于直方图的每一个LevelF值,权重是一定的了,我们只需计算出直方图内每一个Value值所对应所有的Feature的权重后,就可方便的统计出中值所在的位置了。
  那么如果每个像素点都进行领域直方图的计算,这个的工作量也是蛮大的,同一维直方图的优化思路一样,在进行逐像素行处理的时候,对直方图数据可以进行逐步的更新,去除掉移走的那一列的直方图信息,在加入即将进入那一列数据,而中间重叠部分则不需要调整。
  按照论文中的Joint Histgram的布局,即行方向大小为LevelV,列方向大小为LevelF,编制Joint Histgram实现的加权中值算法代码如下所示:
复制代码
//    加权中值模糊,基于论文中图示的内存布局设置的Joint Histgram。 
int IM_WeightedMedianBlur_01(unsigned char *Input, unsigned char *FeatureMap, float *Weight, unsigned char *Output, int Width, int Height, int Stride, int Radius)
{
    int Channel = Stride / Width;
    if ((Input == NULL) || (Output == NULL))                                        return IM_STATUS_NULLREFRENCE;
    if ((FeatureMap == NULL) || (Weight == NULL))                                    return IM_STATUS_NULLREFRENCE;    
    if ((Width <= 0) || (Height <= 0) || (Radius <= 0))                                return IM_STATUS_INVALIDPARAMETER;
    if ((Channel != 1) && (Channel != 3))                                            return IM_STATUS_NOTSUPPORTED;
    int Status = IM_STATUS_OK;
    const int LevelV = 256;                //    Value 可能出现的不同数量
    const int LevelF = 256;                //    Feature 可能出现的不同数量
    int *Histgram = (int *)malloc(LevelF * LevelV * sizeof(int));
    float *Sum = (float *)malloc(LevelV * sizeof(float));
    if ((Histgram == NULL) || (Sum == NULL))
    {
        Status = IM_STATUS_OUTOFMEMORY;
        goto FreeMemory;
    }
    for (int Y = 0; Y < Height; Y++)
    {
        unsigned char *LinePF = FeatureMap + Y * Stride;
        unsigned char *LinePD = Output + Y * Stride;
        memset(Histgram, 0, LevelF * LevelV * sizeof(int));
        for (int J = IM_Max(Y - Radius, 0); J <= IM_Min(Y + Radius, Height - 1); J++)
        {
            for (int I = IM_Max(0 - Radius, 0); I <= IM_Min(0 + Radius, Width - 1); I++)
            {
                int Value = Input[J * Stride + I];
                int Feature = FeatureMap[J * Stride + I];        //    统计二维直方图
                Histgram[Feature * LevelV + Value]++;
            }
        }
        for (int X = 0; X < Width; X++)
        {
            int Feature = LinePF[X];
            float SumW = 0, HalfSumW = 0;;
            for (int I = 0; I < LevelV; I++)
            {
                float Cum = 0;
                for (int J = 0; J < LevelF; J++)        //    计算每个Value列针对的不同的Feature的权重的累计值
                {
                    Cum += Histgram[J * LevelV + I] * Weight[J * LevelF + Feature];
                }
                Sum[I] = Cum;
                SumW += Cum;
            }
            HalfSumW = SumW / 2;
            SumW = 0;
            for (int I = 0; I < LevelV; I++)
            {
                SumW += Sum[I];
                if (SumW >= HalfSumW)                //    计算中值
                {
                    LinePD[X] = I;
                    break;
                }
            }
            if ((X - Radius) >= 0)                    //    移出的那一列的直方图
            {
                for (int J = IM_Max(Y - Radius, 0); J <= IM_Min(Y + Radius, Height - 1); J++)
                {
                    int Value = Input[J * Stride + X - Radius];
                    int Feature = FeatureMap[J * Stride + X - Radius];
                    Histgram[Feature * LevelV + Value]--;
                }
            }
            if ((X + Radius + 1) <= Width - 1)        //    移入的那一列的直方图
            {
                for (int J = IM_Max(Y - Radius, 0); J <= IM_Min(Y + Radius, Height - 1); J++)
                {
                    int Value = Input[J * Stride + X + Radius + 1];
                    int Feature = FeatureMap[J * Stride + X + Radius + 1];
                    Histgram[Feature * LevelV + Value]++;
                }
            }
        }
    }
FreeMemory:
    if (Histgram != NULL)    free(Histgram);
    if (Sum != NULL)        free(Sum);
    return Status;
}
复制代码
  编译后测试,同样是21*21的窗口,one - metalpixel的灰度图像计算用时多达108s,比直接实现慢很多了。
  分析原因,核心就是在中值的查找上,由于我们采用的内存布局方式,导致计算每个Value对应的权重累加存在的大量的Cache miss现象,即下面这条语句:
复制代码
for (int J = 0; J < LevelF; J++)        //    计算每个Value列针对的不同的Feature的权重的累计值
{
    Cum += Histgram[J * LevelV + I] * Weight[J * LevelF + Feature];
}
复制代码
  我们换种Joint Histgram的布局,即行方向大小为LevelF,列方向大小为LevelV,此时的代码如下:
复制代码
//    加权中值模糊,修改内存布局设置的Joint Histgram。 
int IM_WeightedMedianBlur_02(unsigned char *Input, unsigned char *FeatureMap, float *Weight, unsigned char *Output, int Width, int Height, int Stride, int Radius)
{
    int Channel = Stride / Width;
    if ((Input == NULL) || (Output == NULL))                                        return IM_STATUS_NULLREFRENCE;
    if ((FeatureMap == NULL) || (Weight == NULL))                                    return IM_STATUS_NULLREFRENCE;
    if ((Width <= 0) || (Height <= 0) || (Radius <= 0))                                return IM_STATUS_INVALIDPARAMETER;
    if ((Channel != 1) && (Channel != 3))                                            return IM_STATUS_NOTSUPPORTED;
    int Status = IM_STATUS_OK;
    const int LevelV = 256;                //    Value 可能出现的不同数量
    const int LevelF = 256;                //    Feature 可能出现的不同数量
    int *Histgram = (int *)malloc(LevelF * LevelV * sizeof(int));
    float *Sum = (float *)malloc(LevelV * sizeof(float));
    if ((Histgram == NULL) || (Sum == NULL))
    {
        Status = IM_STATUS_OUTOFMEMORY;
        goto FreeMemory;
    }
    for (int Y = 0; Y < Height; Y++)
    {
        unsigned char *LinePF = FeatureMap + Y * Stride;
        unsigned char *LinePD = Output + Y * Stride;
        memset(Histgram, 0, LevelF * LevelV * sizeof(int));
        for (int J = IM_Max(Y - Radius, 0); J <= IM_Min(Y + Radius, Height - 1); J++)
        {
            int Index = J * Stride;
            for (int I = IM_Max(0 - Radius, 0); I <= IM_Min(0 + Radius, Width - 1); I++)
            {
                int Value = Input[J * Stride + I];
                int Feature = FeatureMap[J * Stride + I];
                Histgram[Value * LevelF + Feature]++;            //    注意索引的方式的不同
            }
        }
        for (int X = 0; X < Width; X++)
        {
            int IndexF = LinePF[X] * LevelF;
            float SumW = 0, HalfSumW = 0;;
            for (int I = 0; I < LevelV; I++)
            {
                float Cum = 0;
                int Index = I * LevelF;
                for (int J = 0; J < LevelF; J++)        //    核心就这里不同
                {
                    Cum += Histgram[Index + J] * Weight[IndexF + J];
                }
                Sum[I] = Cum;
                SumW += Cum;
            }
            HalfSumW = SumW / 2;
            SumW = 0;
            for (int I = 0; I < LevelV; I++)
            {
                SumW += Sum[I];
                if (SumW >= HalfSumW)
                {
                    LinePD[X] = I;
                    break;
                }
            }
            if ((X - Radius) >= 0)
            {
                for (int J = IM_Max(Y - Radius, 0); J <= IM_Min(Y + Radius, Height - 1); J++)
                {
                    int Value = Input[J * Stride + X - Radius];
                    int Feature = FeatureMap[J * Stride + X - Radius];
                    Histgram[Value * LevelF + Feature]--;
                }
            }
            if ((X + Radius + 1) <= Width - 1)
            {
                for (int J = IM_Max(Y - Radius, 0); J <= IM_Min(Y + Radius, Height - 1); J++)
                {
                    int Value = Input[J * Stride + X + Radius + 1];
                    int Feature = FeatureMap[J * Stride + X + Radius + 1];
                    Histgram[Value * LevelF + Feature]++;
                }
            }
        }
    }
FreeMemory:
    if (Histgram != NULL)    free(Histgram);
    if (Sum != NULL)        free(Sum);
    return Status;
}
复制代码
  修改后,同样的测试条件和图片,速度提升到了17s,仅仅是更改了一个内存布局而已,原论文的图没有采用这种布局方式,也许只是为了表达算法清晰而已。
  和原论文比较,原论文的joint histgram时间要比直接实现慢(156.9s vs 90.7s),而我这里的一个版本比brute force的快,一个比brute force的慢,因此,不清楚作者在比较时采用了何种编码方式,但是这都不重要,因为他们的区别都还在一个数量级上。
       由于直方图大小是固定的,因此,前面的中值查找的时间复杂度是固定的,而后续的直方图更新则是o(r)的,但是注意到由于LevelV和 LevelF通常都是比较大的常数(一般为256),因此实际上,中值查找这一块的耗时占了绝对的比例。
 二、快速中值追踪
  寻找中值的过程实际上可以看成一个追求平衡的过程,假定当前搜索到的位置是V,位于V左侧所有相关值的和是Wl,位于V右侧所有相关值得和是Wr,则中值的寻找可以认为是下式:
                          
  后面的约束条件可以理解为第一次出现Wl大于Wr前。