问题是这样的,如果我们想基于pyspark开发一个分布式机器训练平台,而xgboost是不可或缺的模型,但是pyspark ml中没有对应的API,这时候我们需要想办法解决它。
测试代码: ( (pyspark使用可以参考这个:https://cloud.tencent.com/developer/article/1436179 ))
#!/usr/bin/env python
# -*- coding:utf8 -*-
"""
-------------------------------------------------
Description : pyspark测试
Author : liupeng
Date : 2019/7/23
-------------------------------------------------
"""
import os
import sys
import time
import pandas as pd
import numpy as np
from start_pyspark import spark, sc, sqlContext
import pyspark.sql.types as typ
import pyspark.ml.feature as ft
from pyspark.sql.functions import isnan, isnull
import os
os.environ['PYSPARK_SUBMIT_ARGS'] = '--jars xgboost4j-spark-0.72.jar,xgboost4j-0.72.jar pyspark-shell'
# import findspark
# findspark.init()
import pyspark
from pyspark.sql.session import SparkSession
from pyspark.sql.types import *
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.sql.functions import col
# spark.sparkContext.addPyFile("hdfs:///tmp/rd/lp/sparkxgb.zip")
from sparkxgb import XGBoostEstimator
schema = StructType(
[StructField("PassengerId", DoubleType()),
StructField("Survival", DoubleType()),
StructField("Pclass", DoubleType()),
StructField("Name", StringType()),
StructField("Sex", StringType()),
StructField("Age", DoubleType()),
StructField("SibSp", DoubleType()),
StructField("Parch", DoubleType()),
StructField("Ticket", StringType()),
StructField("Fare", DoubleType()),
StructField("Cabin", StringType()),
StructField("Embarked", StringType())
])
df_raw = spark\
.read\
.option("header", "true")\
.schema(schema)\
.csv("hdfs:///tmp/rd/lp/titanic/train.csv")
df_raw.show(2)
df = df_raw.na.fill(0)
sexIndexer = StringIndexer()\
.setInputCol("Sex")\
.setOutputCol("SexIndex")\
.setHandleInvalid("keep")
cabinIndexer = StringIndexer()\
.setInputCol("Cabin")\
.setOutputCol("CabinIndex")\
.setHandleInvalid("keep")
embarkedIndexer = StringIndexer()\
.setInputCol("Embarked")\
.setHandleInvalid("keep")
# .setOutputCol("EmbarkedIndex")\
vectorAssembler = VectorAssembler()\
.setInputCols(["Pclass", "Age", "SibSp", "Parch", "Fare"])\
.setOutputCol("features")
xgboost = XGBoostEstimator( featuresCol="features", labelCol="Survival", predictionCol="prediction")
# pipeline = Pipeline().setStages([sexIndexer, cabinIndexer, embarkedIndexer, vectorAssembler, xgboost])
pipeline = Pipeline(stages=[
vectorAssembler,
xgboost
])
trainDF, testDF = df.randomSplit([0.8, 0.2], seed=24)
trainDF.show(2)
model = pipeline.fit(trainDF)
print (88888888888888888888)
model.transform(testDF).select(col("PassengerId"), col("prediction")).show()
print (9999999999999999999)
'''
# Define and train model
xgboost = XGBoostEstimator(
# General Params
nworkers=1, nthread=1, checkpointInterval=-1, checkpoint_path="",
use_external_memory=False, silent=0, missing=float("nan"),
# Column Params
featuresCol="features", labelCol="label", predictionCol="prediction",
weightCol="weight", baseMarginCol="baseMargin",
# Booster Params
booster="gbtree", base_score=0.5, objective="binary:logistic", eval_metric="error",
num_class=2, num_round=2, seed=None,
# Tree Booster Params
eta=0.3, gamma=0.0, max_depth=6, min_child_weight=1.0, max_delta_step=0.0, subsample=1.0,
colsample_bytree=1.0, colsample_bylevel=1.0, reg_lambda=0.0, alpha=0.0, tree_method="auto",
sketch_eps=0.03, scale_pos_weight=1.0, grow_policy='depthwise', max_bin=256,
# Dart Booster Params
sample_type="uniform", normalize_type="tree", rate_drop=0.0, skip_drop=0.0,
# Linear Booster Params
lambda_bias=0.0
)
'''
'''
xgboost_model = xgboost.fit(trainDF)
# Transform test set
xgboost_model.transform(testDF).show()
# Write model/classifier
xgboost.write().overwrite().save("xgboost_class_test")
xgboost_model.write().overwrite().save("xgboost_class_test.model")
'''
start_pyspark.py
#!/usr/bin/env python
# -*- coding:utf8 -*-
"""
-------------------------------------------------
Description : 模型预测接口
Author : liupeng
Date : 2019/7/23
-------------------------------------------------
"""
import os
import sys
'''
#下面这些目录都是你自己机器的Spark安装目录和Java安装目录
os.environ['SPARK_HOME'] = "/Users/***/spark-2.4.3-bin-hadoop2.7/"
sys.path.append("/Users/***/spark-2.4.3-bin-hadoop2.7/bin")
sys.path.append("/Users/***/spark-2.4.3-bin-hadoop2.7/python")
sys.path.append("/Users/***/spark-2.4.3-bin-hadoop2.7/python/pyspark")
sys.path.append("/Users/***/spark-2.4.3-bin-hadoop2.7/python/lib")
sys.path.append("/Users/***/spark-2.4.3-bin-hadoop2.7/python/lib/pyspark.zip")
sys.path.append("/Users/***/spark-2.4.3-bin-hadoop2.7/lib/py4j-0.9-src.zip")
# sys.path.append("/Library/Java/JavaVirtualMachines/jdk1.8.0_144.jdk/Contents/Home")
os.environ['JAVA_HOME'] = "/Library/Java/JavaVirtualMachines/jdk1.8.0_181.jdk/Contents/Home"
'''
from pyspark.sql import SparkSession, SQLContext
from pyspark import SparkConf, SparkContext
#conf = SparkConf().setMaster("local").setAppName("My App")
conf = SparkConf().setMaster("yarn").setAppName("My App")
sc = SparkContext(conf = conf)
spark = SparkSession.builder.appName('CalculatingGeoDistances').getOrCreate()
sqlContext = SQLContext(sparkContext=sc)
集群提交测试:
nohup /di_software/emr-package/spark-2.4.3-bin-hadoop2.7/bin/spark-submit --master yarn --jars /home/di/liupeng/qdxgboost/xgboost4j-0.72.jar,/home/di/liupeng/qdxgboost/xgboost4j-spark-0.72.jar /home/di/liupeng/qdxgboost/test_xgboost.py > output_spark.log 2>&1 &
主要参考:pyspark xgboost: https://towardsdatascience.com/pyspark-and-xgboost-integration-tested-on-the-kaggle-titanic-dataset-4e75a568bdb ( 需要 spark2.3之后的版本 )
非网格搜索模式下加载和保存模型:
from sparkxgb import XGBoostEstimator, XGBoostClassificationModel
feature_path = '/tmp/rd/lp/model27'
model_path = '/tmp/rd/lp/model28'
xgboost = XGBoostEstimator(featuresCol="features", labelCol="Survival", predictionCol="prediction",
num_round=100)
pipeline = Pipeline(stages=[vectorAssembler])
s_model = pipeline.fit(trainDF)
train_data = s_model.transform( trainDF )
s_model.write().overwrite().save(feature_path)
xgb_model = xgboost.fit( train_data )
xgb_model.write().overwrite().save(model_path)
pipeline = PipelineModel.load(feature_path)
train_data = s_model.transform( trainDF )
model = XGBoostClassificationModel.load( model_path )
res = model.transform( train_data )
print ( 'res:', res.collect() )
扫码关注腾讯云开发者
领取腾讯云代金券
Copyright © 2013 - 2025 Tencent Cloud. All Rights Reserved. 腾讯云 版权所有
深圳市腾讯计算机系统有限公司 ICP备案/许可证号:粤B2-20090059 深公网安备号 44030502008569
腾讯云计算(北京)有限责任公司 京ICP证150476号 | 京ICP备11018762号 | 京公网安备号11010802020287
Copyright © 2013 - 2025 Tencent Cloud.
All Rights Reserved. 腾讯云 版权所有