导入数据

#先创建一个spark对象
from pyspark.sql import SparkSession
spark=SparkSession.builder.appName('random_forest').getOrCreate()
#导入数据
df=spark.read.csv('affairs.csv',inferSchema=True,header=True)

EDA

查看数据结构

print((df.count(),len(df.columns)))

df.printSchema()

df.show(5)
df.describe().show(10,False)

(6366, 6)
root
|-- rate_marriage: integer (nullable = true)婚姻评分
|-- age: double (nullable = true)年龄
|-- yrs_married: double (nullable = true)结婚年龄
|-- children: double (nullable = true)子女数量
|-- religious: integer (nullable = true)宗教评分
|-- affairs: integer (nullable = true)是否婚外恋

rate_marriage age yrs_married children religious affairs
5 32.0 6.0 1.0 3 0
4 22.0 2.5 0.0 2 0
3 32.0 9.0 3.0 3 1
3 27.0 13.0 3.0 1 1
4 22.0 2.5 0.0 1 1
summary rate_marriage age yrs_married children religious affairs
count 6366 6366 6366 6366 6366 6366
mean 4.109644989004084 29.082862079798932 9.00942507068803 1.3968740182218033 2.4261702796104303 0.3224945020420987
stddev 0.9614295945655025 6.847881883668817 7.280119972766412 1.433470828560344 0.8783688402641785 0.467467779921086
min 1 17.5 0.5 0.0 1 0
max 5 42.0 23.0 5.5 4 1

可以发现这些人的平均年龄是29岁,并且他们已经结婚9年。

深入分析
df.groupBy('rate_marriage','affairs').count().orderBy('rate_marriage','affairs','count',ascending=True).show()

df.groupBy('religious','affairs').count().orderBy('religious','affairs','count',ascending=True).show()

df.groupBy('children','affairs').count().orderBy('children','affairs','count',ascending=True).show()

df.groupBy('affairs').mean().show()
  1. rate_marriage
rate_marriage affairs count
1 0 25
1 1 74
2 0 127
2 1 221
3 0 446
3 1 547
4 0 1518
4 1 724
5 0 2197
5 1 487

在评分较低的人中,有很高比例的都设计婚外恋

2.religious

religious affairs count
1 0 613
1 1 408
2 0 1448
2 1 819
3 0 1715
3 1 707
4 0 537
4 1 119

在宗教信仰评分中,对于宗教评分较低的人,设计的婚外恋比例较高

  1. children
children affairs count
0.0 0 1912
0.0 1 502
1.0 0 747
1.0 1 412
2.0 0 873
2.0 1 608
3.0 0 460
3.0 1 321
4.0 0 197
4.0 1 131
5.5 0 124
5.5 1 79

在子女数量上没有发现特别明显的特征

  1. 综合
affairs avg(rate_marriage) avg(age) avg(yrs_married) avg(children) avg(religious) avg(affairs)
1 3.6473453482708234 30.537018996590355 11.152459814905017 1.7289332683877252 2.261568436434486 1.0
0 4.329700904242986 28.39067934152562 7.989334569904939 1.2388128912589844 2.5045212149316023 0.0

总体上我们发现,设计婚外恋的人岁婚姻评分较低,年龄稍大一些,结婚时间比较久,并且对宗教信仰评分较低

特征工程

类似logistic一样,都需要将数据转换成向量

df_assembler = VectorAssembler(inputCols=['rate_marriage', 'age', 'yrs_married', 'children', 'religious'], outputCol="features")
df = df_assembler.transform(df)

df.select(['features','affairs']).show(10,False)
features affairs
[5.0,32.0,6.0,1.0,3.0] 0
[4.0,22.0,2.5,0.0,2.0] 0
[3.0,32.0,9.0,3.0,3.0] 1
[3.0,27.0,13.0,3.0,1.0] 1
[4.0,22.0,2.5,0.0,1.0] 1
[4.0,37.0,16.5,4.0,3.0] 1
[5.0,27.0,9.0,1.0,1.0] 1
[4.0,27.0,9.0,0.0,2.0] 1
[5.0,37.0,23.0,5.5,2.0] 1
[5.0,37.0,23.0,5.5,2.0] 1

构建数据集并划分

model_df=df.select(['features','affairs'])
train_df,test_df=model_df.randomSplit([0.75,0.25])

构建和训练模型

rf_classifier=RandomForestClassifier(labelCol='affairs',numTrees=50).fit(train_df)
rf_predictions=rf_classifier.transform(test_df)

预测

rf_accuracy=MulticlassClassificationEvaluator(labelCol='affairs',metricName='accuracy').evaluate(rf_predictions)
print('The accuracy of RF on test data is {0:.0%}'.format(rf_accuracy))


rf_precision=MulticlassClassificationEvaluator(labelCol='affairs',metricName='weightedPrecision').evaluate(rf_predictions)
print('The precision rate on test data is {0:.0%}'.format(rf_precision))

rf_auc=BinaryClassificationEvaluator(labelCol='affairs').evaluate(rf_predictions)
print('The AUC on test data is {0:.0%}'.format(rf_auc))

The accuracy of RF on test data is 72%

The precision rate on test data is 70%

The AUC on test data is 73%

变量贡献

rf_classifier.featureImportances
df.schema["features"].metadata["ml_attr"]["attrs"]

SparseVector(5, {0: 0.6258, 1: 0.018, 2: 0.2381, 3: 0.0524, 4: 0.0657})

{‘numeric’: [{‘idx’: 0, ‘name’: ‘rate_marriage’},
{‘idx’: 1, ‘name’: ‘age’},
{‘idx’: 2, ‘name’: ‘yrs_married’},
{‘idx’: 3, ‘name’: ‘children’},
{‘idx’: 4, ‘name’: ‘religious’}]}

可以看到,rate_marriage是最重要的特征,其次是yrs_married。

模型保存

#pwd可以查看当前路径
rf_classifier.save("/Users/admin/Desktop/pyspark/RFmodel")

from pyspark.ml.classification import RandomForestClassificationModel

rf=RandomForestClassificationModel.load("/Users/admin/Desktop/pysparkRFmodel")

model_preditions=rf.transform(test_df)
model_preditions.show()

Logo

腾讯云面向开发者汇聚海量精品云计算使用和开发经验,营造开放的云计算技术生态圈。

更多推荐