Pipeline API 概念
Spark MLlib 的核心是 Pipeline API,它将特征工程和模型训练的多个步骤串联成一个可复用的工作流,确保训练和预测使用完全一致的处理逻辑。
-
Transformer
接受 DataFrame,输出新的 DataFrame(添加/转换列)。例如:VectorAssembler(将多列合并为特征向量)、StandardScalerModel(已拟合的标准化器)。调用
.transform(df)。 -
Estimator
接受 DataFrame,拟合(fit)后生成 Transformer(即训练好的模型)。例如:RandomForestClassifier、StandardScaler(未训练状态)。调用
.fit(df)返回 Model。 -
Pipeline
按顺序组合多个 Transformer 和 Estimator 的工作流。
Pipeline.fit(trainDF)依次训练所有 Estimator,PipelineModel.transform(testDF)依次应用所有 Transformer。 -
ParamGrid / CrossValidator
超参数网格搜索 + k 折交叉验证。
ParamGridBuilder定义参数空间,CrossValidator对所有参数组合进行交叉验证,选出最优模型。
特征工程
from pyspark.ml.feature import (
VectorAssembler, StringIndexer, OneHotEncoder,
StandardScaler, MinMaxScaler, Bucketizer, SQLTransformer
)
# 1. 字符串索引:将分类列(字符串)转为数值索引
gender_idx = StringIndexer(
inputCol="gender",
outputCol="gender_idx",
handleInvalid="keep" # 遇到未知类别:keep(单独分类) / skip(过滤) / error(报错)
)
# 2. One-Hot 编码:将索引转为稀疏向量(避免引入序数关系)
gender_ohe = OneHotEncoder(
inputCols=["gender_idx"],
outputCols=["gender_ohe"]
)
# 3. 数值特征合并:将多个数值列合并为一个特征向量
assembler = VectorAssembler(
inputCols=["age", "total_orders", "avg_spend", "days_since_login", "gender_ohe"],
outputCol="features_raw",
handleInvalid="skip" # 跳过含有 null 的行
)
# 4. 标准化(Zero Mean, Unit Variance)
scaler = StandardScaler(
inputCol="features_raw",
outputCol="features",
withMean=True,
withStd=True
)
分类、回归、聚类算法
from pyspark.ml.classification import (
LogisticRegression, RandomForestClassifier,
GBTClassifier, LinearSVC
)
from pyspark.ml.regression import LinearRegression, RandomForestRegressor
from pyspark.ml.clustering import KMeans, BisectingKMeans
# 随机森林分类器
rf = RandomForestClassifier(
featuresCol="features",
labelCol="label", # 目标列(0/1 二分类)
numTrees=100,
maxDepth=10,
seed=42
)
# 梯度提升树(GBT)
gbt = GBTClassifier(
featuresCol="features",
labelCol="label",
maxIter=50,
stepSize=0.1
)
# K-Means 聚类(无监督)
kmeans = KMeans(
featuresCol="features",
k=5, # 聚类数
maxIter=20,
seed=42
)
模型评估:CrossValidator
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
evaluator = BinaryClassificationEvaluator(
labelCol="label",
metricName="areaUnderROC" # AUC-ROC
)
# 超参数网格
param_grid = ParamGridBuilder() \
.addGrid(rf.numTrees, [50, 100, 200]) \
.addGrid(rf.maxDepth, [5, 10, 15]) \
.build() # 共 3×3 = 9 种组合
# 3 折交叉验证
cv = CrossValidator(
estimator=pipeline,
estimatorParamMaps=param_grid,
evaluator=evaluator,
numFolds=3,
parallelism=4 # 并行评估 4 个参数组合
)
cv_model = cv.fit(train_df)
print(f"最优 AUC: {max(cv_model.avgMetrics):.4f}")
实战:用户流失预测完整 Pipeline
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
# 1. 准备数据:用户行为特征
# label=1 表示 30 天内流失(未登录),label=0 表示活跃
df = spark.read.parquet("user_features.parquet")
# 2. 特征工程
cat_cols = ["gender", "city_tier", "device_type"]
num_cols = ["age", "total_orders", "avg_spend", "days_since_login",
"session_count", "cart_abandon_rate"]
# 字符串索引
indexers = [StringIndexer(inputCol=c, outputCol=c+"_idx", handleInvalid="keep")
for c in cat_cols]
# 合并所有特征
all_features = num_cols + [c+"_idx" for c in cat_cols]
assembler = VectorAssembler(inputCols=all_features, outputCol="features_raw", handleInvalid="skip")
scaler = StandardScaler(inputCol="features_raw", outputCol="features")
# 模型
gbt = GBTClassifier(featuresCol="features", labelCol="label", maxIter=100, seed=42)
# 3. 构建 Pipeline
pipeline = Pipeline(stages=indexers + [assembler, scaler, gbt])
# 4. 训练/测试集划分
train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)
# 5. 训练
model = pipeline.fit(train_df)
# 6. 评估
predictions = model.transform(test_df)
evaluator = BinaryClassificationEvaluator(labelCol="label")
auc = evaluator.evaluate(predictions)
print(f"AUC-ROC: {auc:.4f}")
# 7. 查看特征重要性
gbt_model = model.stages[-1]
feature_importance = sorted(
zip(all_features, gbt_model.featureImportances.toArray()),
key=lambda x: -x[1]
)
for feat, imp in feature_importance[:10]:
print(f" {feat:30s} {imp:.4f}")
# 8. 保存 Pipeline 模型
model.write().overwrite().save("s3://models/churn_gbt_v1")
MLlib vs Pandas + scikit-learn:MLlib 适合处理亿级以上的训练数据。对于几百万行的数据集,直接用 Pandas + scikit-learn 更快(无分布式调度开销)。最佳实践:用 Spark 做特征工程,用 toPandas() 收集特征矩阵,再调用 scikit-learn 训练。