pandas超级重要函数-transform函数

返回与输入组具有相同形状的结果的transform函数

Posted by Hilda on February 28, 2025

[TOC]

NumPy 博客总结:

《Python数据分析基础教程:NumPy学习指南(第2版)》所有章节阅读笔记+代码

70道NumPy 面试题(题目+答案)

pandas博客总结:

pandas(1)数据预处理

pandas(2)数据分析

pandas(3)常用函数操作

pandas(4)大数据处理技巧

【力扣】pandas入门15题


aggregation会返回数据的缩减版本,而transform能返回完整数据的某一变换版本供我们重组。这样的transformation,输出的形状和输入一致。

如果想要体会下transform函数的魅力,建议可以通过一道力扣的题目体会:184. 部门工资最高的员工

transform() 函数的主要作用是在分组数据上进行逐元素的转换,并保持原始数据的索引结构。 它与 apply() 函数类似,但 apply() 函数可以返回任意形状的结果,而 transform() 必须返回与输入组具有相同形状的结果。

函数签名:

1
df.groupby(by)[column].transform(func, *args, **kwargs)
  • func: 要应用的函数。 它可以是:
    • 内置函数: 例如 sum, mean, max, min 等。 但通常不直接使用内置函数,而是结合 lambda 表达式或自定义函数使用。
    • lambda 表达式: 用于定义简单的匿名函数。
    • 自定义函数: 可以定义更复杂的函数,进行更精细的转换。
  • *args: 传递给 func 的位置参数。
  • **kwargs: 传递给 func 的关键字参数。(kw=keyword)

标准化 (Standardization) 或归一化 (Normalization)

将每个组的数据缩放到一个特定的范围,或者使其具有特定的均值和标准差。

1
2
3
4
5
6
7
8
9
import pandas as pd

data = {'Category': ['A', 'A', 'B', 'B', 'A', 'B'],
        'Value': [10, 15, 20, 25, 12, 22]}
df = pd.DataFrame(data)
display(df)
# 对每个 Category 组的 Value 列进行标准化 (Z-score)
df['Value_Standardized'] = df.groupby('Category')['Value'].transform(lambda x: (x - x.mean()) / x.std())
print(df)

image-20250228132418866

缺失值填充 (Missing Value Imputation)

使用 transform() 基于组的均值、中位数或其他统计量来填充缺失值,结合fillna函数

1
2
3
4
5
6
7
8
9
10
import pandas as pd
import numpy as np

data = {'Category': ['A', 'A', 'B', 'B', 'A', 'B'],
        'Value': [10, np.nan, 20, 25, 12, np.nan]}
df = pd.DataFrame(data)
display(df)
# 用每个 Category 组的均值填充 Value 列的缺失值
df['Value_Filled'] = df.groupby('Category')['Value'].transform(lambda x: x.fillna(x.mean()))
display(df)

image-20250228132631439

计算组内的排名 (Ranking)

练习 力扣:184. 部门工资最高的员工

1
2
3
4
5
6
7
8
9
import pandas as pd

data = {'Category': ['A', 'A', 'B', 'B', 'A', 'B'],
        'Value': [10, 15, 20, 25, 12, 22]}
df = pd.DataFrame(data)
display(df)
# 计算每个 Category 组内 Value 列的排名
df['Value_Rank'] = df.groupby('Category')['Value'].transform(lambda x: x.rank())
display(df)

image-20250228132754738

创建滞后 (Lagged) 或超前 (Lead) 值

使用 shift() 函数结合 transform()

1
2
3
4
5
6
7
8
9
import pandas as pd

data = {'Category': ['A', 'A', 'B', 'B', 'A', 'B'],
        'Value': [10, 15, 20, 25, 12, 22]}
df = pd.DataFrame(data)
display(df)
# 创建每个 Category 组内 Value 列的滞后 1 期值
df['Value_Lagged'] = df.groupby('Category')['Value'].transform(lambda x: x.shift(1))
display(df)

image-20250228133048305

shift(1) 是 Pandas Series 的一个方法,它将 Series 中的每个值向下移动一位。 换句话说,它将当前行的值替换为前一行的值。 第一个值由于没有前一行,会被替换为 NaN (Not a Number),表示缺失值。

注意

  • 传递给 transform() 的函数必须能够处理 Series 或 DataFrame 作为输入,并且返回相同长度的 Series 或 DataFrame。
  • 避免在 transform() 中使用循环,因为它会降低性能。尽量使用 Pandas 内置的函数或矢量化操作。

理解矢量化思维是提高 Pandas 代码性能的关键。目标是将操作表达为可以一次性应用于整个数据结构的操作,而不是逐个元素地处理。