package com.frank.sparktest.java;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
public class MedianUdaf extends UserDefinedAggregateFunction {
private StructType inputSchema;
private StructType bufferSchema;
public MedianUdaf(){
List<StructField> inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("nums",DataTypes.IntegerType,true));
inputSchema=DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("datas",DataTypes.StringType,true));
bufferSchema=DataTypes.createStructType(bufferFields);
}
@Override
public StructType inputSchema() {
return inputSchema;
}
@Override
public StructType bufferSchema() {
return bufferSchema;
}
@Override
public DataType dataType() {
return DataTypes.DoubleType;
}
@Override
public boolean deterministic() {
return true;
}
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0,0);
buffer.update(1,0);
}
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
if (!input.isNullAt(0)){
buffer.update(0,buffer.getString(0)+","+input.getInt(0));
}
}
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
buffer1.update(0,buffer1.getString(0)+","+buffer2.getInt(0));
}
@Override
public Object evaluate(Row buffer) {
List<Integer> list = new ArrayList<Integer>();
List<String> stringList = Arrays.asList(buffer.getString(0).split(","));
for (String s : stringList){
list.add(Integer.valueOf(s));
}
Collections.sort(list);
int size = list.size();
int num=0;
if(size % 2 == 1) {
num = list.get((size / 2)+1);
}
if(size %2 == 0) {
num = (list.get(size / 2)+list.get((size / 2)+1))/2;
}
return num;
}
}
上面是代码段,可以直接拿来使用
下面是测试程序
package com.frank.sparktest.java;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import java.io.IOException;
import java.util.stream.IntStream;
public class DemoUDAF {
public static void main(String[] args) throws IOException {
SQLContext sqlContext = SparkSession.builder().master("local").getOrCreate().sqlContext();
sqlContext.udf().register("generate", (Integer start, Integer end)-> IntStream.range(start, end+1).boxed().toArray(), DataTypes.createArrayType(DataTypes.IntegerType));
sqlContext.udf().register("media",new MedianUdaf());
sqlContext.sql("select generate(1,10)").show();
}
}