import com.google.common.hash.Hashing
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.roaringbitmap.longlong.Roaring64Bitmap
import scala.collection.mutable
/**
* 自定义聚合函数
*
*/
class DistinctAggCount( ) extends UserDefinedAggregateFunction {
/**
* 聚合函数的输入参数数据类型
*
* @return
*/
override def inputSchema: StructType = {
new StructType()
// .add("stringInput", StringType)
.add("arrayInput", ArrayType(StringType))
}
/**
* 中间缓存的数据类型
*
* @return
*/
override def bufferSchema: StructType = {
StructType(StructField("unionSet", ArrayType(LongType)) :: Nil)
}
/**
* 最终输出结果的数据类型
*
* @return
*/
override def dataType: DataType = LongType
override def deterministic: Boolean = true
/**
* 初始值,要是DataSet没有数据,就返回该值
*
* @param buffer 缓冲区
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = Array[Long]()
}
/**
* 相当于把当前分区的,每行数据都需要进行计算,计算的结果保存到buffer中
*
* @param buffer 缓存
* @param input 输入
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
val v = input.get(0)
if (v != null) {
val arr1 = buffer.getAs[mutable.WrappedArray[Long]](0).toArray[Long]
// var b1 = util.BitSet.valueOf(arr1)
var rb = Roaring64Bitmap.bitmapOf(arr1.toSeq: _*)
v match {
case array: ArrayData =>
var i = 0
while (i < array.numElements()) {
val hash = Hashing.sipHash24.newHasher.putUnencodedChars(array.get(i, StringType).toString).hash().asLong()
rb.add(hash)
i += 1
}
case arr: mutable.WrappedArray[_] =>
var i = 0
while (i < arr.length) {
val hash = Hashing.sipHash24.newHasher.putUnencodedChars(arr(i).toString).hash().asLong()
rb.add(hash)
i += 1
}
case _ =>
if(v != null){
val hash = Hashing.sipHash24.newHasher.putUnencodedChars(v.toString).hash().asLong()
rb.add(hash)
}
}
buffer.update(0, rb.toArray)
}
}
}
/**
* 每个缓冲区的数据进行汇总
*
* @param buffer1 缓冲区1
* @param buffer2 缓冲区2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val arr1 = buffer1.getAs[mutable.WrappedArray[Long]](0).toArray[Long]
val arr2 = buffer2.getAs[mutable.WrappedArray[Long]](0).toArray[Long]
var b1 = Roaring64Bitmap.bitmapOf(arr1:_*)
var b2 = Roaring64Bitmap.bitmapOf(arr2:_*)
b1.or(b2)
b1.runOptimize()
buffer1.update(0, b1.toArray)
}
/**
* 计算最终的结果
*
* @param buffer 缓冲区
* @return
*/
override def evaluate(buffer: Row): Long = {
var rt = Roaring64Bitmap.bitmapOf(buffer.getAs[mutable.WrappedArray[Long]](0): _*)
rt.runOptimize()
rt.getLongCardinality
}
}