您的当前位置:首页正文

大数据进阶必修课!Spark实战贝叶斯分类算法

2024-11-19 来源:个人技术集锦

贝叶斯算法作为机器学习算法中非常重要的一个流派,无论在学术研究还是企业应用中都很受欢迎,学习本文不需要较强的概率论基础,在理论讲解的同时,从代码实战角度加强你的理解。

3.SparkMLlib贝叶斯分类算法

3.1贝叶斯分类算法

首先简要介绍一下贝叶斯算法的重要定理——贝叶斯定理。
P ( A ) P(A) P(A)是随机事件 A A A的先验概率, P ( A ∣ B ) P(A|B) P(AB)是已知随机事件 B B B发生后 A A A的条件概率,也称作 A A A的后验概率。随机事件 B B B同理。有公式:
P ( A ∣ B ) = P ( A B ) P ( B ) P ( A | B ) = \frac { P ( A B ) } { P ( B ) } P(AB)=P(B)P(AB)
也就引出了贝叶斯定理:
P ( B ∣ A ) = P ( A ∣ B ) P ( B ) P ( A ) P ( B | A ) = \frac { P ( A | B ) P ( B ) } { P ( A ) } P(BA)=P(A)P(AB)P(B)
基于此,我们在这里提出了一种简单的分类器——朴素贝叶斯分类器。对于待分类样本,计算在此样本发生时各个类别出现的概率,找出最大的概率,就可以判定这个待分类样本属于哪个类别。朴素贝叶斯的分类过程包括了数据准备,训练和预测阶段。分类过程如下:
(1)定义待分类项 x x x的特征向量 ( a 1 , a 2 … … a k ) (a1,a2……ak) (a1,a2ak),分类的类别为 ( y 1 , y 2 … … y m ) (y1,y2……ym) (y1,y2ym)
(2)计算 P ( y j ∣ x ) P(yj|x) P(yjx) x x x发生时,属于 y 1 , y 2 … y m y1,y2…ym y1,y2ym的概率分别为多少,根据贝叶斯公式可将此问题转化为求待分类样本属于具体分类类别的概率 p ( x ∣ y j ) p ( y j ) p(x|yj)p(yj) p(xyj)p(yj),即
P ( y j ∣ x ) = P ( x ∣ y j ) P ( y j ) P ( x ) P ( y_j | x ) = \frac { P ( x | y_ j ) P ( y_ j ) } { P ( x ) } P(yjx)=P(x)P(xyj)P(yj)
(3)再根据条件概率公式:
P ( x ∣ y j ) P ( y j ) = ∏ P ( a i ∣ y j ) P ( y j ) P ( x | y_j ) P ( y_j ) = \prod P ( a_i | y_j ) P ( y_j ) P(xyj)P(yj)=P(aiyj)P(yj)
依据此概率最大项来判断 x x x所属的类别。

3.2算法源码分析

MLlib中的贝叶斯分类模型是基于朴素贝叶斯,先计算各个类别的先验概率和各类别下各个特征的条件概率。其实现的步骤如下:
(1)对训练样本统计所有标签出现次数和对应特征之和;
(2)对(标签,样本特征)形式的样本进行聚合操作,统计属于同一标签的数据;
(3)由以上的统计结果计算先验概率和条件概率,得到朴素贝叶斯分类样本;
(4)进行预测,根据模型的先验和条件概率来计算待测样本属于每个类别的概率,取最大概率作为分类依据。
MLlib中对朴素贝叶斯的实现如下:
(1)贝叶斯分类伴生对象:NativeBayes,含有train静态方法,可以设置参数创建朴素贝叶斯分类类,执行run方法进行训练,train方法主要参数如下:

  • input——训练样本格式为RDD(label,features);
  • lambda——平滑参数,
    (2)贝叶斯分类模型:NativeBayes类,含有run方法,训练贝叶斯模型,计算各个类别的先验概率和各个特征属于各个类别的条件概率;
    (3)训练模型:aggregated对样本各个类别下每一个特征值之和和次数进行统计,pi根据统计结果计算各类别先验概率,theta根据统计结果计算各个特征属于各个类别的条件概率;
    (4)贝叶斯模型:NativeBayesModel类,由模型的先验和条件概率计算样本属于各个类别的概率,以最大概率作为判决依据,主要参数包括:
  • labels——类别标签列表
  • pi——类别先验概率
  • theta——各个特征在各个类别中的条件概率
  • modelType——多项式或伯努利模型
    含有predict方法根据贝叶斯分类模型返回样本预测值RDD[Double]类型。

3.3应用实战

3.3.1数据说明

本次实战选择是sample_libsvm_data数据,以后也会经常用到,其样本格式为:
label index1:value1 index2:value2 index3:value3 …
label是样本的标签值,一对index:value表示一个特征和特征值。

3.3.2代码详解

//首先导入用到的机器学习包,NativeBayes算法包和多分类器评价包
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

//以DataFrame形式加载以libsvm格式存储的数据
val data = spark.read.format("libsvm").load("/mnt/hgfs/thunder-
download/MLlib_rep/data/sample_libsvm_data.txt")
输出结果为:
org.apache.spark.sql.DataFrame = [label: double, features: vector]

//把数据切分为训练集和测试集,比例为7:3
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3), seed = 1234L)

输出结果为:

trainingData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, features: vector]
testData: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, features: vector]
//加载训练数据,训练一个朴素贝叶斯模型
val model = new NaiveBayes().fit(trainingData)
//展示预测结果
val predictions = model.transform(testData)
predictions.show()

输出预测结果如下:

+-----+--------------------+--------------------+-----------+----------+
|label|            features|       rawPrediction|probability|prediction|
+-----+--------------------+--------------------+-----------+----------+
|  0.0|(692,[95,96,97,12...|[-173678.60946628...|  [1.0,0.0]|       0.0|
|  0.0|(692,[98,99,100,1...|[-178107.24302988...|  [1.0,0.0]|       0.0|
|  0.0|(692,[100,101,102...|[-100020.80519087...|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|[-183521.85526462...|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|[-183004.12461660...|  [1.0,0.0]|       0.0|
|  0.0|(692,[128,129,130...|[-246722.96394714...|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|[-208696.01108598...|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|[-261509.59951302...|  [1.0,0.0]|       0.0|
|  0.0|(692,[154,155,156...|[-217654.71748256...|  [1.0,0.0]|       0.0|
|  0.0|(692,[181,182,183...|[-155287.07585335...|  [1.0,0.0]|       0.0|
|  1.0|(692,[99,100,101,...|[-145981.83877498...|  [0.0,1.0]|       1.0|
|  1.0|(692,[100,101,102...|[-147685.13694275...|  [0.0,1.0]|       1.0|
|  1.0|(692,[123,124,125...|[-139521.98499849...|  [0.0,1.0]|       1.0|
|  1.0|(692,[124,125,126...|[-129375.46702012...|  [0.0,1.0]|       1.0|
|  1.0|(692,[126,127,128...|[-145809.08230799...|  [0.0,1.0]|       1.0|
|  1.0|(692,[127,128,129...|[-132670.15737290...|  [0.0,1.0]|       1.0|
|  1.0|(692,[128,129,130...|[-100206.72054749...|  [0.0,1.0]|       1.0|
|  1.0|(692,[129,130,131...|[-129639.09694930...|  [0.0,1.0]|       1.0|
|  1.0|(692,[129,130,131...|[-143628.65574273...|  [0.0,1.0]|       1.0|
|  1.0|(692,[129,130,131...|[-129238.74023248...|  [0.0,1.0]|       1.0|
+-----+--------------------+--------------------+-----------+----------+
only showing top 20 rows
//计算模型的准确度
val evaluator = new MulticlassClassificationEvaluator()
  .setLabelCol("label")
  .setPredictionCol("prediction")
  .setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println("Test set accuracy = " + accuracy)

输出结果,预测精度为1.0:

Test set accuracy = 1.0
显示全文