numpy-形状操作

Posted by Hilda on July 16, 2025

1.数组变形

数组变形是指改变数组的维度(shape),同时保持其数据不变。这是 NumPy 中一个非常常见的操作,允许您将数据以不同的组织形式进行处理。

  • ndarray.reshape(new_shape) 方法:
    • 这是最常用的变形方法。它返回一个具有新形状的数组,但通常不复制数据,而是创建一个原始数组的视图。只有当新旧形状的数据存储方式不兼容时(例如,需要重新排列内存中的数据),才会复制数据。
    • new_shape 参数是一个整数元组,指定了新数组的每个维度的大小。
    • 新形状中的元素总数必须与原始数组中的元素总数相同。
    • 自动推断维度 (-1):new_shape 中,您可以将一个维度指定为 -1。NumPy 会根据数组的总元素数和其余维度的大小自动推断出该维度的大小。这在您不确定某个维度具体大小,但知道其他维度时非常方便。例如,arr.reshape(-1, 5) 表示将数组重塑为 N 行 5 列,N 会自动计算。
1
2
3
4
5
6
7
8
9
10
11
12
13
arr = np.random.randint(0, 10, size=(3, 4, 5))
print(arr)
# 变形为2维数组  12 * 5
arr1 = arr.reshape(12, 5)
print(arr1)
print(arr1.base is arr)  # True
print(arr1.flags.owndata)  # False
# -1自动推断维度
arr2 = arr.reshape(-1, 5)
print(arr2.shape)  # (12, 5)
# 视图的特性
arr1[0, 0] = 9999
print(arr)   # 影响了原数组
  • ndarray.flatten()ndarray.ravel() 方法:
    • 这两个方法都可以将多维数组展平为一维数组。
    • flatten() 总是返回一个副本
    • ravel() 通常返回一个视图(如果可能),只有在必要时才返回副本。因此,ravel() 通常更推荐用于性能考虑。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
arr1 = np.arange(1, 7).reshape(2, 3)
arr1
"""
array([[1, 2, 3],
       [4, 5, 6]])
"""
arr2 = arr1.flatten()
arr2   
"""
array([1, 2, 3, 4, 5, 6])
"""
arr3 = arr1.ravel()
arr3
"""
array([1, 2, 3, 4, 5, 6])
"""
arr2.base is arr1  # False
print(arr3.base is arr1)  # False
arr3[0] = 9999
arr1   # ravel会影响原数组
arr2[1] = 888
arr1   # flattern不会影响原数组

数组变形的本质是在不改变底层数据存储的情况下,改变 NumPy 数组对象的元数据(shapestrides)。

  • shape 定义了数组在每个维度上的大小。
  • strides 定义了在每个维度上移动一个元素需要跳过的字节数。例如,对于一个行主序(row-major order)的二维数组,strides 会告诉你从一行跳到下一行需要跳过多少字节,以及从一列跳到下一列需要跳过多少字节。

当调用 reshape() 时,NumPy 会尝试计算新的 shapestrides,使得新数组仍然可以“视图”到原始数组的内存。如果能够成功计算出这样的 shapestrides(即数据在内存中仍然是连续的,或者可以通过简单的步长调整来访问),那么 reshape() 就会返回一个视图。这种情况下,新数组的 base 属性会指向原始数组,flags.owndataFalse

只有当新的形状要求数据在内存中以一种不连续的方式排列,或者需要重新组织数据以适应新的维度顺序时,reshape() 才不得不复制数据。例如,如果原始数组是 C-contiguous(行主序),而新形状要求 F-contiguous(列主序),则可能需要复制。但在大多数常见情况下,reshape() 返回视图。

自动推断维度 -1 的原理是:如果数组总共有 N 个元素,并且您指定了 k−1 个维度的大小,那么第 k 个维度的大小就是 N 除以已指定维度的乘积。

选择题

  1. 给定一个 NumPy 数组 arr = np.arange(12),以下哪种 reshape 操作是无效的?

    A. arr.reshape(3, 4) B. arr.reshape(2, 2, 3) C. arr.reshape(-1, 6) D. arr.reshape(5, 3)

    答案:D。arr 有 12 个元素。\(5 \times 3 = 15\),元素数量不匹配。

  2. 关于 ndarray.reshape() 方法,以下哪项描述是正确的?

    A. 它总是返回原始数组的副本。 B. 它总是返回原始数组的视图。

    C. 它可能返回视图或副本,具体取决于新旧形状是否兼容内存布局。 D. 它只能用于将数组展平为一维。

    答案:C

编程题

  1. 创建一个 1*10 的 NumPy 数组,包含 0 到 9 的整数。
  2. 将其变形为以下形状:
    • 2*5
    • 5*2
    • 10*1
  3. 对于每次变形,打印新数组的形状,并验证它是否是原始数组的视图(检查 base 属性)。
  4. 2*5 形状的数组的第一个元素修改为 100,然后打印原始数组以观察变化。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
arr1 = np.arange(0, 10)
arr1
"""
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
"""
arr2 = arr1.reshape(2, 5)
arr2
"""
array([[0, 1, 2, 3, 4],
       [5, 6, 7, 8, 9]])
"""
arr3 = arr1.reshape(5, 2)
arr3
"""
array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7],
       [8, 9]])
"""

arr4 = arr1.reshape(10, 1)
arr4
"""
array([[0],
       [1],
       [2],
       [3],
       [4],
       [5],
       [6],
       [7],
       [8],
       [9]])
"""

arr2.base is arr1   # True
arr3.base is arr1    # True
arr4.base is arr1    # True
arr2[0, 0] = 100
arr1
"""
array([100,   1,   2,   3,   4,   5,   6,   7,   8,   9])
"""

2.数组转置

数组转置是改变数组维度的顺序。对于二维数组,这相当于数学上的矩阵转置,即行变成列,列变成行。对于高维数组,转置允许您任意重新排列维度。

  • .T 属性:
    • 这是二维数组最常用的转置方式。对于二维数组 arrarr.T 返回其转置。
    • 对于高维数组,.T 相当于 np.transpose(arr, axes=(..., 1, 0)),它会反转所有维度的顺序。
  • np.transpose(arr, axes=None) 函数:
    • 这是一个更通用的转置函数,可以用于任意维度的数组。
    • arr:要转置的数组。
    • axes:一个元组或列表,指定新轴的顺序。例如,对于一个三维数组 (d0, d1, d2)
      • axes=(0, 1, 2):不改变顺序。
      • axes=(1, 0, 2):交换第一个和第二个维度。
      • axes=(2, 0, 1):将第三个维度放到第一个位置,第一个维度放到第二个位置,第二个维度放到第三个位置。
    • 如果 axesNone(默认),则 np.transpose 会反转维度的顺序,与 .T 的行为一致。

【1】二维转置

1
2
3
4
5
6
7
8
arr = np.random.randint(0, 10, (3, 5))
print(arr)
arr_T = arr.T
print(arr_T)
# 验证  .T是返回  视图
print(arr_T.base is arr)   # True
arr_T[0, 0] = 777
print(arr)   # 会影响原数组

【2】三维数组转置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
arr = np.random.randint(0, 10, (3, 6, 4))
print(arr)
"""
[[[0 1 9 5]
  [4 2 6 8]
  [3 8 2 5]
  [1 0 9 9]
  [0 8 6 3]
  [4 6 1 7]]

 [[3 3 2 3]
  [2 9 6 5]
  [4 6 0 5]
  [2 2 1 0]
  [0 0 3 4]
  [4 4 2 2]]

 [[0 5 3 0]
  [9 1 9 8]
  [2 7 6 7]
  [3 9 8 3]
  [9 9 2 2]
  [7 8 5 1]]]
"""
# 交换维度0, 1
arr2 = np.transpose(arr, axes=(1, 0, 2))
print(arr2)
"""
[[[0 1 9 5]
  [3 3 2 3]
  [0 5 3 0]]

 [[4 2 6 8]
  [2 9 6 5]
  [9 1 9 8]]

 [[3 8 2 5]
  [4 6 0 5]
  [2 7 6 7]]

 [[1 0 9 9]
  [2 2 1 0]
  [3 9 8 3]]

 [[0 8 6 3]
  [0 0 3 4]
  [9 9 2 2]]

 [[4 6 1 7]
  [4 4 2 2]
  [7 8 5 1]]]
"""
print(arr2.base is arr)   # True

选择题

  1. 给定一个 NumPy 数组 arr = np.array([[1, 2], [3, 4]]),执行 arr.T 后,arr[0, 1] 的值是什么?

    A. 1 B. 2 C. 3 D. 4

    答案:C

  2. 对于一个形状为 (2, 3, 4) 的三维 NumPy 数组 data,以下哪个 np.transpose 调用会使其形状变为 (4, 2, 3)

    A. np.transpose(data, axes=(0, 1, 2)) B. np.transpose(data, axes=(2, 0, 1))

    C. np.transpose(data, axes=(1, 2, 0)) D. np.transpose(data)

    答案:B

编程题

  1. 创建一个形状为 (2, 3, 4) 的三维 NumPy 数组,元素为 0 到 23 的整数。
  2. 执行以下转置操作:
    • 使用 .T 属性进行转置。
    • 使用 np.transpose 将第一个维度和第三个维度交换(即 (d0, d1, d2) 变为 (d2, d1, d0))。
  3. 打印每次操作后的数组形状,并验证它们是否是原始数组的视图
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
arr = np.arange(0, 24).reshape(2, 3, 4)
print(arr)
"""
[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]
"""

arr_T = arr.T
print(arr_T.shape)   # (4, 3, 2)
arr_ = np.transpose(arr, axes=(2, 1, 0))
print(arr_.shape)   # (4, 3, 2)
print(arr_T.base is arr)   # False
print(arr_.base is arr)   # False

3.数组堆叠

数组堆叠是将多个数组沿着某个轴组合成一个更大的数组。这与 concatenate 密切相关,但提供了更方便的函数来处理常见堆叠场景。

  • np.concatenate((arr1, arr2, ...), axis=0) 函数:
    • 这是最通用的数组连接函数。它将一系列(元组或列表)数组沿着指定轴连接起来。
    • axis 参数指定了连接的轴。
      • axis=0 (默认):沿着第一个轴(行)连接数组。所有数组除了连接轴之外的其他轴的长度必须相同。
      • axis=1:沿着第二个轴(列)连接数组。所有数组除了连接轴之外的其他轴的长度必须相同。
    • 所有要连接的数组必须具有相同的形状,除了在连接轴上的维度。
    • concatenate 总是返回一个新数组(副本)。
  • np.vstack((arr1, arr2, ...)) 函数(Vertical Stack):
    • 沿着垂直方向(行方向,即 axis=0)堆叠数组。
    • 等价于 np.concatenate((arr1, arr2, ...), axis=0)
    • 所有数组必须具有相同的列数(即除了第一个维度之外的形状必须相同)。
  • np.hstack((arr1, arr2, ...)) 函数(Horizontal Stack):
    • 沿着水平方向(列方向,即 axis=1)堆叠数组。
    • 等价于 np.concatenate((arr1, arr2, ...), axis=1)
    • 所有数组必须具有相同的行数(即除了第二个维度之外的形状必须相同)。
  • np.dstack((arr1, arr2, ...)) 函数(Depth Stack):
    • 沿着深度方向(第三个维度,即 axis=2)堆叠数组。
    • 对于二维数组,它会在第三个维度上增加一个维度,然后进行堆叠。
    • 例如,两个 (M, N) 数组 AB 堆叠后会得到一个 (M, N, 2) 的数组。
  • np.stack((arr1, arr2, ...), axis=0) 函数:
    • concatenate 不同,stack 会在新轴上堆叠数组。这意味着它会增加一个维度。
    • 例如,两个形状为 (M, N) 的数组,np.stack((arr1, arr2), axis=0) 会得到一个形状为 (2, M, N) 的数组。
    • stack 要求所有输入数组具有完全相同的形状。

数组堆叠操作的核心是创建一块新的内存区域,并将所有输入数组的数据按照指定的轴顺序复制到这块新内存中。

  • 内存分配: NumPy 会计算所有输入数组合并后所需的总内存大小,并分配一块新的连续内存。
  • 数据复制: 输入数组的元素会按照 axis 参数指定的顺序,逐个或逐块地复制到新分配的内存中。
  • 新数组对象: 一个新的 NumPy 数组对象被创建,其 shapestrides 会根据堆叠后的新结构进行设置。
  • 总是副本: 由于涉及将多个分散的数组数据整合到一块新的连续内存中,堆叠操作总是返回一个副本,而不是视图。

(1)np.concatenate

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
arr1 = np.array([[1, 2, 3]])
arr2 = np.array([[4, 5, 6]])
arr3 = np.array([[7, 8, 9]])
concat1 = np.concatenate([arr1, arr2], axis = 0)   # 行串联
print(concat1)
"""
[[1 2 3]
 [4 5 6]]
"""
# axis=1表示对应行的数组进行拼接
concat2 = np.concatenate([arr1.T, arr2.T, arr3.T], axis = 1) 
print(concat2)
"""
[[1 4 7]
 [2 5 8]
 [3 6 9]]
"""
concat3 = np.concatenate([arr1, arr2, arr3], axis = 1)
print(concat3)
"""
[[1 2 3 4 5 6 7 8 9]]
"""

如何理解axis这个参数:

所说的第一个维度就是沿着x方向进行拼接,也就是把矩阵和矩阵上下拼接;第二个维度就是沿着y方向进行拼接,也就是把矩阵和矩阵左右拼接;第三个维度就是沿着z方向进行拼接,也就是把矩阵和矩阵合在一起。(x,y方向就是正常的坐标轴方向)

注意: 拼接时候一定要注意维度,就好比axis=0,要进行上下拼接,那么两个矩阵的列数一定要相同;axis=1就是行数相同;axis=2就是行列数均相同。

(2)np.vstack

相当于axis=0,列数要相同

1
2
3
4
5
6
7
8
9
arr1 = np.arange(0, 3)
arr2 = np.arange(4, 10).reshape(-1, 3)
arr_vstack = np.vstack((arr1, arr2))
arr_vstack
"""
array([[0, 1, 2],
       [4, 5, 6],
       [7, 8, 9]])
"""

(3)np.hstack

相当于axis = 1,行数相同

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
arr1 = np.arange(0,12).reshape(3, -1)
arr1
"""
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])
"""
arr2 = np.arange(9, 15).reshape(3, -1)
arr2
"""
array([[ 9, 10],
       [11, 12],
       [13, 14]])
"""
arr_hstack = np.hstack((arr1, arr2))
arr_hstack
"""
array([[ 0,  1,  2,  3,  9, 10],
       [ 4,  5,  6,  7, 11, 12],
       [ 8,  9, 10, 11, 13, 14]])
"""

(4)np.stack

该函数主要用来提升维度。

axis参数指定新轴在结果尺寸中的索引。例如,如果axis=0,它将是第一个维度,如果axis=-1,它将是最后一个维度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
arr1 = np.arange(0, 3)
arr2 = np.arange(4, 7)
arr3 = np.arange(7, 10)
arr_axis0 = np.stack((arr1, arr2, arr3), axis=0)
arr_axis0
"""
array([[0, 1, 2],
       [4, 5, 6],
       [7, 8, 9]])
"""
arr_axis1 = np.stack((arr1, arr2, arr3), axis=1)
arr_axis1
"""
array([[0, 4, 7],
       [1, 5, 8],
       [2, 6, 9]])
"""

假设要转变的张量数组arrays的长度为N,其中的每个张量数组的形状为(A, B, C)。

如果轴axis=0,则转变后的张量的形状为(N, A, B, C)。

如果轴axis=1,则转变后的张量的形状为(A, N, B, C)。

如果轴axis=2,则转变后的张量的形状为(A, B, N, C)。其它情况依次类推。

image-20250716143606251

(5)np.dstack

dstack专门用于在第三个维度上进行堆叠,会将(M, N)的数组变成(M, N, K)

例如:

1
2
3
4
5
6
7
8
9
10
11
12
13
arr1 = np.array([[1, 2, 3], [4, 5, 6]])   # 2 * 3
arr2 = np.array([[7, 8, 9], [10, 11, 12]])   # 2 * 3
arr_dstack = np.dstack((arr1, arr2))   # 2 * 3 * 2
arr_dstack   
"""
array([[[ 1,  7],
        [ 2,  8],
        [ 3,  9]],

       [[ 4, 10],
        [ 5, 11],
        [ 6, 12]]])
"""

选择题

  1. 给定以下两个 NumPy 数组:

    1
    2
    3
    
    import numpy as np
    a = np.array([1, 2])
    b = np.array([3, 4])
    

    执行 np.vstack((a, b)) 后,结果数组的形状是什么?

    A. (2,) B. (1, 4) C. (2, 2) D. (4,)

    答案:C

  2. 以下哪个函数会将两个形状为 (M, N) 的二维数组 arr1arr2 堆叠成一个形状为 (M, N, 2) 的三维数组?

    A. np.concatenate((arr1, arr2), axis=2) B. np.vstack((arr1, arr2))

    C. np.hstack((arr1, arr2)) D. np.dstack((arr1, arr2))

    答案:D

编程题

  1. 创建两个形状为 (3, 4) 的 NumPy 数组,分别包含 0 到 11 和 12 到 23 的整数。
  2. 使用 np.concatenate 将它们沿着行(axis=0)堆叠。
  3. 使用 np.hstack 将它们沿着列(axis=1)堆叠。
  4. 创建一个新的 3*4 数组,包含 24 到 35 的整数。使用 np.stack 将这三个 3*4 数组沿着一个新的轴(例如 axis=0)堆叠成一个三维数组。
  5. 打印每次操作后的数组及其形状。

参考:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
arr1 = np.arange(0, 12).reshape(3, 4)
arr1
"""
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])
"""
arr2 = np.arange(12, 24).reshape(3, 4)
arr2
"""
array([[12, 13, 14, 15],
       [16, 17, 18, 19],
       [20, 21, 22, 23]])
"""
res1 = np.concatenate((arr1, arr2), axis=0)
res1
"""
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15],
       [16, 17, 18, 19],
       [20, 21, 22, 23]])
"""
res2 = np.hstack((arr1, arr2))
res2
"""
array([[ 0,  1,  2,  3, 12, 13, 14, 15],
       [ 4,  5,  6,  7, 16, 17, 18, 19],
       [ 8,  9, 10, 11, 20, 21, 22, 23]])
"""
arr3 = np.arange(24, 36).reshape(3, 4)
arr3
"""
array([[24, 25, 26, 27],
       [28, 29, 30, 31],
       [32, 33, 34, 35]])
"""
res3 = np.stack((arr1, arr2, arr3), axis=0)  # 3 * 3 * 4
res3
"""
array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]],

       [[24, 25, 26, 27],
        [28, 29, 30, 31],
        [32, 33, 34, 35]]])
"""

4.数组拆分

数组拆分是将一个数组沿着某个轴拆分成多个子数组。这是堆叠操作的逆过程。

  • np.split(arr, indices_or_sections, axis=0) 函数:
    • 这是最通用的数组拆分函数。
    • arr:要拆分的数组。
    • indices_or_sections
      • 如果是一个整数 N,表示将数组沿着指定轴平均拆分成 N 个子数组。如果数组在该轴上的长度不能被 N 整除,则会引发错误。
      • 如果是一个一维整数数组(或列表),表示拆分的断点(索引)。例如,[2, 5] 表示在索引 25 处进行拆分,结果会是三个子数组:arr[:, :2], arr[:, 2:5], arr[:, 5:]
    • axis:指定拆分的轴。
      • axis=0 (默认):沿着行方向拆分。
      • axis=1:沿着列方向拆分。
    • split 总是返回一个列表,其中包含拆分后的子数组。这些子数组通常是原始数组的视图
  • np.vsplit(arr, indices_or_sections) 函数(Vertical Split):
    • 沿着垂直方向(行方向,即 axis=0)拆分数组。
    • 等价于 np.split(arr, indices_or_sections, axis=0)
  • np.hsplit(arr, indices_or_sections) 函数(Horizontal Split):
    • 沿着水平方向(列方向,即 axis=1)拆分数组。
    • 等价于 np.split(arr, indices_or_sections, axis=1)
  • np.dsplit(arr, indices_or_sections) 函数(Depth Split):
    • 沿着深度方向(第三个维度,即 axis=2)拆分数组。
    • 等价于 np.split(arr, indices_or_sections, axis=2)

数组拆分操作通过创建多个新的数组对象来实现,这些新数组对象共享原始数组的底层数据。

  • 视图创建: 当您拆分一个数组时,NumPy 不会复制数据。相反,它会创建多个新的 NumPy 数组对象,每个对象都代表原始数组中相应部分的视图。
  • shapestrides 调整: 每个新的子数组对象都有自己的 shapestrides,这些元数据被设置为正确地指向原始数据缓冲区中的相应部分。
  • base 属性: 拆分后的子数组的 base 属性会指向原始数组,flags.owndata 会是 False
  • 返回列表: 拆分函数返回一个包含这些视图的 Python 列表。

(1)np.split

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
arr = np.arange(0, 30).reshape(6, 5)
arr
"""
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24],
       [25, 26, 27, 28, 29]])
"""
s1 = np.split(arr, indices_or_sections=2, axis=0)
s1
"""
[array([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]]),
 array([[15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]])]
"""
s2 = np.split(arr, indices_or_sections=[2, 5], axis=0)
s2
"""
[array([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]]),
 array([[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24]]),
 array([[25, 26, 27, 28, 29]])]
"""
s3 = np.split(arr, indices_or_sections=[2, 3], axis=1)
s3
"""
[array([[ 0,  1],
        [ 5,  6],
        [10, 11],
        [15, 16],
        [20, 21],
        [25, 26]]),
 array([[ 2],
        [ 7],
        [12],
        [17],
        [22],
        [27]]),
 array([[ 3,  4],
        [ 8,  9],
        [13, 14],
        [18, 19],
        [23, 24],
        [28, 29]])]
"""

(2)np.vsplit

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
arr
"""
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24],
       [25, 26, 27, 28, 29]])
"""
v = np.vsplit(arr, indices_or_sections=3)
v
"""
[array([[0, 1, 2, 3, 4],
        [5, 6, 7, 8, 9]]),
 array([[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]]),
 array([[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29]])]
"""

(3)np.hsplit

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
arr
"""
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24],
       [25, 26, 27, 28, 29]])
"""
h = np.hsplit(arr, indices_or_sections=[2, 4])
h
"""
[array([[ 0,  1],
        [ 5,  6],
        [10, 11],
        [15, 16],
        [20, 21],
        [25, 26]]),
 array([[ 2,  3],
        [ 7,  8],
        [12, 13],
        [17, 18],
        [22, 23],
        [27, 28]]),
 array([[ 4],
        [ 9],
        [14],
        [19],
        [24],
        [29]])]
"""

(4)np.dsplit

这个函数是np.split(..., axis=2)的便捷函数。

注意:np.dsplit 要求数组的第三个维度(axis=2)能够被请求的分割数整除。如果数组 arr.shape 是 (2, 2, 3),尝试沿第三个轴(大小为3)分成2部分,但3除以2不能整除(3 ÷ 2 = 1.5),此时就会报错。下面是一个正确的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
arr = np.array(
    [
        [[1, 2, 3], [4, 5, 6]],
        [[7, 8, 9], [10, 11, 12]]
    ]
)
arr.shape  # (2, 2, 3)
res = np.dsplit(arr, 3)
res
"""
[array([[[ 1],
         [ 4]],
 
        [[ 7],
         [10]]]),
 array([[[ 2],
         [ 5]],
 
        [[ 8],
         [11]]]),
 array([[[ 3],
         [ 6]],
 
        [[ 9],
         [12]]])]
"""

选择题

  1. 给定一个 NumPy 数组 data = np.arange(12).reshape(3, 4),执行 np.split(data, 3, axis=0) 后,结果列表中的每个子数组的形状是什么?

    A. (1, 4) B. (3, 1) C. (4,) D. (3, 4)

    答案:A

  2. 以下哪个函数用于将一个数组沿着第三个维度(深度)拆分?

    A. np.vsplit() B. np.hsplit() C. np.split(..., axis=2) D. np.dsplit() E. C 和 D 都是正确的。

    答案:E

编程题

  1. 创建一个形状为 (8, 6) 的 NumPy 数组,包含 0 到 47 的整数。
  2. 使用 np.vsplit 将其平均拆分成 4 个子数组。
  3. 使用 np.hsplit 将其在列索引 2 和 4 处拆分。
  4. 打印每次操作后得到的子数组列表的长度,以及每个子数组的形状。
  5. 验证拆分后的子数组是否是原始数组的视图。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
arr = np.arange(0, 48).reshape(8, 6)
arr
"""
array([[ 0,  1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17],
       [18, 19, 20, 21, 22, 23],
       [24, 25, 26, 27, 28, 29],
       [30, 31, 32, 33, 34, 35],
       [36, 37, 38, 39, 40, 41],
       [42, 43, 44, 45, 46, 47]])
"""
arr_v = np.vsplit(arr, 4)
arr_v
"""
[array([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]]),
 array([[12, 13, 14, 15, 16, 17],
        [18, 19, 20, 21, 22, 23]]),
 array([[24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35]]),
 array([[36, 37, 38, 39, 40, 41],
        [42, 43, 44, 45, 46, 47]])]
"""
arr_h = np.hsplit(arr, [2, 4])
arr_h
"""
[array([[ 0,  1],
        [ 6,  7],
        [12, 13],
        [18, 19],
        [24, 25],
        [30, 31],
        [36, 37],
        [42, 43]]),
 array([[ 2,  3],
        [ 8,  9],
        [14, 15],
        [20, 21],
        [26, 27],
        [32, 33],
        [38, 39],
        [44, 45]]),
 array([[ 4,  5],
        [10, 11],
        [16, 17],
        [22, 23],
        [28, 29],
        [34, 35],
        [40, 41],
        [46, 47]])]
"""
for i, a in enumerate(arr_h):
    print(f"{i}-----\n{a}")
    print(a.base)   
    """
    [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47]
    """
    # base 属性指向原始数组 arr 的数据缓冲区
    # 原始数组 arr 是通过 np.arange(0, 48).reshape(8, 6) 创建的,其底层数据是一个一维数组 [0, 1, ..., 47]