欢迎来到广州华商学院大数据系DModel实训平台,

淘宝双11数据分析与预测课程案例—步骤四:利用Spark预测回头客行为

责任编辑:bradley   发布时间:2022-06-30 19:38:37   

本教程介绍大数据课程实验案例“淘宝双11数据分析与预测”的第五个步骤,利用Spark预测回头客。在实践本步骤之前,请先完成该实验案例的第一个步骤——本地数据集上传到数据仓库Hive第二个步骤——Hive数据分析,和第三个步骤:将数据从Hive导入到MySQL,这里假设你已经完成了前面的这四个步骤。

预处理test.csv和train.csv数据集

这里列出test.csv和train.csv中字段的描述,字段定义如下:

  1. user_id | 买家id

  2. age_range | 买家年龄分段:1表示年龄

  3. gender | 性别:0表示女性,1表示男性,2和NULL表示未知

  4. merchant_id | 商家id

  5. label | 是否是回头客,0值表示不是回头客,1值表示回头客,-1值表示该用户已经超出我们所需要考虑的预测范围。NULL值只存在测试集,在测试集中表示需要预测的值。

这里需要预先处理test.csv数据集,把这test.csv数据集里label字段表示-1值剔除掉,保留需要预测的数据.并假设需要预测的数据中label字段均为1.

cd /usr/local/dbtaobao/dataset

vim predeal_test.sh

上面使用vim编辑器新建了一个predeal_test.sh脚本文件,请在这个脚本文件中加入下面代码:

#!/bin/bash
#下面设置输入文件,把用户执行predeal_test.sh命令时提供的第一个参数作为输入文件名称
infile=$1
#下面设置输出文件,把用户执行predeal_test.sh命令时提供的第二个参数作为输出文件名称
outfile=$2
#注意!!最后的$infile > $outfile必须跟在}’这两个字符的后面
awk -F "," 'BEGIN{
      id=0;
    }
    {
        if($1 && $2 && $3 && $4 && !$5){
            id=id+1;
            print $1","$2","$3","$4","1
            if(id==10000){
                exit
            }
        }
    }' $infile > $outfile

下面就可以执行predeal_test.sh脚本文件,截取测试数据集需要预测的数据到test_after.csv,命令如下:

chmod +x ./predeal_test.sh

./predeal_test.sh ./test.csv ./test_after.csv

train.csv的第一行都是字段名称,不需要第一行字段名称,这里在对train.csv做数据预处理时,删除第一行

sed -i '1d' train.csv

然后剔除掉train.csv中字段值部分字段值为空的数据。


cd /usr/local/dbtaobao/dataset

vim predeal_train.sh

上面使用vim编辑器新建了一个predeal_train.sh脚本文件,请在这个脚本文件中加入下面代码:

#!/bin/bash
#下面设置输入文件,把用户执行predeal_train.sh命令时提供的第一个参数作为输入文件名称
infile=$1
#下面设置输出文件,把用户执行predeal_train.sh命令时提供的第二个参数作为输出文件名称
outfile=$2
#注意!!最后的$infile > $outfile必须跟在}’这两个字符的后面
awk -F "," 'BEGIN{
         id=0;
    }
    {
        if($1 && $2 && $3 && $4 && ($5!=-1)){
            id=id+1;
            print $1","$2","$3","$4","$5
            if(id==10000){
                exit
            }
        }
    }' $infile > $outfile

下面就可以执行predeal_train.sh脚本文件,截取测试数据集需要预测的数据到train_after.csv,命令如下:

chmod +x ./predeal_train.sh

./predeal_train.sh ./train.csv ./train_after.csv

预测回头客

启动hadoop

请先确定Spark的运行方式,如果Spark是基于Hadoop伪分布式运行,那么请先运行Hadoop。
如果Hadoop没有运行,请执行如下命令:

cd /usr/local/hadoop/

sbin/start-dfs.sh

将两个数据集分别存取到HDFS中


bin/hadoop fs -mkdir -p /dbtaobao/dataset

bin/hadoop fs -put /usr/local/dbtaobao/dataset/train_after.csv /dbtaobao/dataset

bin/hadoop fs -put /usr/local/dbtaobao/dataset/test_after.csv /dbtaobao/dataset

启动MySQL服务

service mysql start

mysql -uroot -p #会提示让你输入数据库密码

输入密码后,你就可以进入“mysql>”命令提示符状态,然后就可以输入下面的SQL语句完成表的创建:


use dbtaobao;

create table rebuy (score varchar(40),label varchar(40));

接下来正式启动spark-shell


cd /usr/local/spark

./bin/spark-shell --jars /usr/local/spark/jars/mysql-connector-java-5.1.40/mysql-connector-java-5.1.40-bin.jar --driver-class-path /usr/local/spark/jars/mysql-connector-java-5.1.40/mysql-connector-java-5.1.40-bin.jar

支持向量机SVM分类器预测回头客

这里使用Spark MLlib自带的支持向量机SVM分类器进行预测回头客,有关更多Spark MLlib中SVM分类器的学习知识,请点击Spark入门:支持向量机SVM分类器进行学习。
在spark-shell中执行如下操作:
1.导入需要的包
首先,我们导入需要的包:

import org.apache.spark.SparkConf

import org.apache.spark.SparkContext

import org.apache.spark.mllib.regression.LabeledPoint

import org.apache.spark.mllib.linalg.{Vectors,Vector}

import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}

import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics

import java.util.Properties

import org.apache.spark.sql.types._

import org.apache.spark.sql.Row

2.读取训练数据
首先,读取训练文本文件;然后,通过map将每行的数据用“,”隔开,在数据集中,每行被分成了5部分,前4部分是用户交易的3个特征(age_range,gender,merchant_id),最后一部分是用户交易的分类(label)。把这里我们用LabeledPoint来存储标签列和特征列。LabeledPoint在监督学习中常用来存储标签和特征,其中要求标签的类型是double,特征的类型是Vector。

val train_data = sc.textFile("/dbtaobao/dataset/train_after.csv")

val test_data = sc.textFile("/dbtaobao/dataset/test_after.csv")

3.构建模型


val train= train_data.map{line =>

  val parts = line.split(',')

  LabeledPoint(parts(4).toDouble,Vectors.dense(parts(1).toDouble,parts

(2).toDouble,parts(3).toDouble))

}

val test = test_data.map{line =>

  val parts = line.split(',')

  LabeledPoint(parts(4).toDouble,Vectors.dense(parts(1).toDouble,parts(2).toDouble,parts(3).toDouble))

}

接下来,通过训练集构建模型SVMWithSGD。这里的SGD即著名的随机梯度下降算法(Stochastic Gradient Descent)。设置迭代次数为1000,除此之外还有stepSize(迭代步伐大小),regParam(regularization正则化控制参数),miniBatchFraction(每次迭代参与计算的样本比例),initialWeights(weight向量初始值)等参数可以进行设置。


val numIterations = 1000

val model = SVMWithSGD.train(train, numIterations)

4.评估模型
接下来,我们清除默认阈值,这样会输出原始的预测评分,即带有确信度的结果。


model.clearThreshold()

val scoreAndLabels = test.map{point =>

  val score = model.predict(point.features)

  score+" "+point.label

}

scoreAndLabels.foreach(println)

spark-shell会打印出如下结果

......
-59045.132228013084 1.0
-81550.17634254562 1.0
-87393.69932070676 1.0
-34743.183626268634 1.0
-42541.544145105494 1.0
-75530.22669142077 1.0
-84157.31973688163 1.0
-18673.911440386535 1.0
-43765.52530945006 1.0
-80524.44350315288 1.0
-61709.836501153935 1.0
-37486.854426141384 1.0
-79793.17112276069 1.0
-21754.021986991942 1.0
-50378.971923247285 1.0
-11646.722569368836 1.0
......

如果我们设定了阀值,则会把大于阈值的结果当成正预测,小于阈值的结果当成负预测。

model.setThreshold(0.0)

scoreAndLabels.foreach(println)

  1. 把结果添加到mysql数据库中
    现在我们上面没有设定阀值的测试集结果存入到MySQL数据中。

model.clearThreshold()

val scoreAndLabels = test.map{point =>

  val score = model.predict(point.features)

  score+" "+point.label

}

//设置回头客数据

val rebuyRDD = scoreAndLabels.map(_.split(" "))

/下面要设置模式信息

val schema = StructType(List(StructField("score", StringType, true),StructField("label", StringType, true)))

//下面创建Row对象,每个Row对象都是rowRDD中的一行

val rowRDD = rebuyRDD.map(p => Row(p(0).trim, p(1).trim))

//建立起Row对象和模式之间的对应关系,也就是把数据和模式对应起来

val rebuyDF = spark.createDataFrame(rowRDD, schema)

//下面创建一个prop变量用来保存JDBC连接参数

val prop = new Properties()

prop.put("user", "root") //表示用户名是root

prop.put("password", "root") //表示密码是hadoop

prop.put("driver","com.mysql.jdbc.Driver") //表示驱动程序是com.mysql.jdbc.Driver

//下面就可以连接数据库,采用append模式,表示追加记录到数据库dbtaobao的rebuy表中

rebuyDF.write.mode("append").jdbc("jdbc:mysql://localhost:3306/dbtaobao", "dbtaobao.rebuy", prop)

到这里,第四个步骤的实验内容顺利结束。请继续访问第五个步骤《大数据案例-步骤五:利用ECharts进行数据可视化分析

☆ 大数据