SparkSession:唯一入口
SparkSession 是 Spark 2.0 引入的统一入口,整合了旧版本中 SparkContext(RDD 操作)、SQLContext(SQL 查询)、HiveContext(Hive 集成)三个对象。一个应用中通常只创建一个 SparkSession 实例。
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *
spark = SparkSession.builder \
.appName("UserBehaviorAnalysis") \
.master("local[4]") \ # local[4] 使用 4 个本地线程
.config("spark.sql.shuffle.partitions", "8") \ # 本地开发调小分区数
.config("spark.driver.memory", "4g") \
.enableHiveSupport() \ # 可选:启用 Hive 元数据
.getOrCreate()
# 访问底层 SparkContext(需要时)
sc = spark.sparkContext
print(sc.defaultParallelism) # 默认并行度 = CPU 核数
读取数据
Spark 通过统一的 DataFrameReader API 读取各种数据源,格式如下:spark.read.format(...).option(...).load(path),也可以用简化的链式方法。
读取 CSV
# 方式一:简化链式方法
df_csv = spark.read \
.option("header", "true") \ # 第一行是列名
.option("inferSchema", "true") \ # 自动推断类型(生产建议显式定义)
.option("encoding", "UTF-8") \
.csv("data/user_events.csv")
# 方式二:显式定义 Schema(推荐生产使用,避免全扫描推断类型)
schema = StructType([
StructField("user_id", LongType(), True),
StructField("event_type", StringType(), True),
StructField("item_id", LongType(), True),
StructField("timestamp", TimestampType(), True),
StructField("amount", DoubleType(), True),
])
df = spark.read.csv("data/events.csv", header=True, schema=schema)
读取 JSON / Parquet / JDBC
# JSON(每行一个 JSON 对象,即 NDJSON 格式)
df_json = spark.read.json("data/logs/*.json") # 支持通配符
# Parquet(推荐格式:列式存储,自带 Schema,压缩高效)
df_parquet = spark.read.parquet("data/events.parquet")
df_part = spark.read.parquet("s3://bucket/events/dt=2024-01-*") # 分区裁剪
# JDBC(读取关系型数据库,如 MySQL / PostgreSQL)
df_jdbc = spark.read \
.format("jdbc") \
.option("url", "jdbc:mysql://host:3306/shop") \
.option("dbtable", "orders") \
.option("user", "reader") \
.option("password", "secret") \
.option("numPartitions", "16") \ # 并行读取分区数
.option("partitionColumn", "order_id") \
.option("lowerBound", "1") \
.option("upperBound", "10000000") \
.load()
DataFrame 基础操作
查看数据:show / printSchema / describe
df.show(5) # 显示前5行(默认20行),截断长字符串
df.show(5, truncate=False) # 不截断
df.printSchema() # 打印列名、类型、是否可为 null
df.describe().show() # 数值列的统计摘要(count/mean/std/min/max)
df.dtypes # 返回 [(列名, 类型字符串)] 列表
df.count() # 行数(Action,会触发计算)
df.columns # 列名列表
选择列:select
# select 支持多种写法
df.select("user_id", "event_type") # 字符串列名
df.select(df.user_id, df.event_type) # 列对象(.属性访问)
df.select(F.col("user_id"), F.col("event_type")) # F.col()(推荐)
# 在 select 中进行计算
df.select(
"user_id",
F.col("amount") * 0.9.alias("discounted_amount"),
F.upper(F.col("event_type")).alias("event_upper"),
)
过滤行:filter / where
# filter 和 where 等价,推荐用 filter
df.filter(F.col("event_type") == "purchase")
df.filter("event_type = 'purchase'") # SQL 字符串写法也支持
df.filter(
(F.col("amount") > 100) &
(F.col("event_type").isin(["purchase", "refund"])) &
F.col("user_id").isNotNull()
)
新增/修改列:withColumn
df = df \
.withColumn("year", F.year("timestamp")) \
.withColumn("month", F.month("timestamp")) \
.withColumn("is_vip", F.col("amount") > 1000) \
.withColumn("amount_rmb", F.col("amount") * 7.2) \
.withColumnRenamed("user_id", "uid") \ # 重命名
.drop("redundant_col") # 删除列
类型系统
Spark 有完整的类型系统,在定义 Schema 时需要明确指定。常用类型:
| Python 类型 | Spark 类型 | 说明 |
|---|---|---|
| str | StringType() | 字符串,UTF-8 |
| int | IntegerType() / LongType() | 32/64 位整数 |
| float | FloatType() / DoubleType() | 32/64 位浮点 |
| bool | BooleanType() | 布尔值 |
| datetime | TimestampType() | 时间戳,含时区 |
| date | DateType() | 日期,不含时间 |
| list | ArrayType(elementType) | 数组 |
| dict | MapType(keyType, valueType) | 字典 |
| struct | StructType([StructField...]) | 嵌套结构体 |
实战:用户行为日志分析
以一份包含 1 亿条用户行为日志的 CSV 文件为例,分析各类型事件数量、活跃用户、高价值购买行为。
from pyspark.sql import SparkSession, functions as F
from pyspark.sql.types import *
spark = SparkSession.builder.appName("UserLog").getOrCreate()
# 1. 定义 Schema 并读取数据
schema = StructType([
StructField("user_id", LongType(), True),
StructField("event", StringType(), True), # click/view/purchase/cart
StructField("item_id", LongType(), True),
StructField("category", StringType(), True),
StructField("amount", DoubleType(), True),
StructField("ts", TimestampType(), True),
])
df = spark.read.csv("user_events.csv", header=True, schema=schema)
# 2. 数据清洗:过滤无效行
df_clean = df \
.filter(F.col("user_id").isNotNull()) \
.filter(F.col("event").isin(["click", "view", "purchase", "cart"])) \
.withColumn("date", F.to_date("ts")) \
.withColumn("hour", F.hour("ts"))
# 3. 缓存清洗后的数据(后续多次使用)
df_clean.cache()
# 4. 各事件类型统计
df_clean.groupBy("event") \
.count() \
.orderBy(F.desc("count")) \
.show()
# 5. 日活跃用户数(DAU)
dau = df_clean.groupBy("date") \
.agg(F.countDistinct("user_id").alias("dau")) \
.orderBy("date")
dau.show(10)
# 6. 高价值购买用户(总消费 > 10000)
high_value = df_clean \
.filter(F.col("event") == "purchase") \
.groupBy("user_id") \
.agg(
F.sum("amount").alias("total_spent"),
F.count("*").alias("order_count")
) \
.filter(F.col("total_spent") > 10000) \
.orderBy(F.desc("total_spent"))
print(f"高价值用户数: {high_value.count()}")
high_value.show(20)
# 7. 保存结果为 Parquet(分区写入)
high_value.write \
.mode("overwrite") \
.parquet("output/high_value_users")
生产建议:显式定义 Schema 而非 inferSchema。推断 Schema 需要扫描一遍数据(采样),在大文件上会增加启动延迟,且对于复杂嵌套结构推断常常不准确。