024 随机选择算法,BFPRT算法

本博客讲解了在无序数组中寻找第K大元素的两种方法:随机选择算法和BFPRT算法。随机选择算法基于快排思想,期望时间复杂度为O(n),但证明较复杂。BFPRT算法通过选取“中位数的中位数”作为pivot,保证最坏情况下时间复杂度也为O(n),但实现稍复杂。博客提供了两种算法的Java代码,并强调BFPRT算法的意义在于其规避最坏情况,优化算法的思想。

Posted by Hilda on March 30, 2025

前置知识:讲解023-随机快速排序 笔记

下面的我是指左老师本人

本节视频链接

本节leetcode链接:https://leetcode.cn/problems/xx4gT2/description/

本题与主站 215 题相同: https://leetcode-cn.com/problems/kth-largest-element-in-an-array/

无序数组中寻找第K大的数

给定整数数组 nums 和整数 k,请返回数组中第 k 个最大的元素。

请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

你必须设计并实现时间复杂度为 O(n) 的算法解决此问题。

利用改写快排的方法,时间复杂度O(n),额外空间复杂度O(1)

上面问题的解法就是随机选择算法,是常考内容!本视频定性讲述,定量证明略,算法导论-9.2有详细证明

不要慌!

随机快速排序、随机选择算法,时间复杂度的证明理解起来很困难,只需记住结论,但并不会对后续的算法学习造成什么影响

因为数学很好才能理解的算法和数据结构其实比较少,绝大部分的内容都只需要高中数学的基础就能理解

算法导论第9章,还有一个BFPRT算法,不用随机选择一个数的方式,也能做到时间复杂度O(n),额外空间复杂度O(log n)

早些年我还讲这个算法,不过真的很冷门,很少在笔试、面试、比赛场合出现,所以算了。有兴趣的同学可以研究一下


LCR 076. 数组中的第 K 个最大元素

LCR 076. 数组中的第 K 个最大元素

首先,解决了第XX大,也就类似于解决了第XX小元素的问题。

例如有100个数,找第2大和找第98小,是一样的。


首先说明一个思路:

如果要找到第1小,那么就是求排序完的数组最后,0位置的数;如果要找第16小,就是找排序完的数组最后,15位置的数。

image-20250329204205661

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class Solution {
    public static int first, last;

    public int findKthLargest(int[] nums, int k) {
        return randomizedSelect(nums, nums.length - k);
    }
    // 从小到大排完序  第i位置的数

    public int randomizedSelect(int[] arr, int i) {
        int ans = 0;
        for (int l = 0, r = arr.length - 1; l <= r;) {
            // 随机
            int x = arr[l + (int)(Math.random() * (r - l + 1))];
            partition(arr, l, r, x);

            if (i < first) {
                r = first - 1;
            } else if (i > last) {
                l = last + 1;
            } else {
                ans = arr[i];
                break;
            }
        }
        return ans;
    }

    public void partition(int[] arr, int l, int r, int x) {
        first = l;
        last = r;
        int i = l;
        while (i <= last) {
            if (arr[i] < x) {
                swap(arr, first++, i++);
            } else if (arr[i] == x) {
                i++;
            } else {
                swap(arr, last--, i);
            }
        }
    }

    public void swap(int[] arr, int i, int j) {
        int temp = arr[i];
        arr[i] = arr[j];
        arr[j] = temp;
    }
}

目前已经是最优解了,虽然上面排名不高,但是已经做到时间复杂度O(n),空间复杂度O(1)了。

但是这个方法时间复杂度的证明(算法导论-9.2有详细证明)是需要计算期望的。(要考虑概率,但是下面的bfprt算法不涉及概率)

bfprt算法

BFPRT算法的名字来源于它的五位发明者:Blum, Floyd, Pratt, Rivest, Tarjan。这个算法的全称是“Blum-Floyd-Pratt-Rivest-Tarjan算法”,通常简称为BFPRT算法。它是由这五位计算机科学家在1973年合作提出的一种用于解决线性时间选择问题的算法,具体来说是找到一个数组中第k小的元素。

Manuel Blum:著名计算机科学家,以计算复杂性和密码学领域的贡献闻名。

Robert Floyd:图灵奖得主,以算法设计和分析(如Floyd-Warshall算法)著称。

Vaughan Pratt:在算法和数据结构领域有重要贡献,例如Pratt证书用于素数判定。

Ron Rivest:RSA加密算法的共同发明者之一,在算法和密码学领域影响深远。

Robert Tarjan:图算法大师,以发明如Tarjan算法(用于强连通分量)而闻名。

BFPRT通过精心选择pivot(枢轴),保证了最坏情况下的时间复杂度为O(n)。其核心思想是将数组分成若干小组,计算每组的中位数,再递归地从这些中位数中选出总体中位数作为pivot。

步骤如下:

  • 分组:将数组 A 划分为若干组,每组包含 5 个元素。如果数组长度 n 不是5的倍数,最后一组可能少于5个元素。分组只需遍历数组一遍,分配到组中,时间复杂度为 O(n)。分组通常是通过遍历数组并将元素分配到不同的组(可以用数组或列表表示)来完成的。如果不是通过遍历数组并将元素分配到不同的组,那么就是O(1)。
  • 排序:计算每组的中位数。对每组的5个元素进行排序,找出每组的中位数(第3小的元素)。这些中位数将用于进一步选择一个“优秀”的枢轴。每组有5个元素,排序一个小组的时间复杂度是常数,例如插入排序为O(\(5^2\))=O(1)。总共有 \(⌈n/5⌉\) 个小组,因此总时间复杂度为\(⌈n/5⌉ * O(1) = O(n)\)。如果不足5个,可以拿上中位数作为中位数。例如数组1, 2, 3, 4中上中位数就是2。
  • 将步骤2中得到的每组中位数组成一个新数组(称为 n),其长度为 \(⌈n/5⌉\)。对这个新数组递归调用BFPRT算法,找到 M 的中位数(即“中位数的中位数”),这个值将作为最终的枢轴 pivot。
    • 补充:本来这个问题要解决的是int bfprt(arr, k),即给定数组arr,返回第k大/小的数(返回类型是int类型)。现在只需要调用bfprt(m, N / 10)就可以得到新的pivot了。为什么是N/10?因为刚才说了,数组M的长度就是N/5,中位数不就是排完序,其位置是\(\frac{N}{5}*\frac{1}{2}=\frac{N}{10}\)。
    • 得到的这个枢纽就是“中位数中的中位数”,所以这个算法也叫“中位数的中位数算法”。
  • 后面的过程和随机选择算法做这道题类似了。其实上面这3步都是在做一件事:挑选枢纽。

为什么这个方法好?

可以思考对比下随机选择算法,随机选的情况下,无法控制小于枢纽、大于枢纽这些数的规模,但是BFPRT可以做到控制这样的规模。

对于随机选择排序,在证明其复杂度时,需要考虑很差的情况(这个时候时间复杂度是O(\(n^2\))),然后考虑其他各种情况,但是这些情况都是等概率的,然后求期望。

而BFPRT算法通过选择一个接近中位数的“优秀”枢轴(pivot),避免了快速选择算法(Quickselect)在最坏情况下的O(\(n^2\))时间复杂度,保证了O(n)的最坏情况时间复杂度。

详细来说,数组m中每个中位数代表一个5元素小组,小组中至少有3个元素(包括中位数本身)小于等于或大于等于 pivot。因此,至少有 \(⌈⌈n/5⌉/2⌉ * 3\)个元素小于等于 pivot,同样至少有这么多元素大于等于 pivot。

计算可得,至少 \(3n/10\) 个元素在 pivot 的一侧(具体取决于边界调整),剩下部分最多为\(7n/10\)。

则时间复杂度满足:\(T(n) = T(⌈n/5⌉) + T(7n/10) + O(n)\)

递归部分\(T(⌈n/5⌉) + T(7n/10)\)的规模系数为 \(1/5 + 7/10 = 9/10 < 1\)。

非递归部分为 O(n)。

根据主定理,a = 2,b = 10/7(保守估计),c = 1,\(a/b^c = 2/(10/7) = 14/10 > 1\),但实际 \(9/10 < 1\),表明总时间受线性项支配。

最终解得 T(n) = O(n)。

简单来说,就是每一次挑选出一个pivot,可以淘汰掉至少3/10这么多数。在剩下的7/10中寻找所谓的第k大/小的数(注意,这里所说的是最多,即最多要在剩下7/10中找)


回顾整个过程,说明时间复杂度:

(1)分组:5个数一组,时间复杂度O(N),如果不是通过遍历数组并将元素分配到不同的组,那么就是O(1)。

(2)每一小组排序,选出每个组的中位数,时间复杂度O(N)

(3)每个小组的中位数组成的数组m排序,时间复杂度O(N)

(4)调用bfprt(m, N / 10),得到优秀的枢纽,递归规模T(N/5)

(5)选出的pivot,利用荷兰国旗思想,把整个数组划分成三个部分:

  • <pivot
  • ==pivot
  • >pivot

其时间复杂度也是O(N)

(6)剩下递归的规模T(7N/10)。如果上面这一步,没有使得==pivot的规模“套住“k,那么还需要递归,那么根据上面已有的描述,得到的规模就是T(7N/10)

所以得到最终的表达式:\(T(n) = T(⌈n/5⌉) + T(7n/10) + O(n)\)


虽然master公式无法解决子过程规模不一样的情况(算法导论9.3详细证明为什么这个表达式的时间复杂度是O(N))

但是其实面试阶段,只需要记住,\(T(n) = T(⌈n/5⌉) + T(7n/10) + O(n)\)其时间复杂度就是O(N)

问题1:为什么要5个数一组?

3个数一组不行吗?为什么要5个数呢?为什么不7个数?

回答:1, 5, 7都能得出类似的表达式,也可以根据数学证明出(不会证没关系,但是数学家可以证明的)时间复杂度就是O(N)

其实无关痛痒,只不过这个算法是5个大牛发明的,大佬们喜欢用5个就5个呗。

问题2:看似平庸实则影响深远

BFPRT算法在计算机科学领域确实具有很高的地位,尤其是在算法设计与分析的理论研究中。它不仅解决了选择问题(selection problem)的最坏情况时间复杂度优化,还引入了一种重要的算法设计思想——“分而治之”与“优秀枢轴选择”的结合。这种思想对后续的算法研究和优化产生了深远影响。

简单来说,这个算法提出的思想:选择一个确定的、能够淘汰一定比例的枢纽,去优化整个行为,进而规避掉最差的情况,使得算法拥有严格的优秀时间复杂度。

这种方法保证了划分的“质量”,使得递归树的深度受控,从而将最坏情况时间复杂度优化到 O(n)。

通过结构化的预处理(如分组和中位数计算),将随机化的不确定性转化为确定性的保证。

这种方法影响了后续算法设计,例如在确定性算法中追求最坏情况性能的优化(如一些几何算法或图算法)。

“用分组和中位数选择一个好枢轴”,通过分治和数学保证,让算法在任何情况下都高效运行。这种思想就像是“与其随机猜一个答案,不如精心挑一个靠谱的起点”,对算法设计来说是一种智慧的体现。

代码

用上面的BFPRT算法解决力扣LCR 076题,解法如下:

image-20250330163545195

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class Solution {
        // 解法2:BFPRT算法
    // 算法步骤:
    // 1。分组:5个一组
    // 2。每个组利用插入排序进行排序,选出中位数
    // 3。组成中位数数组m,中位数数组再挑选中位数,得到pivot  ---》中位数的中位数算法
    // 4。利用 荷兰国旗问题  得到<x ==x >x 三个部分
    // 看要找的k/len-k是不是在 first  last 中
    // <x        ==x        >x
    //       first  last
    // 不在first与last之间递归找

    public int first, last;
    public int findKthLargest(int[] nums, int k) {
        return bfprt(nums, 0, nums.length - 1, nums.length - k);
    }

    public int bfprt(int[] arr, int L, int R, int index) {
        if (L == R) return arr[L];
        int pivot = mediansOfMedian(arr, L, R);
        partition(arr, L, R, pivot);
        if (index < first) {
            return bfprt(arr, L, first - 1, index);
        } else if (index > last) {
            return bfprt(arr, last + 1, R, index);
        } else {
            return arr[index];
        }
    }
    public void partition(int[] arr, int l, int r, int x) {
        first = l;
        last = r;
        int i = l;
        while (i <= last) {
            if (arr[i] < x) {
                swap(arr, first++, i++);
            } else if (arr[i] == x) {
                i++;
            } else {
                swap(arr, last--, i);
            }
        }
    }

    public int mediansOfMedian(int[] arr, int L, int R) {
        int size = R - L + 1, offset = size % 5 == 0 ? 0 : 1;
        int[] m = new int[size / 5 + offset];
        for (int team = 0; team < m.length ; team++) {
            int teamFirst = L + team * 5;
            m[team] = getMedian(arr, teamFirst, Math.min(R, teamFirst + 4));
        }
        return bfprt(m, 0, m.length - 1, m.length / 2);
    }
    public int getMedian(int[] arr, int L, int R) {
        insertionSort(arr, L, R);
        return arr[L + ((R - L) >> 1)];
    }

    public void insertionSort(int[] arr, int L, int R) {
        for (int i = L + 1; i <= R; i++) {
            for (int j = i - 1; j >= L && arr[j] > arr[j + 1]; j--) {
                swap(arr, j, j + 1);
            }
        }
    }

    public void swap(int[] arr, int i ,int j) {
        int temp = arr[i];
        arr[i] = arr[j];
        arr[j] = temp;
    }

    
}

bfprt递归改迭代

image-20250330164753342

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class Solution {
    // 解法2:BFPRT算法(迭代版)
    // 算法步骤:
    // 1。分组:5个一组
    // 2。每个组利用插入排序进行排序,选出中位数
    // 3。组成中位数数组m,中位数数组再挑选中位数,得到pivot  ---》中位数的中位数算法
    // 4。利用 荷兰国旗问题  得到<x ==x >x 三个部分
    // 看要找的k/len-k是不是在 first  last 中
    // <x        ==x        >x
    //       first  last
    // 不在first与last之间递归找

    public int first, last;

    public int findKthLargest(int[] nums, int k) {
        return bfprt(nums, 0, nums.length - 1, nums.length - k);
    }

    public int bfprt(int[] arr, int L, int R, int index) {
        if (L == R) return arr[L];
        int ans = arr[L];
        while (L <= R) {
            int pivot = mediansOfMedian(arr, L, R);
            partition(arr, L, R, pivot);
            if (index < first) {
                R = first - 1;
            } else if (index > last) {
                L = last + 1;
            } else {
                ans = arr[index];
                break;
            }
        }
        return ans;

    }

    public void partition(int[] arr, int l, int r, int x) {
        first = l;
        last = r;
        int i = l;
        while (i <= last) {
            if (arr[i] < x) {
                swap(arr, first++, i++);
            } else if (arr[i] == x) {
                i++;
            } else {
                swap(arr, last--, i);
            }
        }
    }

    public int mediansOfMedian(int[] arr, int L, int R) {
        int size = R - L + 1, offset = size % 5 == 0 ? 0 : 1;
        int[] m = new int[size / 5 + offset];
        for (int team = 0; team < m.length; team++) {
            int teamFirst = L + team * 5;
            m[team] = getMedian(arr, teamFirst, Math.min(R, teamFirst + 4));
        }
        return bfprt(m, 0, m.length - 1, m.length / 2);
    }

    public int getMedian(int[] arr, int L, int R) {
        insertionSort(arr, L, R);
        return arr[L + ((R - L) >> 1)];
    }

    public void insertionSort(int[] arr, int L, int R) {
        for (int i = L + 1; i <= R; i++) {
            for (int j = i - 1; j >= L && arr[j] > arr[j + 1]; j--) {
                swap(arr, j, j + 1);
            }
        }
    }

    public void swap(int[] arr, int i, int j) {
        int temp = arr[i];
        arr[i] = arr[j];
        arr[j] = temp;
    }
}