spark3.0版本中sparkSQL自定义聚合函数(UDAF)
spark3.0版本可以继承Aggregator> 1.继承import org.apache.spark.sql.expressions.Aggregator,定义泛型>IN:输入的数据类型>BUF:缓冲区的数据类型>OUT:输出的数据类型> 2.重写方法> 3.注册自定义聚合函数>spark.udf.register("函数名称",functions.udaf(new MyAgeAvg()))
·
spark3.0之前的版本中sparkSQL自定义聚合函数要继承UserDefinedAggregateFunction
类,重写8个方法,具体使用方法可参考https://blog.csdn.net/weixin_43866709/article/details/88914871
但是该类是弱类型的,实现逻辑的时候容易出错。
spark3.0版本可以继承Aggregator
1.继承import org.apache.spark.sql.expressions.Aggregator,定义泛型
IN:输入的数据类型
BUF:缓冲区的数据类型
OUT:输出的数据类型
2.重写方法
3.注册自定义聚合函数
spark.udf.register(“函数名称”,functions.udaf(new MyAgeAvg()))
具体实现案例如下,实现一个简单个求平均值的自定义聚合函数:
package com.zsz.spark.sql
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Encoders, SparkSession, functions}
object Spark_SparkSQL_UDAF1 {
def main(args: Array[String]): Unit = {
val conf: SparkConf = new SparkConf().setMaster("local").setAppName("newUDAF")
val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
// 隐式转换
import spark.implicits._
val rdd: RDD[(String, Int)] = spark.sparkContext.makeRDD(List(("zhangsan", 20), ("lisi", 30), ("wangwu", 40)))
val df: DataFrame = rdd.toDF("username", "age")
df.createTempView("user")
// 注册自定义聚合函数
spark.udf.register("MyAgeAvg",functions.udaf(new MyAgeAvg()))
spark.sql("select MyAgeAvg(age) from user").show()
spark.close()
}
/**
* 自定义聚合函数类:计算年龄的平均值
* 1.继承import org.apache.spark.sql.expressions.Aggregator,定义泛型
* IN:输入的数据类型
* BUF:缓冲区的数据类型
* OUT:输出的数据类型
* 2.重写方法
*/
case class Buff( var total:Long, var count:Long )
class MyAgeAvg extends Aggregator[Long,Buff,Long]{
// 初始值
override def zero: Buff = {
Buff(0L,0L)
}
// 缓冲区数据计算
override def reduce(buff: Buff, in: Long): Buff = {
buff.total += in
buff.count += 1
buff
}
// 合并缓冲区
override def merge(buff1: Buff, buff2: Buff): Buff = {
buff1.total += buff2.total
buff1.count += buff2.count
buff1
}
// 输出值计算
override def finish(buff: Buff): Long = {
buff.total/buff.count
}
// 缓冲区编码设置
override def bufferEncoder: Encoder[Buff] = {
Encoders.product
}
// 输出编码
override def outputEncoder: Encoder[Long] = {
Encoders.scalaLong
}
}
}
更多推荐
已为社区贡献2条内容
所有评论(0)