前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >spark 编写udaf函数求中位数

spark 编写udaf函数求中位数

作者头像
yiduwangkai
发布2019-09-17 15:59:16
1.1K0
发布2019-09-17 15:59:16
举报
文章被收录于专栏:大数据进阶
代码语言:java
复制
package com.frank.sparktest.java;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class MedianUdaf extends UserDefinedAggregateFunction {

    private StructType inputSchema;
    private StructType bufferSchema;

    public MedianUdaf(){
        List<StructField> inputFields = new ArrayList<>();
        inputFields.add(DataTypes.createStructField("nums",DataTypes.IntegerType,true));
        inputSchema=DataTypes.createStructType(inputFields);
        List<StructField> bufferFields = new ArrayList<>();
        bufferFields.add(DataTypes.createStructField("datas",DataTypes.StringType,true));
        bufferSchema=DataTypes.createStructType(bufferFields);
    }

    @Override
    public StructType inputSchema() {
        return inputSchema;
    }

    @Override
    public StructType bufferSchema() {
        return bufferSchema;
    }

    @Override
    public DataType dataType() {
        return DataTypes.DoubleType;
    }

    @Override
    public boolean deterministic() {
        return true;
    }

    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0,0);
        buffer.update(1,0);
    }

    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        if (!input.isNullAt(0)){
            buffer.update(0,buffer.getString(0)+","+input.getInt(0));
        }
    }

    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        buffer1.update(0,buffer1.getString(0)+","+buffer2.getInt(0));
    }

    @Override
    public Object evaluate(Row buffer) {
        List<Integer> list = new ArrayList<Integer>();
        List<String> stringList = Arrays.asList(buffer.getString(0).split(","));
        for (String s : stringList){
            list.add(Integer.valueOf(s));
        }
        Collections.sort(list);
        int size = list.size();
        int num=0;
        if(size % 2 == 1) {
            num = list.get((size / 2)+1);
        }
        if(size %2  == 0) {
            num = (list.get(size / 2)+list.get((size / 2)+1))/2;
        }
        return num;
    }

}

上面是代码段,可以直接拿来使用

下面是测试程序

代码语言:javascript
复制
package com.frank.sparktest.java;

import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;

import java.io.IOException;
import java.util.stream.IntStream;

public class DemoUDAF {

    public static void main(String[] args) throws IOException {
        SQLContext sqlContext = SparkSession.builder().master("local").getOrCreate().sqlContext();
        sqlContext.udf().register("generate", (Integer start, Integer end)-> IntStream.range(start, end+1).boxed().toArray(), DataTypes.createArrayType(DataTypes.IntegerType));
        sqlContext.udf().register("media",new MedianUdaf());
        sqlContext.sql("select generate(1,10)").show();
    }
}
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档