pyspark-随机森林
导入数据#先创建一个spark对象from pyspark.sql import SparkSessionspark=SparkSession.builder.appName('random_forest').getOrCreate()#导入数据df=spark.read.csv('affairs.csv',inferSchema=True,header=True)EDA查看数据结构print((
导入数据
#先创建一个spark对象
from pyspark.sql import SparkSession
spark=SparkSession.builder.appName('random_forest').getOrCreate()
#导入数据
df=spark.read.csv('affairs.csv',inferSchema=True,header=True)
EDA
查看数据结构
print((df.count(),len(df.columns)))
df.printSchema()
df.show(5)
df.describe().show(10,False)
(6366, 6)
root
|-- rate_marriage: integer (nullable = true)婚姻评分
|-- age: double (nullable = true)年龄
|-- yrs_married: double (nullable = true)结婚年龄
|-- children: double (nullable = true)子女数量
|-- religious: integer (nullable = true)宗教评分
|-- affairs: integer (nullable = true)是否婚外恋
| rate_marriage | age | yrs_married | children | religious | affairs |
|---|---|---|---|---|---|
| 5 | 32.0 | 6.0 | 1.0 | 3 | 0 |
| 4 | 22.0 | 2.5 | 0.0 | 2 | 0 |
| 3 | 32.0 | 9.0 | 3.0 | 3 | 1 |
| 3 | 27.0 | 13.0 | 3.0 | 1 | 1 |
| 4 | 22.0 | 2.5 | 0.0 | 1 | 1 |
| summary | rate_marriage | age | yrs_married | children | religious | affairs |
|---|---|---|---|---|---|---|
| count | 6366 | 6366 | 6366 | 6366 | 6366 | 6366 |
| mean | 4.109644989004084 | 29.082862079798932 | 9.00942507068803 | 1.3968740182218033 | 2.4261702796104303 | 0.3224945020420987 |
| stddev | 0.9614295945655025 | 6.847881883668817 | 7.280119972766412 | 1.433470828560344 | 0.8783688402641785 | 0.467467779921086 |
| min | 1 | 17.5 | 0.5 | 0.0 | 1 | 0 |
| max | 5 | 42.0 | 23.0 | 5.5 | 4 | 1 |
可以发现这些人的平均年龄是29岁,并且他们已经结婚9年。
深入分析
df.groupBy('rate_marriage','affairs').count().orderBy('rate_marriage','affairs','count',ascending=True).show()
df.groupBy('religious','affairs').count().orderBy('religious','affairs','count',ascending=True).show()
df.groupBy('children','affairs').count().orderBy('children','affairs','count',ascending=True).show()
df.groupBy('affairs').mean().show()
- rate_marriage
rate_marriage |
affairs | count |
|---|---|---|
| 1 | 0 | 25 |
| 1 | 1 | 74 |
| 2 | 0 | 127 |
| 2 | 1 | 221 |
| 3 | 0 | 446 |
| 3 | 1 | 547 |
| 4 | 0 | 1518 |
| 4 | 1 | 724 |
| 5 | 0 | 2197 |
| 5 | 1 | 487 |
在评分较低的人中,有很高比例的都设计婚外恋
2.religious
religious |
affairs | count |
|---|---|---|
| 1 | 0 | 613 |
| 1 | 1 | 408 |
| 2 | 0 | 1448 |
| 2 | 1 | 819 |
| 3 | 0 | 1715 |
| 3 | 1 | 707 |
| 4 | 0 | 537 |
| 4 | 1 | 119 |
在宗教信仰评分中,对于宗教评分较低的人,设计的婚外恋比例较高
- children
children |
affairs | count |
|---|---|---|
| 0.0 | 0 | 1912 |
| 0.0 | 1 | 502 |
| 1.0 | 0 | 747 |
| 1.0 | 1 | 412 |
| 2.0 | 0 | 873 |
| 2.0 | 1 | 608 |
| 3.0 | 0 | 460 |
| 3.0 | 1 | 321 |
| 4.0 | 0 | 197 |
| 4.0 | 1 | 131 |
| 5.5 | 0 | 124 |
| 5.5 | 1 | 79 |
在子女数量上没有发现特别明显的特征
- 综合
| affairs | avg(rate_marriage) | avg(age) | avg(yrs_married) | avg(children) | avg(religious) | avg(affairs) |
|---|---|---|---|---|---|---|
| 1 | 3.6473453482708234 | 30.537018996590355 | 11.152459814905017 | 1.7289332683877252 | 2.261568436434486 | 1.0 |
| 0 | 4.329700904242986 | 28.39067934152562 | 7.989334569904939 | 1.2388128912589844 | 2.5045212149316023 | 0.0 |
总体上我们发现,设计婚外恋的人岁婚姻评分较低,年龄稍大一些,结婚时间比较久,并且对宗教信仰评分较低
特征工程
类似logistic一样,都需要将数据转换成向量
df_assembler = VectorAssembler(inputCols=['rate_marriage', 'age', 'yrs_married', 'children', 'religious'], outputCol="features")
df = df_assembler.transform(df)
df.select(['features','affairs']).show(10,False)
| features | affairs |
|---|---|
| [5.0,32.0,6.0,1.0,3.0] | 0 |
| [4.0,22.0,2.5,0.0,2.0] | 0 |
| [3.0,32.0,9.0,3.0,3.0] | 1 |
| [3.0,27.0,13.0,3.0,1.0] | 1 |
| [4.0,22.0,2.5,0.0,1.0] | 1 |
| [4.0,37.0,16.5,4.0,3.0] | 1 |
| [5.0,27.0,9.0,1.0,1.0] | 1 |
| [4.0,27.0,9.0,0.0,2.0] | 1 |
| [5.0,37.0,23.0,5.5,2.0] | 1 |
| [5.0,37.0,23.0,5.5,2.0] | 1 |
构建数据集并划分
model_df=df.select(['features','affairs'])
train_df,test_df=model_df.randomSplit([0.75,0.25])
构建和训练模型
rf_classifier=RandomForestClassifier(labelCol='affairs',numTrees=50).fit(train_df)
rf_predictions=rf_classifier.transform(test_df)
预测
rf_accuracy=MulticlassClassificationEvaluator(labelCol='affairs',metricName='accuracy').evaluate(rf_predictions)
print('The accuracy of RF on test data is {0:.0%}'.format(rf_accuracy))
rf_precision=MulticlassClassificationEvaluator(labelCol='affairs',metricName='weightedPrecision').evaluate(rf_predictions)
print('The precision rate on test data is {0:.0%}'.format(rf_precision))
rf_auc=BinaryClassificationEvaluator(labelCol='affairs').evaluate(rf_predictions)
print('The AUC on test data is {0:.0%}'.format(rf_auc))
The accuracy of RF on test data is 72%
The precision rate on test data is 70%
The AUC on test data is 73%
变量贡献
rf_classifier.featureImportances
df.schema["features"].metadata["ml_attr"]["attrs"]
SparseVector(5, {0: 0.6258, 1: 0.018, 2: 0.2381, 3: 0.0524, 4: 0.0657})
{‘numeric’: [{‘idx’: 0, ‘name’: ‘rate_marriage’},
{‘idx’: 1, ‘name’: ‘age’},
{‘idx’: 2, ‘name’: ‘yrs_married’},
{‘idx’: 3, ‘name’: ‘children’},
{‘idx’: 4, ‘name’: ‘religious’}]}
可以看到,rate_marriage是最重要的特征,其次是yrs_married。
模型保存
#pwd可以查看当前路径
rf_classifier.save("/Users/admin/Desktop/pyspark/RFmodel")
from pyspark.ml.classification import RandomForestClassificationModel
rf=RandomForestClassificationModel.load("/Users/admin/Desktop/pysparkRFmodel")
model_preditions=rf.transform(test_df)
model_preditions.show()
更多推荐
所有评论(0)