2-Spark学习笔记2


SparkSQL

SparkSQL概述

SparkSQL核心编程

package com.lotuslaw.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession


/**
 * @author: lotuslaw
 * @version: V1.0
 * @package: com.lotuslaw.spark.sql
 * @create: 2021-12-02 20:05
 * @description:
 */
object Spark_SparkSQL_Basic {

  def main(args: Array[String]): Unit = {
    // TODO: 创建SparkSQL的运行环境
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()
    import spark.implicits._

    // TODO: 创建运行环境
    val df = spark.read.json("input/user.json")
//    df.show()

    df.createOrReplaceTempView("user")
//    spark.sql("select * from user").show

    // 在使用DataFrame时,如果涉及到转换操作,需要引入转换规则
//    df.select("age", "username").show

//    df.select($"age" + 1).show

    // DataSet
    val seq = Seq(1, 2, 3, 4)
    val ds = seq.toDS()
//    ds.show

    val rdd = spark.sparkContext.makeRDD(List((1, "zhangsan", 30), (2, "lisi", 40)))
    val df2 = rdd.toDF("id", "name", "age")
    val rowRDD = df2.rdd

    val ds2 = df2.as[User]
    val df3 = ds2.toDF()

    val ds3 = rdd.map {
      case (id, name, age) => {
        User(id, name, age)
      }
    }.toDS()

    val userRDD = ds3.rdd


    // 关闭环境
    spark.close()

  }

  case class User(id: Int, name: String, age: Int)
}

package com.lotuslaw.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession


/**
 * @author: lotuslaw
 * @version: V1.0
 * @package: com.lotuslaw.spark.sql
 * @create: 2021-12-02 20:05
 * @description:
 */
object Spark_SparkSQL_UDF {

  def main(args: Array[String]): Unit = {
    // TODO: 创建SparkSQL的运行环境
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()
    import spark.implicits._

    val df = spark.read.json("input/user.json")
    df.createOrReplaceTempView("user")

    spark.udf.register("prefixName", (name:String) => {
      "Name" + name
    })

    spark.sql("select age, prefixName(username) from user").show()


    spark.close()
  }
}

package com.lotuslaw.spark.sql

import org.apache.parquet.filter2.predicate.Operators.UserDefined
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}


/**
 * @author: lotuslaw
 * @version: V1.0
 * @package: com.lotuslaw.spark.sql
 * @create: 2021-12-02 20:05
 * @description:
 */
object Spark_SparkSQL_UDAF {

  def main(args: Array[String]): Unit = {
    // TODO: 创建SparkSQL的运行环境
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()
    import spark.implicits._

    val df = spark.read.json("input/user.json")

    df.createOrReplaceTempView("user")

    spark.udf.register("avgAge", new MyAvgUDAF())

    spark.sql("select avgAge(age) from user").show

    spark.close()
  }

  /*
     自定义聚合函数类:计算年龄的平均值
     1. 继承UserDefinedAggregateFunction
     2. 重写方法(8)
     */
  class MyAvgUDAF extends UserDefinedAggregateFunction{

    // 输入数据的结构
    override def inputSchema: StructType = {
      StructType(
        Array(
          StructField("age", LongType)
        )
      )
    }

    // 缓冲区数据的结构
    override def bufferSchema: StructType = {
      StructType(
        Array(
          StructField("total", LongType),
          StructField("count", LongType)
        )
      )
    }

    // 函数计算结果的数据类型
    override def dataType: DataType = LongType

    // 函数的稳定性
    override def deterministic: Boolean = true

    // 缓冲区初始化
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
      buffer.update(0, 0L)
      buffer.update(1, 0L)
    }

    // 根据输入的值更新缓冲区数据
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
      buffer.update(0, buffer.getLong(0) + input.getLong(0))
      buffer.update(1, buffer.getLong(1) + 1)
    }

    // 缓冲区数据合并
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
      buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
      buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
    }

    override def evaluate(buffer: Row): Any = {
      buffer.getLong(0) / buffer.getLong(1)
    }
  }
}
package com.lotuslaw.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}


/**
 * @author: lotuslaw
 * @version: V1.0
 * @package: com.lotuslaw.spark.sql
 * @create: 2021-12-02 20:05
 * @description:
 */
object Spark_SparkSQL_UDF1 {

  def main(args: Array[String]): Unit = {
    // TODO: 创建SparkSQL的运行环境
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()
    import spark.implicits._

    val df = spark.read.json("input/user.json")
    df.createOrReplaceTempView("user")

    spark.udf.register("ageAvg", functions.udaf(new MyAvgUDAF()))

    spark.sql("select ageAvg(age) from user").show

    spark.close()
  }

  /*
     自定义聚合函数类:计算年龄的平均值
     1. 继承org.apache.spark.sql.expressions.Aggregator, 定义泛型
         IN : 输入的数据类型 Long
         BUF : 缓冲区的数据类型 Buff
         OUT : 输出的数据类型 Long
     2. 重写方法(6)
     */
  case class Buff(var total: Long, var count: Long)
  class MyAvgUDAF extends Aggregator[Long, Buff, Long] {

    // 缓冲区的初始化
    override def zero: Buff = {
      Buff(0L, 0L)
    }

    // 根据输入的数据更新缓冲区的数据
    override def reduce(b: Buff, a: Long): Buff = {
      b.total = b.total + a
      b.count = b.count + 1
      b
    }

    // 合并缓冲区
    override def merge(b1: Buff, b2: Buff): Buff = {
      b1.total = b1.total + b2.total
      b1.count = b1.count + b2.count
      b1
    }

    // 计算结果
    override def finish(reduction: Buff): Long = {
      reduction.total / reduction.count
    }

    // 缓冲区的编码操作
    override def bufferEncoder: Encoder[Buff] = Encoders.product

    override def outputEncoder: Encoder[Long] = Encoders.scalaLong
  }
}
package com.lotuslaw.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}


/**
 * @author: lotuslaw
 * @version: V1.0
 * @package: com.lotuslaw.spark.sql
 * @create: 2021-12-02 20:05
 * @description:
 */
object Spark_SparkSQL_UDF2 {

  def main(args: Array[String]): Unit = {
    // TODO: 创建SparkSQL的运行环境
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()
    import spark.implicits._

    val df = spark.read.json("input/user.json")
    // 早期版本中,spark不能在sql中使用强类型UDAF操作
    // SQL & DSL
    // 早期的UDAF强类型聚合函数使用DSL语法操作
    val ds = df.as[User]

    // 将UDAF函数转换为查询的列对象
    val udafCol = new MyAvgUDAF().toColumn

    ds.select(udafCol).show

    spark.close()
  }

  /*
     自定义聚合函数类:计算年龄的平均值
     1. 继承org.apache.spark.sql.expressions.Aggregator, 定义泛型
         IN : 输入的数据类型 User
         BUF : 缓冲区的数据类型 Buff
         OUT : 输出的数据类型 Long
     2. 重写方法(6)
     */
  case class User(username: String, age:Long)
  case class Buff(var total: Long, var count: Long)
  class MyAvgUDAF extends Aggregator[User, Buff, Long]{

    // 缓冲区的初始化
    override def zero: Buff = {
      Buff(0L, 0L)
    }

    // 根据输入的数据更新缓冲区的数据
    override def reduce(b: Buff, a: User): Buff = {
      b.total = b.total + a.age
      b.count = b.count + 1
      b
    }

    // 合并缓冲区
    override def merge(b1: Buff, b2: Buff): Buff = {
      b1.total = b1.total + b2.total
      b1.count = b1.count + b2.count
      b1
    }

    // 计算结果
    override def finish(reduction: Buff): Long = {
      reduction.total / reduction.count
    }

    // 缓冲区的编码操作
    override def bufferEncoder: Encoder[Buff] = Encoders.product

    override def outputEncoder: Encoder[Long] = Encoders.scalaLong
  }
}

package com.lotuslaw.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SaveMode, SparkSession}


/**
 * @author: lotuslaw
 * @version: V1.0
 * @package: com.lotuslaw.spark.sql
 * @create: 2021-12-02 20:05
 * @description:
 */
object Spark_SparkSQL_JDBC {

  def main(args: Array[String]): Unit = {
    // TODO: 创建SparkSQL的运行环境
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()
    import spark.implicits._

    // 读MySQL数据
    val df = spark.read
      .format("jdbc")
      .option("url", "jdbc:mysql://hadoop102:3306/test")
      .option("driver", "com.mysql.jdbc.Driver")
      .option("user", "root")
      .option("password", "********")
      .option("dbtable", "users")
      .load()

//    df.show

    val df2 = spark.sparkContext.makeRDD(List((3, "wangwu", 30))).toDF("id", "name", "age")
    df2.write
      .format("jdbc")
      .option("url", "jdbc:mysql://hadoop102:3306/test")
      .option("driver", "com.mysql.jdbc.Driver")
      .option("user", "root")
      .option("password", "******")
      .option("dbtable", "users")
      .mode(SaveMode.Append)
      .save()

    spark.close()
  }
}

package com.lotuslaw.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}


/**
 * @author: lotuslaw
 * @version: V1.0
 * @package: com.lotuslaw.spark.sql
 * @create: 2021-12-02 20:05
 * @description:
 */
object Spark_SparkSQL_Hive {

  def main(args: Array[String]): Unit = {

    System.setProperty("HADOOP_USER_NAME", "root")

    // TODO: 创建SparkSQL的运行环境
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
    val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()
    import spark.implicits._

    // 使用SparkSQL连接外置的Hive
    // 1. 拷贝Hive-size.xml文件到classpath下
    // 2. 启用Hive的支持
    // 3. 增加对应的依赖关系(包含MySQL驱动)
    spark.sql("show databases").show

    spark.close()
  }
}

SparkSQL项目实战

package com.lotuslaw.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}


/**
 * @author: lotuslaw
 * @version: V1.0
 * @package: com.lotuslaw.spark.sql
 * @create: 2021-12-02 20:05
 * @description:
 */
object Spark_SparkSQL_Test {

  def main(args: Array[String]): Unit = {

    System.setProperty("HADOOP_USER_NAME", "lotuslaw")
    // TODO: 创建SparkSQL的运行环境
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
    val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()
    import spark.implicits._

    spark.sql("use db_hive")

    // 准备数据
    spark.sql(
      """
        |CREATE TABLE IF NOT EXISTS `user_visit_action`(
        |  `date` string,
        |  `user_id` bigint,
        |  `session_id` string,
        |  `page_id` bigint,
        |  `action_time` string,
        |  `search_keyword` string,
        |  `click_category_id` bigint,
        |  `click_product_id` bigint,
        |  `order_category_ids` string,
        |  `order_product_ids` string,
        |  `pay_category_ids` string,
        |  `pay_product_ids` string,
        |  `city_id` bigint)
        |row format delimited fields terminated by '\t'
        |""".stripMargin
    )

    spark.sql(
      """
        |load data local inpath 'input/user_visit_action.txt' into table db_hive.user_visit_action
        |""".stripMargin
    )

    spark.sql(
      """
        |CREATE TABLE IF NOT EXISTS `product_info`(
        |  `product_id` bigint,
        |  `product_name` string,
        |  `extend_info` string)
        |row format delimited fields terminated by '\t'
        |""".stripMargin
    )

    spark.sql(
      """
        |load data local inpath 'input/product_info.txt' into table db_hive.product_info
        |""".stripMargin
    )

    spark.sql(
      """
        |CREATE TABLE IF NOT EXISTS `city_info`(
        |  `city_id` bigint,
        |  `city_name` string,
        |  `area` string)
        |row format delimited fields terminated by '\t'
        |""".stripMargin
    )

    spark.sql(
      """
        |load data local inpath 'input/city_info.txt' into table db_hive.city_info
        |""".stripMargin
    )

    spark.sql(
      """
        |select * from city_info
        |""".stripMargin).show


    spark.close()
  }
}
package com.lotuslaw.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession


/**
 * @author: lotuslaw
 * @version: V1.0
 * @package: com.lotuslaw.spark.sql
 * @create: 2021-12-02 20:05
 * @description:
 */
object Spark_SparkSQL_Test1 {

  def main(args: Array[String]): Unit = {

    System.setProperty("HADOOP_USER_NAME", "lotuslaw")
    // TODO: 创建SparkSQL的运行环境
    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
    val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()

    spark.sql("use db_hive")

    spark.sql(
      """
        |select
        |    *
        |from (
        |    select
        |        *,
        |        rank() over( partition by area order by clickCnt desc ) as rank
        |    from (
        |        select
        |           area,
        |           product_name,
        |           count(*) as clickCnt
        |        from (
        |            select
        |               a.*,
        |               p.product_name,
        |               c.area,
        |               c.city_name
        |            from user_visit_action a
        |            join product_info p on a.click_product_id = p.product_id
        |            join city_info c on a.city_id = c.city_id
        |            where a.click_product_id > -1
        |        ) t1 group by area, product_name
        |    ) t2
        |) t3 where rank <= 3
            """.stripMargin).show


    spark.close()
  }
}
package com.lotuslaw.spark.sql

import org.apache.spark.SparkConf
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
import org.apache.spark.sql.expressions.Aggregator

import scala.collection.mutable
import scala.collection.mutable.ListBuffer


/**
 * @author: lotuslaw
 * @version: V1.0
 * @package: com.lotuslaw.spark.sql
 * @create: 2021-12-02 20:05
 * @description:
 */
object Spark_SparkSQL_Test2 {

  def main(args: Array[String]): Unit = {
    System.setProperty("HADOOP_USER_NAME", "lotuslaw")

    val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
    val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()

    spark.sql("use db_hive")

    // 查询基本数据
    spark.sql(
      """
        |  select
        |     a.*,
        |     p.product_name,
        |     c.area,
        |     c.city_name
        |  from user_visit_action a
        |  join product_info p on a.click_product_id = p.product_id
        |  join city_info c on a.city_id = c.city_id
        |  where a.click_product_id > -1
            """.stripMargin).createOrReplaceTempView("t1")

    // 根据区域,商品进行数据聚合
    // 操作group by area, product_name内的数据
    spark.udf.register("cityRemark", functions.udaf(new CityRemarkUDAF()))
    spark.sql(
      """
        |  select
        |     area,
        |     product_name,
        |     count(*) as clickCnt,
        |     cityRemark(city_name) as city_remark
        |  from t1 group by area, product_name
            """.stripMargin).createOrReplaceTempView("t2")

    // 区域内对点击数量进行排行
    spark.sql(
      """
        |  select
        |      *,
        |      rank() over( partition by area order by clickCnt desc ) as rank
        |  from t2
            """.stripMargin).createOrReplaceTempView("t3")

    // 取前3名
    spark.sql(
      """
        | select
        |     *
        | from t3 where rank <= 3
            """.stripMargin).show(false)

    spark.close()
  }

  case class Buffer(var total: Long, var cityMap: mutable.Map[String, Long])

  // 自定义聚合函数:实现城市备注功能
  // 1. 继承Aggregator, 定义泛型
  //    IN : 城市名称
  //    BUF : Buffer =>【总点击数量,Map[(city, cnt), (city, cnt)]】
  //    OUT : 备注信息
  // 2. 重写方法(6)
  class CityRemarkUDAF extends Aggregator[String, Buffer, String] {
    // 缓冲区初始化
    override def zero: Buffer = {
      Buffer(0, mutable.Map[String, Long]())
    }

    // 更新缓冲区数据
    override def reduce(buff: Buffer, city: String): Buffer = {
      buff.total += 1
      val newCount = buff.cityMap.getOrElse(city, 0L) + 1
      buff.cityMap.update(city, newCount)
      buff
    }

    // 合并缓冲区数据
    override def merge(buff1: Buffer, buff2: Buffer): Buffer = {
      buff1.total += buff2.total

      val map1 = buff1.cityMap
      val map2 = buff2.cityMap

      map2.foreach {
        case (city, cnt) => {
          val newCount = map1.getOrElse(city, 0L) + cnt
          map1.update(city, newCount)
        }
      }
      buff1.cityMap = map1
      buff1
    }

    // 将统计的结果生成字符串信息
    override def finish(buff: Buffer): String = {
      val remarkList = ListBuffer[String]()

      val totalcnt = buff.total
      val cityMap = buff.cityMap

      // 降序排列
      val cityCntList = cityMap.toList.sortWith(
        (left, right) => {
          left._2 > right._2
        }
      ).take(2)

      val hasMore = cityMap.size > 2
      var rsum = 0L
      cityCntList.foreach {
        case (city, cnt) => {
          val r = cnt * 100 / totalcnt
          remarkList.append(s"${city} ${r}%")
          rsum += r
        }
      }
      if (hasMore) {
        remarkList.append(s"其他 ${100 - rsum}%")
      }

      remarkList.mkString(", ")
    }

    override def bufferEncoder: Encoder[Buffer] = Encoders.product

    override def outputEncoder: Encoder[String] = Encoders.STRING
  }

}