022 归并分治

归并分治将问题分解为左右子问题及跨越左右的答案。若左右有序能简化跨越部分计算,则适合归并分治。解题时融入归并排序保证左右有序。牛客小和累积和与力扣493翻转对是典型例题,关键在于高效统计跨越左右的答案。归并分治也可用线段树等解决,并能处理更复杂问题。

Posted by Hilda on March 16, 2025

概览:

原理:

1)思考一个问题在大范围上的答案,是否等于,左部分的答案 + 右部分的答案 + 跨越左右产生的答案

2)计算“跨越左右产生的答案”时,如果加上左、右各自有序这个设定,会不会获得计算的便利性

3)如果以上两点都成立,那么该问题很可能被归并分治解决(话不说满,因为总有很毒的出题人)

4)求解答案的过程中只需要加入归并排序的过程即可,因为要让左、右各自有序,来获得计算的便利性

补充:

1)一些用归并分治解决的问题,往往也可以用线段树、树状数组等解法。时间复杂度也都是最优解,这些数据结构都会在 【必备】或者【扩展】课程阶段讲到

2)本节讲述的题目都是归并分治的常规题,难度不大。归并分治不仅可以解决简单问题,还可以解决很多较难的问题,只要符合上面说的特征。比如二维空间里任何两点间的最短距离问题,这个内容会在【挺难】课程阶段里讲述。顶级公司考这个问题的也很少,因为很难,但是这个问题本身并不冷门,来自《算法导论》原题

3)还有一个常考的算法:“整块分治”。会在【必备】课程阶段讲到

聊:精妙又美丽的思想传统(不要太纠结是这么想到的,先接受它。)

牛客-小和累积和

题目链接:https://www.nowcoder.com/practice/edfe05a1d45c4ea89101d936cac32469

代码的关键是“统计部分”:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// --- 统计部分 ---
long ans = 0, sum = 0; // ans 记录总小和,sum 记录当前窗口的累加和
// x 遍历左半部分 [l, m],y 遍历右半部分 [m+1, r]
for (int x = l, y = m + 1; y <= r; y++) {
  // 对于当前 arr[y],找到左半部分中所有小于等于 arr[y] 的元素
  while (x <= m && arr[x] <= arr[y]) {
    sum += arr[x++]; // 将 arr[x] 加到 sum 中,并移动 x
  }
  // arr[y] 的小和贡献:sum 表示左半部分中小于等于 arr[y] 的元素之和
  // 因为 arr[m+1..r] 有序,y 每次右移时,sum 的值适用于当前 arr[y]
  ans += sum;
}
// 注意:当 x > m 时,剩余的 arr[y] 没有左侧更小的元素,贡献为 0(sum 已不再增加)

建议通过一个例子模拟一下就知道什么意思了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import java.io.*;

public class E0316_nowcoder_SmallSum {

    public static int MAXN = 100001;
    public static int n;//数据范围

    public static int[] arr = new int[MAXN];
    public static int[] help = new int[MAXN];


    //题目链接:https://www.nowcoder.com/practice/edfe05a1d45c4ea89101d936cac32469

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StreamTokenizer st = new StreamTokenizer(br);
        PrintWriter pw = new PrintWriter(new OutputStreamWriter(System.out));
        while (st.nextToken() != StreamTokenizer.TT_EOF) {
            n = (int) st.nval;
            for (int i = 0; i < n; i++) {
                st.nextToken();
                arr[i] = (int) st.nval;
            }
            pw.println(smallSum(0, n - 1));
        }
        pw.flush();
        pw.close();

    }

    public static long smallSum(int l, int r) {
        if (l == r) return 0;//求小和就是0
        int m = l + ((r - l) >> 1);
        return smallSum(l, m) + smallSum(m + 1, r) + merge(l, m, r);
    }

    public static long merge(int l, int m, int r) {
        //统计部分
        long ans = 0, sum = 0;
        for(int x = l, y = m + 1; y <= r;y++){
          while(x <= m && arr[x] <= arr[y]){
            sum += arr[x++];
          }
          ans += sum;
        }

        //merge部分
        int a = l, b = m + 1, i = l;
        while (a <= m && b <= r) {
            help[i++] = arr[a] <= arr[b] ? arr[a++] : arr[b++];
        }
        while(a <= m){
            help[i++] = arr[a++];
        }
        while(b <= r){
            help[i++] = arr[b++];
        }
        for (int j = l;j <= r;j++){
            arr[j] = help[j];
        }

        return ans;
    }
}

力扣493. 翻转对

493. 翻转对

给定一个数组 nums ,如果 i < jnums[i] > 2*nums[j] 我们就将 (i, j) 称作一个*重要翻转对*

你需要返回给定数组中的重要翻转对的数量。

示例 1:

1
2
输入: [1,3,2,3,1]
输出: 2

示例 2:

1
2
输入: [2,4,3,5,1]
输出: 3

注意:

  1. 给定数组的长度不会超过50000
  2. 输入数组中的所有数字都在32位整数的表示范围内。

解法-归并分治

image-20250316203957240

根据牛客小和累积和的思路,我写了下面的解法:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Solution {
    public static int[] help = new int[50001];

    public int reversePairs(int[] nums) {
        return mergereversePairs(nums, 0, nums.length - 1);
    }

    public int mergereversePairs(int[] arr, int l, int r) {
        if (l == r)
            return 0;
        int m = l + ((r - l) >> 1);
        return mergereversePairs(arr, l, m) + mergereversePairs(arr, m + 1, r) + merge(arr, l, m, r);
    }

    public int merge(int[] arr, int l, int m, int r) {
        //统计部分
        int ans = 0;
        for (int i = l, j = m + 1; i <= m; i++) {
            while (j <= r && (long) arr[i] > ((long) arr[j] << 1)) j++;
            ans += (j - m - 1);
        }

        // merge部分
        int a = l, b = m + 1, i = l;
        while (a <= m && b <= r) {
            help[i++] = arr[a] <= arr[b] ? arr[a++] : arr[b++];
        }
        while (a <= m) {
            help[i++] = arr[a++];
        }
        while (b <= r) {
            help[i++] = arr[b++];
        }
        for (int j = l; j <= r; j++) {
            arr[j] = help[j];
        }
        return ans;
    }
}

其中,最关键就是【统计部分】的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
// --- 统计逆序对部分 ---
int ans = 0; // 记录当前合并过程中满足 nums[i] > 2 * nums[j] 的逆序对数量
// i 遍历左半部分 [l, m],j 指向右半部分 [m+1, r] 的当前元素
for (int i = l, j = m + 1; i <= m; i++) {
    // 对于当前 arr[i],找到右半部分第一个不满足 arr[i] > 2 * arr[j] 的 j
    // 因为 arr[m+1..r] 已排序,j 是单调递增的,无需每次从头扫描
    while (j <= r && (long) arr[i] > ((long) arr[j] << 1)) {
        j++; // j 右移,直到找到不满足条件的位置
    }
    // j 停止时,[m+1, j-1] 范围内的元素都满足 arr[i] > 2 * arr[j]
    // 逆序对数量为 j - (m + 1),但这里用了 j - m - 1,与 j++ 的时机一致
    ans += (j - m - 1);
    // 注意:j - m - 1 表示从 m+1 到 j-1 的元素个数,因为 j 是停止时的下一个位置
}

起初我写的很复杂,逻辑也是对的(对于无法通过的测试用例我单独测试都是可以的),但是就是“超出时间限制”,也就是说,需要简化判断条件。

1
2
3
4
5
6
7
8
9
10
11
12
//统计部分
int ans = 0;
for (int i, j = m + 1; j <= r; j++) {
    i = l;
    while (i <= m) {
        if ((long) arr[i] > (long)arr[j] * 2) {
            ans += (m - i + 1);
            break;
        }
        i++;
    }
}