PySpark DataFrame 添加自增 ID
在用 Spark
处理数据的时候,经常需要给全量数据增加一列自增 ID 序号
,在存入数据库的时候,自增 ID
也常常是一个很关键的要素。
在 DataFrame 的 API 中没有实现这一功能,所以只能通过其他方式实现,或者转成 RDD 再用 RDD 的 zipWithIndex
算子实现。
下面呢就介绍三种实现方式。
创建 DataFrame 对象
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
| from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame( [ {"name": "Alice", "age": 18}, {"name": "Sitoi", "age": 22}, {"name": "Shitao", "age": 22}, {"name": "Tom", "age": 7}, {"name": "De", "age": 17}, {"name": "Apple", "age": 45} ] ) df.show()
|
输出:
1 2 3 4 5 6 7 8 9 10
| +---+------+ |age| name| +---+------+ | 18| Alice| | 22| Sitoi| | 22|Shitao| | 7| Tom| | 17| De| | 45| Apple| +---+------+
|
方式一:monotonically_increasing_id()
使用自带函数 monotonically_increasing_id()
创建,由于 spark 会有分区
,所以生成的 ID 保证单调增加且唯一,但不是连续的
。
优点:对于没有分区的文件,处理速度快。
缺点:由于 spark 的分区,会导致,ID 不是连续增加。
1 2
| df = df.withColumn("id", monotonically_increasing_id()) df.show()
|
输出:
1 2 3 4 5 6 7 8 9 10
| +---+------+-----------+ |age| name| id| +---+------+-----------+ | 18| Alice| 8589934592| | 22| Sitoi|17179869184| | 22|Shitao|25769803776| | 7| Tom|42949672960| | 17| De|51539607552| | 45| Apple|60129542144| +---+------+-----------+
|
如果读取本地的单个 CSV 文件 或 JSON 文件,ID 会是连续增加且唯一的。
方法二:窗口函数
利用窗口函数
:设置窗口函数的分区以及排序,因为是全局排序而不是分组排序,所有分区依据为空,排序规则没有特殊要求也可以随意填写
优点:保证 ID 连续增加且唯一
。
缺点:运行速度满,并且数据量过大会爆内存
,需要排序
,会改变原始数据顺序。
1 2 3 4 5
| from pyspark.sql.functions import row_number
spec = Window.partitionBy().orderBy("age") df = df.withColumn("id", row_number().over(spec)) df.show()
|
输出:
1 2 3 4 5 6 7 8 9 10
| +---+------+---+ |age| name| id| +---+------+---+ | 7| Tom| 1| | 17| De| 2| | 18| Alice| 3| | 22| Sitoi| 4| | 22|Shitao| 5| | 45| Apple| 6| +---+------+---+
|
方法三:RDD 的 zipWithIndex 算子
转成 RDD 再用 RDD 的 zipWithIndex 算子实现
优点:保证 ID 连续 增加且唯一。
缺点:运行速度慢。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| from pyspark.sql import SparkSession from pyspark.sql.functions import monotonically_increasing_id from pyspark.sql.types import StructField, LongType
spark = SparkSession.builder.getOrCreate()
schema = df.schema.add(StructField("id", LongType())) rdd = df.rdd.zipWithIndex()
def flat(l): for k in l: if not isinstance(k, (list, tuple)): yield k else: yield from flat(k)
rdd = rdd.map(lambda x: list(flat(x))) df = spark.createDataFrame(rdd, schema) df.show()
|
输出:
1 2 3 4 5 6 7 8 9 10
| +---+------+---+ |age| name| id| +---+------+---+ | 18| Alice| 0| | 22| Sitoi| 1| | 22|Shitao| 2| | 7| Tom| 3| | 17| De| 4| | 45| Apple| 5| +---+------+---+
|