PySpark DataFrame 添加自增 ID

在用 Spark 处理数据的时候,经常需要给全量数据增加一列自增 ID 序号,在存入数据库的时候,自增 ID 也常常是一个很关键的要素。
在 DataFrame 的 API 中没有实现这一功能,所以只能通过其他方式实现,或者转成 RDD 再用 RDD 的 zipWithIndex 算子实现。
下面呢就介绍三种实现方式。

创建 DataFrame 对象

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

df = spark.createDataFrame(
[
{"name": "Alice", "age": 18},
{"name": "Sitoi", "age": 22},
{"name": "Shitao", "age": 22},
{"name": "Tom", "age": 7},
{"name": "De", "age": 17},
{"name": "Apple", "age": 45}
]
)
df.show()

输出:

1
2
3
4
5
6
7
8
9
10
+---+------+
|age| name|
+---+------+
| 18| Alice|
| 22| Sitoi|
| 22|Shitao|
| 7| Tom|
| 17| De|
| 45| Apple|
+---+------+

方式一:monotonically_increasing_id()

使用自带函数 monotonically_increasing_id() 创建,由于 spark 会有分区,所以生成的 ID 保证单调增加且唯一,但不是连续的

优点:对于没有分区的文件,处理速度快。
缺点:由于 spark 的分区,会导致,ID 不是连续增加。

1
2
df = df.withColumn("id", monotonically_increasing_id())
df.show()

输出:

1
2
3
4
5
6
7
8
9
10
+---+------+-----------+
|age| name| id|
+---+------+-----------+
| 18| Alice| 8589934592|
| 22| Sitoi|17179869184|
| 22|Shitao|25769803776|
| 7| Tom|42949672960|
| 17| De|51539607552|
| 45| Apple|60129542144|
+---+------+-----------+

如果读取本地的单个 CSV 文件 或 JSON 文件,ID 会是连续增加且唯一的。

方法二:窗口函数

利用窗口函数:设置窗口函数的分区以及排序,因为是全局排序而不是分组排序,所有分区依据为空,排序规则没有特殊要求也可以随意填写

优点:保证 ID 连续增加且唯一
缺点:运行速度满,并且数据量过大会爆内存,需要排序,会改变原始数据顺序。

1
2
3
4
5
from pyspark.sql.functions import row_number

spec = Window.partitionBy().orderBy("age")
df = df.withColumn("id", row_number().over(spec))
df.show()

输出:

1
2
3
4
5
6
7
8
9
10
+---+------+---+
|age| name| id|
+---+------+---+
| 7| Tom| 1|
| 17| De| 2|
| 18| Alice| 3|
| 22| Sitoi| 4|
| 22|Shitao| 5|
| 45| Apple| 6|
+---+------+---+

方法三:RDD 的 zipWithIndex 算子

转成 RDD 再用 RDD 的 zipWithIndex 算子实现

优点:保证 ID 连续 增加且唯一。
缺点:运行速度慢。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from pyspark.sql import SparkSession
from pyspark.sql.functions import monotonically_increasing_id
from pyspark.sql.types import StructField, LongType

spark = SparkSession.builder.getOrCreate()

schema = df.schema.add(StructField("id", LongType()))
rdd = df.rdd.zipWithIndex()


def flat(l):
for k in l:
if not isinstance(k, (list, tuple)):
yield k
else:
yield from flat(k)


rdd = rdd.map(lambda x: list(flat(x)))
df = spark.createDataFrame(rdd, schema)
df.show()

输出:

1
2
3
4
5
6
7
8
9
10
+---+------+---+
|age| name| id|
+---+------+---+
| 18| Alice| 0|
| 22| Sitoi| 1|
| 22|Shitao| 2|
| 7| Tom| 3|
| 17| De| 4|
| 45| Apple| 5|
+---+------+---+