numpy(8)质量控制

第 8 章 质量控制

Posted by Hilda on February 12, 2025

前7章以及其他补充已经整理如下:

第1章NumPy入门

第2章NumPy基础

第3章常用函数

广播与广播机制

第4章便捷函数

第5章矩阵和通用函数

第6章深入学习NumPy模块

第7章专用函数


有些程序员只在产品中做测试。如果你不是他们中的一员,你可能会对单元测试的概念耳熟能详。单元测试是由程序员编写的自动测试模块,用来测试他或者她的代码。这些单元测试可以测试某个函数或函数中的某个独立的部分。每一个单元测试仅仅对一小部分代码进行测试。单元测试可以带来诸多好处,如提高代码质量、可重复性测试等,使软件副作用更为清晰。 Python 本身对单元测试就有良好的支持。此外, NumPy 中也有numpy.testing包可以支持NumPy代码的单元测试。

TDD(Test Driven Development,测试驱动的开发)是软件开发史上最重要的里程碑之一。TDD主要专注于自动单元测试,它的目标是尽最大限度自动化测试代码。如果代码被改动,我们仍可以运行测试并捕捉可能存在的问题。换言之,测试对于已经存在的功能模块依然有效。

注:TDD(Test-Driven Development,测试驱动开发) 是一种软件开发方法论,它强调在编写生产代码之前,先编写测试代码。换句话说,开发者在开始实现功能前,先编写测试用例,然后通过编写足够的代码来使这些测试通过,最后进行重构。

TDD 通常遵循以下的 红绿重构(Red-Green-Refactor)循环:

  1. 红色(Red)
    • 编写一个失败的测试用例。测试用例应该描述你希望实现的功能或行为。
    • 由于功能还没有实现,测试会失败。
  2. 绿色(Green)
    • 编写足够的代码,使得测试用例能够通过。
    • 代码的实现目标是使得测试通过,因此此时代码通常是最简单的解决方案。
  3. 重构(Refactor)
    • 在测试通过后,对代码进行重构,以提高代码质量,保持代码的简洁和可维护性。
    • 在重构过程中,确保已有的测试依然通过,验证没有破坏已有的功能。

尽管初期开发可能会稍慢,但长期来看,TDD 有助于提高代码的可维护性和减少bug。

本章涵盖以下内容:

  • 单元测试;
  • 断言机制;
  • 浮点数精度。

8.1 断言函数

单元测试通常使用断言函数作为测试的组成部分。在进行数值计算时,我们经常遇到比较两个近似相等的浮点数这样的基本问题。整数之间的比较很简单,但浮点数却非如此,这是由于计算机对浮点数的表示本身就是不精确的numpy.testing包中有很多实用的工具函数考虑了浮点数比较的问题,可以测试前提是否成立。

为什么计算机对浮点数的表示本身就是不精确?

1
2
3
4
a = 0.1 + 0.2
b = 0.3
print(a == b)  # 输出: False

计算机中的数字是通过有限数量的二进制位来表示的,而浮点数是计算机用来表示实数的一种近似方式。在计算机内存中,每个数字(包括浮点数)都有固定的存储空间(比如 32 位或 64 位)。这意味着即使我们需要表示一个无限小数,计算机也只能将它截断为有限位数,因此无法完美精确地表示所有实数。另外,由于浮点数的尾数部分(即有效数字的精度)是有限的,因此对于不能被精确表示的数字,计算机会将它们四舍五入,产生误差。

因此,浮点数的精度问题是计算机科学中的一个固有问题,尤其是在涉及高精度计算时

函数 描述
assert_almost_equal 如果两个数字的近似程度没有达到指定精度,就抛出异常
assert_approx_equal 如果两个数字的近似程度没有达到指定有效数字,就抛出异常
assert_array_almost_equal 如果两个数组中元素的近似程度没有达到指定精度,就抛出异常
assert_array_equal 如果两个数组对象不相同,就抛出异常
assert_array_less 两个数组组必须形状一致,并且第一个数组的元素严格小于第二个数组的元素,否则抛出异常
assert_equal 如果两个对象不相同,就抛出异常
assert_raises 若用填写的参数调用函数没有抛出指定的异常,则测试不通过
assert_warns 若没有抛出指定的警告,则测试不通过
assert_string_equal 判断两个字符串变量完全相同
assert_allclose 如果两个对象的近似程度超过了指定的容差限,就抛出异常

注:assert_* 开头的函数通常用于验证测试中的各种条件,如比较两个数字、数组或对象是否相等,是否触发预期的异常或警告,等等。

注:特别注意np.nan的情况。

补充1:assert_almost_equal

assert_almost_equal 是一个用于 单元测试 的函数,它通常用于 比较两个数字数组,并验证它们是否相等到一定的精度。这个函数常用于确保在计算时,结果与预期值非常接近,通常是浮点数计算的比较。

函数签名:

1
numpy.testing.assert_almost_equal(actual, desired, decimal=6, err_msg='', verbose=True)

参数:

actual: 需要检查的实际值,通常是你计算得到的值。

desired: 期望的值,即你希望实际值接近的值。

decimal(可选,默认值为 6): 期望的精度,表示小数点后要比较的位数。例如,decimal=6 表示实际值和期望值的小数点后最多允许有 6 位差异。

err_msg(可选): 如果断言失败时显示的错误信息。你可以自定义错误消息。

verbose(可选,默认值为 True): 是否显示详细的输出。如果为 True,会显示更多的调试信息。

返回值:如果 actualdesired 在指定的精度范围内相等,则函数什么都不返回,测试通过。如果它们的差异超过了指定的精度,函数会引发 AssertionError,并显示相应的错误信息。

例1:假设你有一个浮点数计算结果,并希望确保它与预期值非常接近,可以使用 assert_almost_equal 来验证。

1
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np

# 实际值
actual_value = 0.3333333

# 期望值
desired_value = 1 / 3

# 比较两个数字,最多允许小数点后 6 位差异
np.testing.assert_almost_equal(actual_value, desired_value, decimal=6)

print("测试通过!")  # 输出:测试通过!

例2:比较数组

1
2
3
4
5
6
7
8
9
10
11
12
import numpy as np

# 实际数组
actual_array = np.array([0.3333333, 0.6666667, 1.0])

# 期望数组
desired_array = np.array([1 / 3, 2 / 3, 1])

# 比较数组中的每个元素,最多允许小数点后 6 位差异
np.testing.assert_almost_equal(actual_array, desired_array, decimal=6)

print("数组测试通过!")  # 输出:数组测试通过!

例3:浮点数之间的差异超过精度限制

注意:需要注意的是,如果在指定位数上数值相差1则仍然不会报错

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import numpy as np

# 实际值
actual_value = 0.3333344

# 期望值
desired_value = 1 / 3

# 比较两个数字,最多允许小数点后 6 位差异
try:
    np.testing.assert_almost_equal(actual_value, desired_value, decimal=6)
    print("测试通过!")# 输出:测试通过!
except AssertionError as e:
    print(f"测试失败: {e}")

下面就会报错:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

# 实际值
actual_value = 0.3333444

# 期望值
desired_value = 1 / 3

# 比较两个数字,最多允许小数点后 6 位差异
try:
    np.testing.assert_almost_equal(actual_value, desired_value, decimal=6)
    print("测试通过!")# 输出:测试通过!
except AssertionError as e:
    print(f"测试失败: {e}")

image-20250210225915596

若有NAN:

1
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np
from numpy.testing import assert_almost_equal

a = np.nan
b = np.nan

# 对于NaN,assert_almost_equal默认认为它们相等
try:
    assert_almost_equal(a, b)
    print("Test passed: NaN values are considered equal")  # 输出:Test passed: NaN values are considered equal
except AssertionError:
    print("AssertionError: NaN values cannot be compared using assert_almost_equal")


补充2:assert_approx_equal

assert_approx_equal 是一个用于单元测试的函数,通常用于比较两个数字或数组,检查它们是否在 近似相等 的情况下通过测试。与 assert_almost_equal 不同,assert_approx_equal 用于判断两个数字是否在给定的 容差 内接近,而不要求它们完全相等到某个固定的小数位数。

函数签名:

1
numpy.testing.assert_approx_equal(actual, desired, significant=7, err_msg='', verbose=True)

参数:

actual: 实际值,通常是你计算得到的值。

desired: 期望的值,即你希望实际值接近的值。

significant(可选,默认值为 7): 允许的 有效数字位数significant=7 表示允许两个数字的差异在 有效数字的第 7 位 内。有效数字决定了两个数字的接近程度。

err_msg(可选): 如果断言失败时显示的错误信息。你可以自定义错误消息。

verbose(可选,默认值为 True): 是否显示详细的输出。如果为 True,会显示更多的调试信息。

返回值:如果 actualdesired 的差异在指定的有效数字范围内,则函数什么都不返回,测试通过。如果它们的差异超过了指定的有效数字,函数会引发 AssertionError,并显示相应的错误信息。

例1:

1
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np

# 实际值
actual_value = 0.3333344

# 期望值
desired_value = 1 / 3  # 即 0.3333333...

# 比较两个数字,最多允许 7 位有效数字的差异
np.testing.assert_approx_equal(actual_value, desired_value, significant=7)

print("测试通过!")

image-20250210230214140

注:将上面的actual_value改为 ` 0.3333334`,则测试通过,即如果在有效数字的指定位数上数值相差1则仍然不会报错


补充3:assert_array_almost_equal

assert_array_almost_equal 是一个用于 单元测试 的函数,通常用于比较两个 数组 中的元素,并检查它们是否在给定的精度范围内相等。该函数适用于浮点数数组的比较,特别是在需要考虑微小的舍入误差时。

函数签名:

1
numpy.testing.assert_array_almost_equal(actual, desired, decimal=6, err_msg='', verbose=True)

参数:

actual: 实际值,通常是你计算得到的数组。

desired: 期望的值,即你希望实际值接近的数组。

decimal(可选,默认值为 6): 允许的精度,表示小数点后要比较的位数。例如,decimal=6 表示实际值和期望值的小数点后最多允许有 6 位差异。

err_msg(可选): 如果断言失败时显示的错误信息。你可以自定义错误消息。

verbose(可选,默认值为 True): 是否显示详细的输出。如果为 True,会显示更多的调试信息。

返回中:如果 actualdesired 数组中的元素在指定的精度范围内相等,则函数什么都不返回,测试通过。如果它们的差异超过了指定的精度,函数会引发 AssertionError,并显示相应的错误信息。

例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np

# 实际数组
actual_array = np.array([0.3333333, 0.6666677, 1.0])

# 期望数组
desired_array = np.array([1 / 3, 2 / 3, 1])

# 比较数组中的每个元素,最多允许小数点后 6 位差异
np.testing.assert_array_almost_equal(actual_array, desired_array, decimal=6)

print("测试通过!") # 输出:测试通过!

注:如果在指定位数上数值相差1则仍然不会报错(例如上面指定6位,但是第六位不一样,还是不报错,测试通过)

补充4:assert_array_equal

assert_array_equal 是 NumPy 中的一个测试函数,用于比较两个数组是否 完全相等。它在单元测试中非常有用,通常用于检查两个数组是否具有相同的形状和元素。

函数签名:

1
numpy.testing.assert_array_equal(actual, desired, err_msg='', verbose=True)

参数:

actual: 实际值,通常是你计算得到的数组。

desired: 期望的值,即你希望实际值和期望值相等的数组。

err_msg(可选): 如果断言失败时显示的错误信息。你可以自定义错误消息。

verbose(可选,默认值为 True): 是否显示详细的输出。如果为 True,会显示更多的调试信息。

返回值:

如果 actualdesired 数组完全相等(形状和元素相同),则函数什么都不返回,测试通过。如果它们不相等,函数会引发 AssertionError,并显示相应的错误信息。

例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np

# 实际数组
actual_array = np.array([1, 2, 3, 4, 5])

# 期望数组
desired_array = np.array([1, 2, 3, 4, 7])

# 比较两个数组是否完全相等
np.testing.assert_array_equal(actual_array, desired_array)

print("测试通过!")

报错:

image-20250210230831050

再看一个例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import numpy as np

# 实际数组
actual_array = np.array([1, 2, 3, 4, 5])

# 期望数组
desired_array = np.array([[1, 2], [3, 4], [5]])  # 形状不同

# 比较两个数组是否完全相等
try:
    np.testing.assert_array_equal(actual_array, desired_array)
    print("测试通过!")
except AssertionError as e:
    print(f"测试失败: {e}")

报错是形状不匹配:报错不是Assertion Error

image-20250210230915906


补充5:assert_array_less

assert_array_less 是 NumPy 中的一个测试函数,用于检查 一个数组中的元素是否严格小于另一个数组中的对应元素,或者 数组中所有元素是否都小于给定的标量。它通常用于单元测试中,以确保数组中的元素满足某个比较条件。

函数签名:

1
numpy.testing.assert_array_less(a, b, err_msg='', verbose=True)

参数:

a: 实际值,可以是一个数组,或与 b 相同形状的数组。需要检查的数组中的元素。

b: 期望的值。可以是一个数组,或一个标量。如果是数组,则其形状必须与 a 相同。如果是标量,则检查数组中所有元素是否小于该标量。

err_msg(可选): 如果断言失败时显示的错误信息。你可以自定义错误消息。

verbose(可选,默认值为 True): 是否显示详细的输出。如果为 True,会显示更多的调试信息。

返回值:

如果 a 中的所有元素都严格小于 b 中对应的元素(或者 a 中所有元素都小于标量 b),则函数什么都不返回,测试通过。如果任意一个元素不满足该条件,函数会引发 AssertionError,并显示相应的错误信息。

例如:测试通过的例子

1
2
3
4
5
6
7
8
9
10
11
12
import numpy as np

# 实际数组
a = np.array([1, 2, 3, 4, 5])

# 期望数组
b = np.array([6, 7, 8, 9, 10])

# 检查 a 中的每个元素是否都严格小于 b 中对应的元素
np.testing.assert_array_less(a, b)

print("测试通过!")  # 输出:测试通过!

下面是2个测试不通过的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import numpy as np

# 实际数组
a = np.array([1, 2, 3, 4, 6])

# 期望数组
b = np.array([5, 5, 5, 5, 5])

# 检查 a 中的每个元素是否都严格小于 b 中对应的元素
try:
    np.testing.assert_array_less(a, b)
    print("测试通过!")
except AssertionError as e:
    print(f"测试失败: {e}")

image-20250210231528685

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import numpy as np

# 实际数组
a = np.array([1, 2, 3, 4, 5])

# 期望标量
b = 4  # 检查是否所有元素都小于 4

# 检查 a 中的每个元素是否都严格小于 4
try:
    np.testing.assert_array_less(a, b)
    print("测试通过!")
except AssertionError as e:
    print(f"测试失败: {e}")

image-20250210231606635


补充6:assert_equal

assert_equal 是一个用于 单元测试 的函数,它用于比较 两个对象是否完全相等。它常用于测试中,确保预期的结果与实际的计算结果完全一致。

函数签名:

1
numpy.testing.assert_equal(actual, desired, err_msg='', verbose=True)

参数:

actual: 需要检查的实际值,通常是你计算得到的值。

desired: 期望的值,即你希望实际值和期望值相等的值。

err_msg(可选): 如果断言失败时显示的错误信息。你可以自定义错误消息。

verbose(可选,默认值为 True): 是否显示详细的输出。如果为 True,会显示更多的调试信息。

返回值:如果 actualdesired 完全相等(包括类型、形状和元素),则函数什么都不返回,测试通过。如果它们不相等,函数会引发 AssertionError,并显示相应的错误信息。

例子:

1
2
3
4
5
6
7
8
9
10
11
12
import numpy as np

# 实际数组
actual_array = np.array([1, 2, 3, 4, 5])

# 期望数组
desired_array = np.array([1, 2, 3, 4, 7])

# 比较两个数组是否完全相等
np.testing.assert_equal(actual_array, desired_array)

print("数组测试通过!")

image-20250210231937098


补充7:assert_raises

assert_raises 是一个用于单元测试的函数,用于检查在执行某个代码块时是否会引发指定的异常。它通常用于验证当输入不合法或者发生某些错误时,程序是否正确地抛出了预期的异常。

函数签名:

1
numpy.testing.assert_raises(expected_exception, callable_obj, *args, **kwargs)

参数:

  • expected_exception: 期望的异常类型,可以是一个异常类(例如 ValueErrorTypeError 等),或者自定义的异常类。
  • callable_obj: 可调用对象(如函数、方法),它应该是一个函数或可以调用的对象,这个函数会在测试过程中执行。
  • *args: 可选参数,传递给 callable_obj 的位置参数。
  • **kwargs: 可选参数,传递给 callable_obj 的关键字参数。

返回值:

  • 如果执行 callable_obj 时抛出了指定的异常 expected_exception,则测试通过。
  • 如果没有抛出该异常,或者抛出了不同的异常,测试失败并抛出 AssertionError

例子1:假设你编写了一个函数,当输入负数时会抛出 ValueError 异常。

1
2
3
4
5
6
7
8
9
10
11
12
import numpy as np

# 测试函数
def test_negative_input(x):
    if x < 0:
        raise ValueError("输入不能是负数")
    return x

# 检查当输入负数时是否抛出 ValueError 异常
np.testing.assert_raises(ValueError, test_negative_input, -1)

print("测试通过!")  # 输出:测试通过!

例子2:报异常:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import numpy as np

# 测试函数
def test_negative_input(x):
    if x < 0:
        raise ValueError("输入不能是负数")
    return x

# 检查当输入 5 时是否抛出 ValueError 异常
try:
    np.testing.assert_raises(ValueError, test_negative_input, 5)
    print("测试通过!")
except AssertionError as e:
    print(f"测试失败: {e}")

image-20250210232543807

由于没有抛出期望的 ValueError 异常,测试失败并引发 AssertionError

例子3:

1
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np

# 测试函数
def test_zero_division(x):
    return 10 / x

# 检查当传入 0 时是否抛出 ZeroDivisionError 异常
try:
    np.testing.assert_raises(ValueError, test_zero_division, 0)
    print("测试通过!")
except AssertionError as e:
    print(f"测试失败: {e}")

image-20250210232725826

test_zero_division(0) 会抛出 ZeroDivisionError,而不是 ValueError

因为我们期望 ValueError,但是实际上抛出了 ZeroDivisionError,测试失败。

补充8:assert_warns

assert_warns 是一个用于 单元测试 的函数,它用于检查在执行某个代码块时是否会触发指定的警告。通常情况下,警告并不直接导致程序崩溃,但我们可能需要确保在某些特定条件下,程序会给出相应的警告。

函数签名:

1
numpy.testing.assert_warns(expected_warning, callable_obj, *args, **kwargs)

参数:

  • expected_warning: 期望的警告类型,可以是一个警告类(例如 DeprecationWarningUserWarning 等),或者自定义的警告类。
  • callable_obj: 可调用对象(如函数、方法),它应该是一个函数或可以调用的对象,这个函数会在测试过程中执行。
  • *args: 可选参数,传递给 callable_obj 的位置参数。
  • **kwargs: 可选参数,传递给 callable_obj 的关键字参数。

返回值:

  • 如果执行 callable_obj 时触发了指定的警告 expected_warning,则测试通过。
  • 如果没有触发该警告,或者触发了不同类型的警告,测试失败并抛出 AssertionError

这个函数特别适合用于当你希望测试一些过时的或将来可能发生变化的功能时。


例子:

1
2
3
4
5
6
7
8
9
10
import numpy as np
import warnings

def my_function():
    warnings.warn("This is a deprecated function", DeprecationWarning)

# 使用 assert_warns 测试是否会触发 DeprecationWarning
np.testing.assert_warns(DeprecationWarning, my_function)

print("测试成功!")

image-20250211113853262

在这个例子中,my_function 会触发一个 DeprecationWarningassert_warns 会确保这个警告被引发。如果没有引发该警告,测试将失败。

补充9:assert_string_equal

assert_string_equal 是 NumPy 中 numpy.testing 模块下的一个函数,用于比较两个字符串是否相等。它通常用于单元测试中,以确保两个字符串的值完全相同。如果字符串不相等,则测试会失败。assert_string_equal 用来断言两个字符串是否完全一致。它会比较两个字符串的内容,若不相等则抛出 AssertionError

函数签名:

1
numpy.testing.assert_string_equal(actual, desired)

参数:

actual:实际值,通常是你测试中的结果字符串。

desired:期望的字符串值。


例子:

1
2
3
4
5
6
7
8
9
import numpy as np

def greet(name):
    return f"Hello, {name}!"

# 测试 greet 函数
np.testing.assert_string_equal(greet("Alice"), "Hello, Alice!")
np.testing.assert_string_equal(greet("Bob"), "Hello, Bob!")

运行结果:代码没有报错。

在这个例子中,我们有一个 greet 函数,它根据传入的名字返回一个问候字符串。assert_string_equal 被用来验证函数返回的字符串是否与我们期望的字符串相同。

注:assert_string_equal 是严格的比较,即两个字符串必须完全相同(包括大小写、空格等)。如果有任何不同,测试就会失败。

如果你只关心字符串的一部分内容或不在乎空格、大小写差异,可能需要使用其他方法(如正则表达式或将字符串转换为小写后比较)。

补充10:assert_allclose

assert_allclose 是 NumPy 中 numpy.testing 模块下的一个函数,用于比较两个数组(或数值)是否在指定的容差范围内相等。它是用于浮动数值比较的常用方法,特别适用于那些可能因为精度问题而导致微小差异的情况。

函数签名:

1
numpy.testing.assert_allclose(actual, desired, rtol=1e-5, atol=1e-8, equal_nan=False)

参数:

  • actual:实际值,可以是一个数组或单个数值。
  • desired:期望值,可以是一个数组或单个数值。
  • rtol:相对容差(relative tolerance),默认为 1e-5,表示允许的相对误差。
  • atol:绝对容差(absolute tolerance),默认为 1e-8,表示允许的绝对误差。
  • equal_nan:是否将 NaN 视为相等。默认为 False,即 NaN 不会被认为是相等的。如果设置为 True,则会将所有 NaN 值视为相等。

assert_allclose 将通过以下公式来判断两个值是否相近:\(abs(actual−desired)≤atol+rtol×abs(desired)\)

如果满足该条件,则认为两个值在容忍范围内相等。


例子:不报错:

1
2
3
4
5
6
7
8
9
import numpy as np

# 实际结果和期望结果有微小差异
actual = np.array([1.0, 2.0, 3.0001])
desired = np.array([1.0, 2.0, 3.0])

# 使用 assert_allclose 来判断是否在允许的容差范围内
np.testing.assert_allclose(actual, desired, rtol=1e-4, atol=1e-6)

再看一个例子:不报错:

1
2
3
4
5
6
actual = np.array([1.0, 2.0, np.nan])
desired = np.array([1.0, 2.0, np.nan])

# 设置 equal_nan=True,表示 NaN 值是相等的
np.testing.assert_allclose(actual, desired, equal_nan=True)

下面是报错的例子:

1
2
3
4
5
6
actual = np.array([1.0, 2.0, 3.0])
desired = np.array([1.0, 2.0, 4.0])

# 这个测试会失败,因为差异超过了容差
np.testing.assert_allclose(actual, desired, rtol=1e-4, atol=1e-6)

image-20250211132058455

8.2 动手实践:使用 assert_almost_equal 断言近似相等

假设你有两个很接近的数字。我们用assert_almost_equal函数来检查它们是否近似相等。

(1) 调用函数,指定较低的精度(小数点后7位):

1
np.testing.assert_almost_equal(0.123456789, 0.123456780, decimal = 7)

image-20250211134421554

注意,这里没有抛出异常

(2) 调用函数,指定较高的精度(小数点后8位):

1
np.testing.assert_almost_equal(0.123456789, 0.123456780, decimal = 8)

image-20250211134520658

注:指定的这一位不同也不会报错。改为decimal=9,会报错:

image-20250211134610883

刚才做了些什么 : 我们使用NumPy testing包中的assert_almost_equal函数在不同的精度要求下检查了两个浮点数0.123456789和0.123456780是否近似相等。


突击测验:指定精度

问题1 以下哪一个是assert_almost_equal函数的参数,用来指定小数点后的精度?

(1) decimal (2) precision (3) tolerance (4) significant

答案:(1)decimal(可选,默认值为 6): 期望的精度,表示小数点后要比较的位数。例如,decimal=6 表示实际值和期望值的小数点后最多允许有 6 位差异。

8.3 近似相等

如果两个数字的近似程度没有达到指定的有效数字要求,assert_approx_equal函数将抛出异常。该函数触发异常的条件如下:

1
numpy.testing.assert_approx_equal(actual, desired, significant=7, err_msg='', verbose=True)
\[\vert actual - expected \vert >= 10^{-(significant - 1)}\]

8.4 动手实践:使用 assert_approx_equal 断言近似相等

我们仍使用前面“动手实践”教程中的数字,并使用assert_approx_equal函数对它们进行比较。

(1) 调用函数,指定较低的有效数字位:

1
print("Significance 8",np.testing.assert_approx_equal(0.123456789, 0.123456780, significant=8))

image-20250212143728233

没有异常。

(2) 调用函数,指定较高的有效数字位:

1
print("Significance 9",np.testing.assert_approx_equal(0.123456789, 0.123456780, significant=9))

image-20250212143948588

如上抛出了一个异常。

刚才做了些什么 我们使用numpy.testing包中的assert_approx_equal函数在不同的精度要求下检查了两个浮点数0.123456789和0.123456780是否近似相等。


8.5 数组近似相等

如果两个数组中元素的近似程度没有达到指定的精度要求, assert_array_ almost_equal函数将抛出异常。该函数首先检查两个数组的形状是否一致,然后逐一比较两个数组中的元素:

1
numpy.testing.assert_array_almost_equal(actual, desired, decimal=6, err_msg='', verbose=True)
\[\vert expected - actual \vert < 0.5 * 10^{-decimal }\]

8.6 动手实践:断言数组近似相等

我们使用前面“动手实践”教程中的数字,并各加上一个0来构造两个数组。

(1) 调用函数,指定较低的精度:

1
2
print("Decimal 8 :", np.testing.assert_array_almost_equal([0, 0.123456789], [0, 
0.123456780], decimal = 8))

image-20250212144447748

(2) 调用函数,指定较高的精度:

1
2
print("Decimal 9 :", np.testing.assert_array_almost_equal([0, 0.123456789], [0, 
0.123456780], decimal = 9))

报出异常:

image-20250212144525083

刚才做了些什么 : 我们使用NumPy中的assert_array_almost_equal函数比较了两个数组。

勇敢出发:比较形状不一致的数组

使用NumPy的assert_array_almost_equal函数比较两个形状不一致的数组。


1
2
3
4
5
n1 = np.random.randint(0, 10, size = (3, 4))
n2 = np.random.randint(0, 10, size = (3, 5))
display(n1)
display(n2)
print(np.testing.assert_array_almost_equal(n1, n2))

image-20250212144816761

因为默认精度是decimal = 6,所以报错是

1
AssertionError: Arrays are not almost equal to 6 decimals

并且说明了形状不匹配:

1
(shapes (3, 4), (3, 5) mismatch)

8.7 数组相等

如果两个数组对象不相同,assert_array_equal函数将抛出异常。两个数组相等必须形状一致且元素也严格相等,允许数组中存在NaN元素。

此外,比较数组也可以使用assert_allclose函数。该函数有参数atol(absolute tolerance,绝对容差限)和rtol(relative tolerance,相对容差限)。对于两个数组a和b,将测试是否满足以下等式:

image-20250212170255163

1
2
numpy.testing.assert_array_equal(actual, desired, err_msg='', verbose=True)
numpy.testing.assert_allclose(actual, desired, rtol=1e-5, atol=1e-8, equal_nan=False)

8.8 动手实践:比较数组

我们使用刚刚提到的函数来比较两个数组。我们仍使用前面“动手实践”教程中的数组,并增加一个NaN元素。

(1) 调用assert_allclose函数:

1
2
print("pass,", np.testing.assert_allclose([0, 0.123456789, np.nan], [0, 0.123456780, 
np.nan], rtol = 1e-7, atol = 0))

输出pass, None,即通过。

(2) 调用assert_array_equal函数:

1
print("fail,", np.testing.assert_array_equal([0, 0.123456789, np.nan], [0, 0.123456780, np.nan]))

image-20250212151912620

刚才做了些什么 : 我们分别使用assert_allcloseassert_array_equal函数比较了两个数组。


8.9 数组排序

两个数组必须形状一致并且第一个数组的元素严格小于第二个数组的元素,否则assert_ array_less函数将抛出异常。

8.10 动手实践:核对数组排序

我们来检查一个数组是否严格大于另一个数组。

1
numpy.testing.assert_array_less(a, b, err_msg='', verbose=True)

比较a是不是严格小于b。

(1) 调用assert_array_less函数比较两个有严格顺序的数组:

1
print("pass:", np.testing.assert_array_less([0, 0.123456789, np.nan], [1, 0.23456780, np.nan]))

image-20250212152221473

(2) 调用assert_array_less函数,使测试不通过:

1
print("fail:", np.testing.assert_array_less([0, 0.123456789, np.nan], [0, 0.123456780, np.nan]))

image-20250212152326222

错误信息是:

1
2
AssertionError: 
Arrays are not less-ordered

刚才做了些什么 : 我们使用assert_array_less函数比较了两个数组的大小顺序。

8.11 对象比较

如果两个对象不相同,assert_equal函数将抛出异常。这里的对象不一定为NumPy数组,也可以是Python中的列表、元组或字典。

1
numpy.testing.assert_equal(actual, desired, err_msg='', verbose=True)

8.12 动手实践:比较对象

假设你需要比较两个元组。我们可以用assert_equal函数来完成。

(1) 调用assert_equal函数:

1
print("equal?", np.testing.assert_equal((1, 2), (1, 3)))

image-20250212152616873

注:item=1 指的是在比较两个序列(或者数组)时,出现不相等的元素的位置或索引。即2与3不同。

刚才做了些什么 : 我们使用assert_equal函数比较了两个元组——两个元组并不相同,因此抛出了异常。

8.13 字符串比较

assert_string_equal函数断言两个字符串变量完全相同。如果测试不通过,将会抛出异常并显示两个字符串之间的差异。该函数区分字符大小写。

1
numpy.testing.assert_string_equal(actual, desired)

8.14 动手实践:比较字符串

比较两个均为NumPy的字符串。

(1) 调用assert_string_equal函数,比较一个字符串和其自身。显然,该测试应通过:

1
print("pass,", np.testing.assert_string_equal("NumPy","NumPy"))

image-20250212152852371

(2) 调用assert_string_equal函数,比较一个字符串和另一个字母完全相同但大小写有区别的字符串。该测试应抛出异常:

1
print("fail,", np.testing.assert_string_equal("NumPy XXXX","Numpy XXXX"))

image-20250212153006350

报错:AssertionError: Differences in strings:


刚才做了些什么 我们使用assert_string_equal函数比较了两个字符串。当字符大小写不匹配时抛出 异常。


8.15 浮点数比较

浮点数在计算机中是以不精确的方式表示的,这给比较浮点数带来了问题。

NumPy中的assert_array_almost_equal_nulpassert_array_max_ulp函数可以提供可靠的浮点数比较功能。

ULP是Unit of Least Precision的缩写,即浮点数的最小精度单位。根据IEEE 754标准,四则运算的误差必须保持在半个ULP之内。你可以用刻度尺来做对比。公制刻度尺的刻度通常精确到毫米,而更高精度的部分只能估读,误差上界通常认为是最小刻度值的一半,即半毫米。

机器精度(machine epsilon)是指浮点运算中的相对舍入误差上界。机器精度等于ULP相对于1的值。NumPy中的finfo函数可以获取机器精度。Python标准库也可以给出机器精度值,并应该与NumPy给出的结果一致。


8.16 动手实践:使用 assert_array_almost_equal_nulp 比较浮点数

我们在实践中学习assert_array_almost_equal_nulp函数。

(1) 使用finfo函数确定机器精度:

1
2
eps = np.finfo(float).eps
eps

image-20250212153330280

(2) 使用assert_array_almost_equal_nulp函数比较两个近似相等的浮点数1.01.0 * eps(epsilon),然后对1.0 + 2 * eps做同样的比较:

1
2
print("1", np.testing.assert_array_almost_equal_nulp(1.0, 1.0 + eps) )
print("2", np.testing.assert_array_almost_equal_nulp(1.0, 1.0 + 2 * eps) )

image-20250212153459760

刚才做了些什么 : 我们使用finfo函数获取了机器精度。随后,我们使用assert_array_almost_equal_nulp函数比较了1.01.0 + eps,测试通过,再加上一个机器精度则抛出了异常。

8.17 多 ULP 的浮点数比较

assert_array_max_ulp函数可以指定ULP的数量作为允许的误差上界。参数maxulp接受整数作为ULP数量的上限,默认值为1。

8.18 动手实践:设置 maxulp 并比较浮点数

我们仍使用前面“动手实践”教程中比较的浮点数,但在需要的时候设置maxulp为2。

(1) 使用finfo函数确定机器精度:

1
2
eps = np.finfo(float).eps
eps

image-20250212153632928

(2) 与前面的“动手实践”教程做相同的比较,但这里我们使用assert_array_max_ulp函数和适当的maxulp参数值:

1
2
print("1", np.testing.assert_array_max_ulp(1.0, 1.0 + eps) )
print("2", np.testing.assert_array_max_ulp(1.0, 1 + 2 * eps, maxulp=2) )

image-20250212153728473

上面测试都是通过的。

刚才做了些什么 :我们仍比较了前面“动手实践”教程中的浮点数,但在第二次比较时将maxulp设置为2。我们使用assert_array_max_ulp函数和适当的maxulp参数值通过了比较测试,并返回了指定的ULP数量。

8.19 单元测试

单元测试是对代码的一小部分进行自动化测试的单元,通常是一个函数或方法。Python中有用于单元测试的PyUnit API(Application Programming Interface,应用程序编程接口)。作为NumPy 用户,我们还可以使用前面学习过的断言函数。

8.20 动手实践:编写单元测试

我们将为一个简单的阶乘函数编写测试代码,检查所谓的程序主逻辑以及非法输入的情况。

(1) 首先,我们编写一个阶乘函数:

1
2
3
4
5
6
def factorial(n):  
    if n == 0:  
        return 1 
    if n < 0: 
        raise(ValueError, "Unexpected negative value" )
    return np.arange(1, n+1).cumprod() 

代码中使用了我们已经掌握的创建数组和累乘计算函数arange和cumprod,并增加了一些边界条件的判断。

(2) 现在我们来编写单元测试。编写一个包含单元测试的类,继承Python标准库unittest模块中的TestCase类。我们对阶乘函数进行如下调用测试:

注:我修改了原书的部分代码。但是功能还是差不多的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import unittest
import numpy as np

def factorial(n):  
    if n == 0:  
        return 1  
    if n < 0:  
        raise ValueError("Unexpected negative value")  # 负数抛出 ValueError
    return np.arange(1, n+1).cumprod()  

class FactorialTest(unittest.TestCase):
    def test_factorial(self): 
        # 计算3的阶乘,测试通过
        self.assertEqual(6, factorial(3)[-1]) 
        np.testing.assert_equal(np.array([1, 2, 6]), factorial(3))  

    def test_zero(self): 
        # 计算0的阶乘,测试通过
        self.assertEqual(1, factorial(0))  

    def test_negative(self): 
        # 计算负数的阶乘,这里故意写错:应该抛出 ValueError,但我们错误地期望它抛出 IndexError
        with self.assertRaises(IndexError):  # 故意写错,让测试失败
            factorial(-10)  

# 运行测试
unittest.main(argv=[''], exit=False)

我们有意使得其中一项测试不通过,输出结果如下所示:

image-20250212161458628

刚才做了些什么 : 我们对阶乘函数的程序主逻辑代码进行了测试,并有意使得边界条件的测试不通过。

8.21 nose 和测试装饰器

鼻子(nose)是长在嘴上方的器官,人和动物的呼吸和闻都依赖于它。nose同时也是一种Python框架,使得(单元)测试更加容易。

nose可以帮助你组织测试代码。根据nose的文档,“任何能够匹配testMatch正则表达式(默认为(?:^|[b_.-])[Tt]est)的Python源代码文件、文件夹或库都将被收集用于测试”。nose充分利用了装饰器(decorator)。Python装饰器是有一定含义的对函数或方法的注解。numpy.testing模块中有很多装饰器。


image-20250212161631044

此外,我们还可以调用decorate_methods函数,将装饰器应用到能够匹配正则表达式或字符串的类方法上。


8.22 动手实践:使用测试装饰器

我们将直接在测试函数上使用setastest装饰器。我们在另一个方法上也使用该装饰器,但将其禁用。此外,我们还将跳过一个测试,并使得另一个测试不通过。如果你仍未安装nose,请先完成安装步骤。

(1) 使用setuptools安装nose:

1
easy_install nose 

或者使用pip安装:

1
!pip install nose

image-20250212161807051

此处,说明一点。nose 作为测试框架,依赖于 setastestskipifknownfailureif 等 Nose 相关的装饰器。然而,Nose(nosetests)已经被废弃,因此需要 迁移到 pytestunittest 来在 Jupyter Notebook 里进行测试。

由于 nosetests 不能直接在 Jupyter Notebook 中运行,我选择 pytest 作为替代方案,或者使用 unittest 进行测试。


方法 1:转换为 pytest

安装pytest:

1
2
import sys
!{sys.executable} -m pip install pytest

image-20250212162616693

可以在 Jupyter Notebook 里 先把测试代码保存为 .py 文件,然后使用 IPython 的魔法命令 !pytest%run 运行测试。

首先保存 .py 文件后用 !pytest 运行

在 Jupyter Notebook 里执行以下代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# 1. 把测试代码写入 test_script.py
with open("test_script.py", "w") as f:
    f.write("""
import pytest

@pytest.mark.skip(reason="This test is set to False")
def test_false():
    pass  # 这个测试会被跳过

def test_true():
    pass  # 这个测试会被执行

@pytest.mark.skipif(True, reason="This test is skipped due to condition being True")
def test_skip():
    pass  # 这个测试会被跳过

@pytest.mark.xfail(reason="This test is expected to fail")
def test_alwaysfail():
    assert False  # 这个测试会失败,但 pytest 认为是 "expected failure"

class TestClass:
    def test_true2(self):
        pass  # 这个测试会被执行

class TestClass2:
    def test_false2(self):
        pass  # 这个测试也会被执行
""")

# 2. 运行 pytest 进行测试
!pytest test_script.py --maxfail=1 --disable-warnings -q

image-20250212163022863

方法 2:转换为 unittest

unittest.skipunittest.expectedFailure 来代替 skipifknownfailureif

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import unittest

class TestCases(unittest.TestCase):

    @unittest.skip("This test is set to False")
    def test_false(self):
        pass  # 这个测试会被跳过

    def test_true(self):
        pass  # 这个测试会被执行

    @unittest.skipIf(True, "This test is skipped due to condition being True")
    def test_skip(self):
        pass  # 这个测试会被跳过

    @unittest.expectedFailure
    def test_alwaysfail(self):
        self.assertFalse(True)  # 这个测试会失败,但 unittest 认为是 "预期失败"

class TestClass(unittest.TestCase):
    def test_true2(self):
        pass  # 这个测试会被执行

class TestClass2(unittest.TestCase):
    def test_false2(self):
        pass  # 这个测试也会被执行

# 在 Jupyter Notebook 中运行 unittest
unittest.main(argv=[''], exit=False)

image-20250212163247570

8.23 文档字符串

文档字符串(docstring)是内嵌在Python代码中的类似交互式会话的字符串。这些字符串可以用于某些测试,也可以仅用于提供使用示例。numpy.testing模块中有一个函数可以运行这些测试。

8.24 动手实践:执行文档字符串测试

我们来编写一个简单的计算阶乘的例子,但不考虑所有的边界条件。换言之,编写一些测试不能通过的例子。

(1) 文档字符串看起来就像你在Python shell中看到的文本一样(包括命令提示符)。我们将有意使得其中一项测试不通过,看看会发生什么。

1
2
3
4
5
6
7
8
9
"""
Test for the factorial of 3 that should pass.
>>> factorial(3)
6

Test for the factorial of 0 that should fail.
>>> factorial(0)
1
"""

(2) 我们将用下面这一行NumPy代码来计算阶乘:

1
return np.arange(1, n+1).cumprod()[-1] 

为了演示目的,这行代码有时会出错。

(3) 我们可以在Python shell中通过调用numpy.testing模块的rundocs函数,从而执行文档字符串测试。

1
2
3
4
import numpy as np
import doctest

doctest.testmod()

完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import numpy as np
import doctest

def factorial(n):
    """
    Test for the factorial of 3 that should pass.
    >>> factorial(3)
    6

    Test for the factorial of 0 that should fail.
    >>> factorial(0)
    1
    """
    return np.arange(1, n+1).cumprod()[-1]

doctest.testmod()

测试结果:

image-20250212165338431

刚才做了些什么 : 我们编写了一个文档字符串测试,在对应的阶乘函数中没有考虑0和负数的情况。我们使用numpy.testing模块中的rundocs函数执行了测试,并得到了“索引错误”的结果。

8.25 本章小结

在本章中,我们学习了代码测试和NumPy中的测试工具。 涵盖的内容包括单元测试、文档字符串测试、断言函数和浮点数精度。大部分NumPy断言函数都与浮点数精度有关。我们演示了可以被nose使用的Numpy装饰器的用法。装饰器使得测试更加容易使用,并体现开发者的意图。

下一章将要讨论的是Matplotlib——开源的Python科学可视化和绘图工具库。