SparkSQL
SparkSQL概述
SparkSQL核心编程
package com.lotuslaw.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* @author: lotuslaw
* @version: V1.0
* @package: com.lotuslaw.spark.sql
* @create: 2021-12-02 20:05
* @description:
*/
object Spark_SparkSQL_Basic {
def main(args: Array[String]): Unit = {
// TODO: 创建SparkSQL的运行环境
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
import spark.implicits._
// TODO: 创建运行环境
val df = spark.read.json("input/user.json")
// df.show()
df.createOrReplaceTempView("user")
// spark.sql("select * from user").show
// 在使用DataFrame时,如果涉及到转换操作,需要引入转换规则
// df.select("age", "username").show
// df.select($"age" + 1).show
// DataSet
val seq = Seq(1, 2, 3, 4)
val ds = seq.toDS()
// ds.show
val rdd = spark.sparkContext.makeRDD(List((1, "zhangsan", 30), (2, "lisi", 40)))
val df2 = rdd.toDF("id", "name", "age")
val rowRDD = df2.rdd
val ds2 = df2.as[User]
val df3 = ds2.toDF()
val ds3 = rdd.map {
case (id, name, age) => {
User(id, name, age)
}
}.toDS()
val userRDD = ds3.rdd
// 关闭环境
spark.close()
}
case class User(id: Int, name: String, age: Int)
}
package com.lotuslaw.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* @author: lotuslaw
* @version: V1.0
* @package: com.lotuslaw.spark.sql
* @create: 2021-12-02 20:05
* @description:
*/
object Spark_SparkSQL_UDF {
def main(args: Array[String]): Unit = {
// TODO: 创建SparkSQL的运行环境
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
import spark.implicits._
val df = spark.read.json("input/user.json")
df.createOrReplaceTempView("user")
spark.udf.register("prefixName", (name:String) => {
"Name" + name
})
spark.sql("select age, prefixName(username) from user").show()
spark.close()
}
}
package com.lotuslaw.spark.sql
import org.apache.parquet.filter2.predicate.Operators.UserDefined
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
/**
* @author: lotuslaw
* @version: V1.0
* @package: com.lotuslaw.spark.sql
* @create: 2021-12-02 20:05
* @description:
*/
object Spark_SparkSQL_UDAF {
def main(args: Array[String]): Unit = {
// TODO: 创建SparkSQL的运行环境
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
import spark.implicits._
val df = spark.read.json("input/user.json")
df.createOrReplaceTempView("user")
spark.udf.register("avgAge", new MyAvgUDAF())
spark.sql("select avgAge(age) from user").show
spark.close()
}
/*
自定义聚合函数类:计算年龄的平均值
1. 继承UserDefinedAggregateFunction
2. 重写方法(8)
*/
class MyAvgUDAF extends UserDefinedAggregateFunction{
// 输入数据的结构
override def inputSchema: StructType = {
StructType(
Array(
StructField("age", LongType)
)
)
}
// 缓冲区数据的结构
override def bufferSchema: StructType = {
StructType(
Array(
StructField("total", LongType),
StructField("count", LongType)
)
)
}
// 函数计算结果的数据类型
override def dataType: DataType = LongType
// 函数的稳定性
override def deterministic: Boolean = true
// 缓冲区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0L)
buffer.update(1, 0L)
}
// 根据输入的值更新缓冲区数据
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0, buffer.getLong(0) + input.getLong(0))
buffer.update(1, buffer.getLong(1) + 1)
}
// 缓冲区数据合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1))
}
override def evaluate(buffer: Row): Any = {
buffer.getLong(0) / buffer.getLong(1)
}
}
}
package com.lotuslaw.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
/**
* @author: lotuslaw
* @version: V1.0
* @package: com.lotuslaw.spark.sql
* @create: 2021-12-02 20:05
* @description:
*/
object Spark_SparkSQL_UDF1 {
def main(args: Array[String]): Unit = {
// TODO: 创建SparkSQL的运行环境
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
import spark.implicits._
val df = spark.read.json("input/user.json")
df.createOrReplaceTempView("user")
spark.udf.register("ageAvg", functions.udaf(new MyAvgUDAF()))
spark.sql("select ageAvg(age) from user").show
spark.close()
}
/*
自定义聚合函数类:计算年龄的平均值
1. 继承org.apache.spark.sql.expressions.Aggregator, 定义泛型
IN : 输入的数据类型 Long
BUF : 缓冲区的数据类型 Buff
OUT : 输出的数据类型 Long
2. 重写方法(6)
*/
case class Buff(var total: Long, var count: Long)
class MyAvgUDAF extends Aggregator[Long, Buff, Long] {
// 缓冲区的初始化
override def zero: Buff = {
Buff(0L, 0L)
}
// 根据输入的数据更新缓冲区的数据
override def reduce(b: Buff, a: Long): Buff = {
b.total = b.total + a
b.count = b.count + 1
b
}
// 合并缓冲区
override def merge(b1: Buff, b2: Buff): Buff = {
b1.total = b1.total + b2.total
b1.count = b1.count + b2.count
b1
}
// 计算结果
override def finish(reduction: Buff): Long = {
reduction.total / reduction.count
}
// 缓冲区的编码操作
override def bufferEncoder: Encoder[Buff] = Encoders.product
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
}
package com.lotuslaw.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
/**
* @author: lotuslaw
* @version: V1.0
* @package: com.lotuslaw.spark.sql
* @create: 2021-12-02 20:05
* @description:
*/
object Spark_SparkSQL_UDF2 {
def main(args: Array[String]): Unit = {
// TODO: 创建SparkSQL的运行环境
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
import spark.implicits._
val df = spark.read.json("input/user.json")
// 早期版本中,spark不能在sql中使用强类型UDAF操作
// SQL & DSL
// 早期的UDAF强类型聚合函数使用DSL语法操作
val ds = df.as[User]
// 将UDAF函数转换为查询的列对象
val udafCol = new MyAvgUDAF().toColumn
ds.select(udafCol).show
spark.close()
}
/*
自定义聚合函数类:计算年龄的平均值
1. 继承org.apache.spark.sql.expressions.Aggregator, 定义泛型
IN : 输入的数据类型 User
BUF : 缓冲区的数据类型 Buff
OUT : 输出的数据类型 Long
2. 重写方法(6)
*/
case class User(username: String, age:Long)
case class Buff(var total: Long, var count: Long)
class MyAvgUDAF extends Aggregator[User, Buff, Long]{
// 缓冲区的初始化
override def zero: Buff = {
Buff(0L, 0L)
}
// 根据输入的数据更新缓冲区的数据
override def reduce(b: Buff, a: User): Buff = {
b.total = b.total + a.age
b.count = b.count + 1
b
}
// 合并缓冲区
override def merge(b1: Buff, b2: Buff): Buff = {
b1.total = b1.total + b2.total
b1.count = b1.count + b2.count
b1
}
// 计算结果
override def finish(reduction: Buff): Long = {
reduction.total / reduction.count
}
// 缓冲区的编码操作
override def bufferEncoder: Encoder[Buff] = Encoders.product
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
}
package com.lotuslaw.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SaveMode, SparkSession}
/**
* @author: lotuslaw
* @version: V1.0
* @package: com.lotuslaw.spark.sql
* @create: 2021-12-02 20:05
* @description:
*/
object Spark_SparkSQL_JDBC {
def main(args: Array[String]): Unit = {
// TODO: 创建SparkSQL的运行环境
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
import spark.implicits._
// 读MySQL数据
val df = spark.read
.format("jdbc")
.option("url", "jdbc:mysql://hadoop102:3306/test")
.option("driver", "com.mysql.jdbc.Driver")
.option("user", "root")
.option("password", "********")
.option("dbtable", "users")
.load()
// df.show
val df2 = spark.sparkContext.makeRDD(List((3, "wangwu", 30))).toDF("id", "name", "age")
df2.write
.format("jdbc")
.option("url", "jdbc:mysql://hadoop102:3306/test")
.option("driver", "com.mysql.jdbc.Driver")
.option("user", "root")
.option("password", "******")
.option("dbtable", "users")
.mode(SaveMode.Append)
.save()
spark.close()
}
}
package com.lotuslaw.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
/**
* @author: lotuslaw
* @version: V1.0
* @package: com.lotuslaw.spark.sql
* @create: 2021-12-02 20:05
* @description:
*/
object Spark_SparkSQL_Hive {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME", "root")
// TODO: 创建SparkSQL的运行环境
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()
import spark.implicits._
// 使用SparkSQL连接外置的Hive
// 1. 拷贝Hive-size.xml文件到classpath下
// 2. 启用Hive的支持
// 3. 增加对应的依赖关系(包含MySQL驱动)
spark.sql("show databases").show
spark.close()
}
}
SparkSQL项目实战
package com.lotuslaw.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
/**
* @author: lotuslaw
* @version: V1.0
* @package: com.lotuslaw.spark.sql
* @create: 2021-12-02 20:05
* @description:
*/
object Spark_SparkSQL_Test {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME", "lotuslaw")
// TODO: 创建SparkSQL的运行环境
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()
import spark.implicits._
spark.sql("use db_hive")
// 准备数据
spark.sql(
"""
|CREATE TABLE IF NOT EXISTS `user_visit_action`(
| `date` string,
| `user_id` bigint,
| `session_id` string,
| `page_id` bigint,
| `action_time` string,
| `search_keyword` string,
| `click_category_id` bigint,
| `click_product_id` bigint,
| `order_category_ids` string,
| `order_product_ids` string,
| `pay_category_ids` string,
| `pay_product_ids` string,
| `city_id` bigint)
|row format delimited fields terminated by '\t'
|""".stripMargin
)
spark.sql(
"""
|load data local inpath 'input/user_visit_action.txt' into table db_hive.user_visit_action
|""".stripMargin
)
spark.sql(
"""
|CREATE TABLE IF NOT EXISTS `product_info`(
| `product_id` bigint,
| `product_name` string,
| `extend_info` string)
|row format delimited fields terminated by '\t'
|""".stripMargin
)
spark.sql(
"""
|load data local inpath 'input/product_info.txt' into table db_hive.product_info
|""".stripMargin
)
spark.sql(
"""
|CREATE TABLE IF NOT EXISTS `city_info`(
| `city_id` bigint,
| `city_name` string,
| `area` string)
|row format delimited fields terminated by '\t'
|""".stripMargin
)
spark.sql(
"""
|load data local inpath 'input/city_info.txt' into table db_hive.city_info
|""".stripMargin
)
spark.sql(
"""
|select * from city_info
|""".stripMargin).show
spark.close()
}
}
package com.lotuslaw.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
/**
* @author: lotuslaw
* @version: V1.0
* @package: com.lotuslaw.spark.sql
* @create: 2021-12-02 20:05
* @description:
*/
object Spark_SparkSQL_Test1 {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME", "lotuslaw")
// TODO: 创建SparkSQL的运行环境
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkSQL")
val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()
spark.sql("use db_hive")
spark.sql(
"""
|select
| *
|from (
| select
| *,
| rank() over( partition by area order by clickCnt desc ) as rank
| from (
| select
| area,
| product_name,
| count(*) as clickCnt
| from (
| select
| a.*,
| p.product_name,
| c.area,
| c.city_name
| from user_visit_action a
| join product_info p on a.click_product_id = p.product_id
| join city_info c on a.city_id = c.city_id
| where a.click_product_id > -1
| ) t1 group by area, product_name
| ) t2
|) t3 where rank <= 3
""".stripMargin).show
spark.close()
}
}
package com.lotuslaw.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Encoder, Encoders, SparkSession, functions}
import org.apache.spark.sql.expressions.Aggregator
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
/**
* @author: lotuslaw
* @version: V1.0
* @package: com.lotuslaw.spark.sql
* @create: 2021-12-02 20:05
* @description:
*/
object Spark_SparkSQL_Test2 {
def main(args: Array[String]): Unit = {
System.setProperty("HADOOP_USER_NAME", "lotuslaw")
val sparkConf = new SparkConf().setMaster("local[*]").setAppName("sparkSQL")
val spark = SparkSession.builder().enableHiveSupport().config(sparkConf).getOrCreate()
spark.sql("use db_hive")
// 查询基本数据
spark.sql(
"""
| select
| a.*,
| p.product_name,
| c.area,
| c.city_name
| from user_visit_action a
| join product_info p on a.click_product_id = p.product_id
| join city_info c on a.city_id = c.city_id
| where a.click_product_id > -1
""".stripMargin).createOrReplaceTempView("t1")
// 根据区域,商品进行数据聚合
// 操作group by area, product_name内的数据
spark.udf.register("cityRemark", functions.udaf(new CityRemarkUDAF()))
spark.sql(
"""
| select
| area,
| product_name,
| count(*) as clickCnt,
| cityRemark(city_name) as city_remark
| from t1 group by area, product_name
""".stripMargin).createOrReplaceTempView("t2")
// 区域内对点击数量进行排行
spark.sql(
"""
| select
| *,
| rank() over( partition by area order by clickCnt desc ) as rank
| from t2
""".stripMargin).createOrReplaceTempView("t3")
// 取前3名
spark.sql(
"""
| select
| *
| from t3 where rank <= 3
""".stripMargin).show(false)
spark.close()
}
case class Buffer(var total: Long, var cityMap: mutable.Map[String, Long])
// 自定义聚合函数:实现城市备注功能
// 1. 继承Aggregator, 定义泛型
// IN : 城市名称
// BUF : Buffer =>【总点击数量,Map[(city, cnt), (city, cnt)]】
// OUT : 备注信息
// 2. 重写方法(6)
class CityRemarkUDAF extends Aggregator[String, Buffer, String] {
// 缓冲区初始化
override def zero: Buffer = {
Buffer(0, mutable.Map[String, Long]())
}
// 更新缓冲区数据
override def reduce(buff: Buffer, city: String): Buffer = {
buff.total += 1
val newCount = buff.cityMap.getOrElse(city, 0L) + 1
buff.cityMap.update(city, newCount)
buff
}
// 合并缓冲区数据
override def merge(buff1: Buffer, buff2: Buffer): Buffer = {
buff1.total += buff2.total
val map1 = buff1.cityMap
val map2 = buff2.cityMap
map2.foreach {
case (city, cnt) => {
val newCount = map1.getOrElse(city, 0L) + cnt
map1.update(city, newCount)
}
}
buff1.cityMap = map1
buff1
}
// 将统计的结果生成字符串信息
override def finish(buff: Buffer): String = {
val remarkList = ListBuffer[String]()
val totalcnt = buff.total
val cityMap = buff.cityMap
// 降序排列
val cityCntList = cityMap.toList.sortWith(
(left, right) => {
left._2 > right._2
}
).take(2)
val hasMore = cityMap.size > 2
var rsum = 0L
cityCntList.foreach {
case (city, cnt) => {
val r = cnt * 100 / totalcnt
remarkList.append(s"${city} ${r}%")
rsum += r
}
}
if (hasMore) {
remarkList.append(s"其他 ${100 - rsum}%")
}
remarkList.mkString(", ")
}
override def bufferEncoder: Encoder[Buffer] = Encoders.product
override def outputEncoder: Encoder[String] = Encoders.STRING
}
}