Chapter 02

PySpark DataFrame API

掌握 Spark 最核心的数据操作接口,从读取多种格式数据到复杂的列转换,以用户行为日志分析为实战

SparkSession:唯一入口

SparkSession 是 Spark 2.0 引入的统一入口,整合了旧版本中 SparkContext(RDD 操作)、SQLContext(SQL 查询)、HiveContext(Hive 集成)三个对象。一个应用中通常只创建一个 SparkSession 实例。

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *

spark = SparkSession.builder \
    .appName("UserBehaviorAnalysis") \
    .master("local[4]") \           # local[4] 使用 4 个本地线程
    .config("spark.sql.shuffle.partitions", "8") \  # 本地开发调小分区数
    .config("spark.driver.memory", "4g") \
    .enableHiveSupport() \          # 可选:启用 Hive 元数据
    .getOrCreate()

# 访问底层 SparkContext(需要时)
sc = spark.sparkContext
print(sc.defaultParallelism)    # 默认并行度 = CPU 核数

读取数据

Spark 通过统一的 DataFrameReader API 读取各种数据源,格式如下:spark.read.format(...).option(...).load(path),也可以用简化的链式方法。

读取 CSV

# 方式一:简化链式方法
df_csv = spark.read \
    .option("header", "true") \        # 第一行是列名
    .option("inferSchema", "true") \   # 自动推断类型(生产建议显式定义)
    .option("encoding", "UTF-8") \
    .csv("data/user_events.csv")

# 方式二:显式定义 Schema(推荐生产使用,避免全扫描推断类型)
schema = StructType([
    StructField("user_id",    LongType(),      True),
    StructField("event_type", StringType(),    True),
    StructField("item_id",    LongType(),      True),
    StructField("timestamp",  TimestampType(), True),
    StructField("amount",     DoubleType(),    True),
])
df = spark.read.csv("data/events.csv", header=True, schema=schema)

读取 JSON / Parquet / JDBC

# JSON(每行一个 JSON 对象,即 NDJSON 格式)
df_json = spark.read.json("data/logs/*.json")  # 支持通配符

# Parquet(推荐格式:列式存储,自带 Schema,压缩高效)
df_parquet = spark.read.parquet("data/events.parquet")
df_part = spark.read.parquet("s3://bucket/events/dt=2024-01-*")  # 分区裁剪

# JDBC(读取关系型数据库,如 MySQL / PostgreSQL)
df_jdbc = spark.read \
    .format("jdbc") \
    .option("url", "jdbc:mysql://host:3306/shop") \
    .option("dbtable", "orders") \
    .option("user", "reader") \
    .option("password", "secret") \
    .option("numPartitions", "16") \      # 并行读取分区数
    .option("partitionColumn", "order_id") \
    .option("lowerBound", "1") \
    .option("upperBound", "10000000") \
    .load()

DataFrame 基础操作

查看数据:show / printSchema / describe

df.show(5)               # 显示前5行(默认20行),截断长字符串
df.show(5, truncate=False) # 不截断
df.printSchema()         # 打印列名、类型、是否可为 null
df.describe().show()      # 数值列的统计摘要(count/mean/std/min/max)
df.dtypes                # 返回 [(列名, 类型字符串)] 列表
df.count()               # 行数(Action,会触发计算)
df.columns               # 列名列表

选择列:select

# select 支持多种写法
df.select("user_id", "event_type")          # 字符串列名
df.select(df.user_id, df.event_type)        # 列对象(.属性访问)
df.select(F.col("user_id"), F.col("event_type"))  # F.col()(推荐)

# 在 select 中进行计算
df.select(
    "user_id",
    F.col("amount") * 0.9.alias("discounted_amount"),
    F.upper(F.col("event_type")).alias("event_upper"),
)

过滤行:filter / where

# filter 和 where 等价,推荐用 filter
df.filter(F.col("event_type") == "purchase")
df.filter("event_type = 'purchase'")  # SQL 字符串写法也支持
df.filter(
    (F.col("amount") > 100) &
    (F.col("event_type").isin(["purchase", "refund"])) &
    F.col("user_id").isNotNull()
)

新增/修改列:withColumn

df = df \
    .withColumn("year", F.year("timestamp")) \
    .withColumn("month", F.month("timestamp")) \
    .withColumn("is_vip", F.col("amount") > 1000) \
    .withColumn("amount_rmb", F.col("amount") * 7.2) \
    .withColumnRenamed("user_id", "uid") \  # 重命名
    .drop("redundant_col")                 # 删除列

类型系统

Spark 有完整的类型系统,在定义 Schema 时需要明确指定。常用类型:

Python 类型Spark 类型说明
strStringType()字符串,UTF-8
intIntegerType() / LongType()32/64 位整数
floatFloatType() / DoubleType()32/64 位浮点
boolBooleanType()布尔值
datetimeTimestampType()时间戳,含时区
dateDateType()日期,不含时间
listArrayType(elementType)数组
dictMapType(keyType, valueType)字典
structStructType([StructField...])嵌套结构体

实战:用户行为日志分析

以一份包含 1 亿条用户行为日志的 CSV 文件为例,分析各类型事件数量、活跃用户、高价值购买行为。

from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import *

spark = SparkSession.builder.appName("UserLog").getOrCreate()

# 1. 定义 Schema 并读取数据
schema = StructType([
    StructField("user_id",    LongType(),      True),
    StructField("event",      StringType(),    True),   # click/view/purchase/cart
    StructField("item_id",    LongType(),      True),
    StructField("category",   StringType(),    True),
    StructField("amount",     DoubleType(),    True),
    StructField("ts",         TimestampType(), True),
])

df = spark.read.csv("user_events.csv", header=True, schema=schema)

# 2. 数据清洗:过滤无效行
df_clean = df \
    .filter(F.col("user_id").isNotNull()) \
    .filter(F.col("event").isin(["click", "view", "purchase", "cart"])) \
    .withColumn("date", F.to_date("ts")) \
    .withColumn("hour", F.hour("ts"))

# 3. 缓存清洗后的数据(后续多次使用)
df_clean.cache()

# 4. 各事件类型统计
df_clean.groupBy("event") \
    .count() \
    .orderBy(F.desc("count")) \
    .show()

# 5. 日活跃用户数(DAU)
dau = df_clean.groupBy("date") \
    .agg(F.countDistinct("user_id").alias("dau")) \
    .orderBy("date")
dau.show(10)

# 6. 高价值购买用户(总消费 > 10000)
high_value = df_clean \
    .filter(F.col("event") == "purchase") \
    .groupBy("user_id") \
    .agg(
        F.sum("amount").alias("total_spent"),
        F.count("*").alias("order_count")
    ) \
    .filter(F.col("total_spent") > 10000) \
    .orderBy(F.desc("total_spent"))

print(f"高价值用户数: {high_value.count()}")
high_value.show(20)

# 7. 保存结果为 Parquet(分区写入)
high_value.write \
    .mode("overwrite") \
    .parquet("output/high_value_users")

生产建议:显式定义 Schema 而非 inferSchema。推断 Schema 需要扫描一遍数据(采样),在大文件上会增加启动延迟,且对于复杂嵌套结构推断常常不准确。