Flink 自定义各种UDF函数实践

在 Flink 中,用户自定义函数(User-defined Functions,UDF)是一个非常重要的特性,一些系统内置函数无法解决的需求,我们都可以用自定义 UDF 来实现。
总结下来,可以分为4种:UDF,UDTF,UDAF,UDTAF。
下面我们通过实际代码来讲解怎么使用吧。

一、自定义UDF

1. 概述

自定义 UDF(Scalar Functions),可以把 1 个标量值映射成 1 个标量值,即:一行输入,一行输出
想要实现自定义 UDF,需要继承org.apache.flink.table.functions.ScalarFunction 并且实现eval()方法。

2. 自定义UDF

 1import org.apache.flink.table.functions.ScalarFunction;
2
3public class UpperFunction extends ScalarFunction {
4
5    public String eval(String s) {
6
7        // 将输入字符串转换为大写
8        return s.toUpperCase();
9    }
10}

3. 使用

 1import com.eyesmoons.sql.functions.UpperFunction;
2import org.apache.flink.api.common.functions.MapFunction;
3import org.apache.flink.streaming.api.datastream.DataStream;
4import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
5import org.apache.flink.table.api.Table;
6import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
7import org.apache.flink.types.Row;
8
9public class UdfDemo {
10    public static void main(String[] args) throws Exception {
11        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment().setParallelism(1);
12        StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
13        // 注册UDF函数
14        tableEnv.registerFunction("my_upper"new UpperFunction());
15
16        // 创建输入流
17        DataStream inputStream = env.fromElements("hello""world""flink");
18
19        // 在 DataStream API 中使用自定义 UDF
20        DataStream resultStream = inputStream.map(new MapFunction() {
21            @Override
22            public String map(String value) throws Exception {
23                // 调用自定义 UDF
24                return new UpperFunction().eval(value);
25            }
26        });
27        // 打印结果
28        resultStream.print("DataStream");
29
30        // 在 SQL 查询中使用自定义 UDF
31        tableEnv.createTemporaryView("my_table", resultStream, "words");
32        Table table = tableEnv.sqlQuery("SELECT my_upper(words) FROM my_table");
33        DataStream resultTable = tableEnv.toDataStream(table);
34        // 打印结果
35        resultTable.print("table");
36
37        env.execute();
38    }
39}

二、自定义UDTF

1. 概述

跟自定义 UDF 一样,自定义 UDTF 的输入参数也是1个。
但是跟标量函数不同的是,它可以返回任意多行,即:一行输入,多行输出
要定义一个 UDTF,需要继承 org.apache.flink.table.functions.TableFunction,可以通过实现多个名为 eval 的方法对求值方法进行重载。
在 Table API 中,UDTF 是通过 .joinLateral(...) 或者 .leftOuterJoinLateral(...) 来使用的。joinLateral 算子会把外表(算子左侧的表)的每一行跟跟表值函数返回的所有行(位于算子右侧)进行 (cross)join。leftOuterJoinLateral 算子也是把外表(算子左侧的表)的每一行跟表值函数返回的所有行(位于算子右侧)进行(cross)join,并且如果表值函数返回 0 行也会保留外表的这一行。
在 SQL 里面用 JOIN 或者 以 ON TRUE 为条件的 LEFT JOIN 来配合 LATERAL TABLE() 的使用。

2. 自定义 UDTF

 1import org.apache.flink.table.annotation.DataTypeHint;
2import org.apache.flink.table.annotation.FunctionHint;
3import org.apache.flink.table.functions.TableFunction;
4import org.apache.flink.types.Row;
5
6//hint暗示,主要作用为类型推断时使用
7@FunctionHint(output = @DataTypeHint("ROW"))
8public class SplitUDTF extends TableFunction<Row{
9
10    public void eval(String str) {
11        // 按下划线分割输入字符串,并输出每个单词
12        for (String s : str.split("_")) {
13            collect(Row.of(s));
14        }
15    }
16}

3. 使用

① 先定义一个对象:

 1import lombok.Data;
2
3@Data
4public class WaterSensor {
5    private String id;
6    private Long length;
7    private Integer ct;
8
9    public WaterSensor() {
10    }
11
12    public WaterSensor(String id, Long length, Integer ct) {
13        this.id = id;
14        this.length = length;
15        this.ct = ct;
16    }
17}

② 写一个测试类:

 1import com.eyesmoons.sql.functions.SplitUDTF;
2import org.apache.flink.streaming.api.datastream.DataStreamSource;
3import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
4import org.apache.flink.table.api.Table;
5import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
6
7import static org.apache.flink.table.api.Expressions.$;
8import static org.apache.flink.table.api.Expressions.call;
9
10public class UDTFDemo {
11
12    public static void main(String[] args) throws Exception {
13        //1.获取执行环境
14        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment().setParallelism(1);
15        StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
16
17        //2.读取文件得到DataStream
18        DataStreamSource waterSensorDataStreamSource = env.fromElements(new WaterSensor("sensor_1"1000L10),
19                new WaterSensor("sensor_1"2000L20),
20                new WaterSensor("sensor_2"3000L30),
21                new WaterSensor("sensor_1"4000L40),
22                new WaterSensor("sensor_1"5000L50),
23                new WaterSensor("sensor_2"6000L60));
24
25        //3.将流转换为动态表
26        Table table = tableEnv.fromDataStream(waterSensorDataStreamSource);
27
28        //4先注册再使用
29        tableEnv.createTemporarySystemFunction("split", SplitUDTF.class);
30
31        // table api中使用
32        table.joinLateral(call("split", $("id")))
33                .select($("id"), $("word"))
34                .execute()
35                .print();
36
37        // SQL 中使用
38        tableEnv.executeSql("select id, word from " + table + ", lateral table(split(id))").print();
39    }
40
41}

三、自定义UDAF

1. 概述

自定义聚合函数(UDAGG)是把一个表(一行或者多行,每行可以有一列或者多列)聚合成一个标量值。即:多行输入,一行输出
自定义聚合函数是通过继承 AggregateFunction 来实现的。AggregateFunction 的工作过程如下。
① 首先,它需要一个 accumulator,它是一个数据结构,存储了聚合的中间结果。通过调用 AggregateFunctioncreateAccumulator() 方法创建一个空的 accumulator。
② 接下来,对于每一行数据,会调用 accumulate() 方法来更新 accumulator。当所有的数据都处理完了之后,通过调用 getValue 方法来计算和返回最终的结果。
所以,必须要实现以下几个方法:
1createAccumulator()
2accumulate()
3getValue()

2. 自定义 UDAF

① 自定义group_concat
这个函数其实 Flink 中可以使用 listagg 来达到同样的目的。
 1import org.apache.commons.collections.CollectionUtils;
2import org.apache.flink.table.functions.AggregateFunction;
3import org.slf4j.Logger;
4import org.slf4j.LoggerFactory;
5
6import java.util.ArrayList;
7import java.util.List;
8
9/**
10 * 自定义group_concat
11 * group_concat_list(list, separator), separator不填则默认为逗号
12 * AggregateFunction, T表示聚合输出的结果类型,ACC表示聚合的中间状态类型
13 */

14public class GroupConcatList extends AggregateFunction<StringGroupConcatList.AggregateList{
15
16    private static final Logger LOG = LoggerFactory.getLogger(GroupConcatList.class);
17
18    public static class AggregateList {
19        public List columnList;
20        public String delimiter;
21    }
22
23    /**
24     * 返回聚合结果
25     *
26     * @param acc ACC类型的累加器
27     */

28    @Override
29    public String getValue(GroupConcatList.AggregateList acc) {
30        if (CollectionUtils.isEmpty(acc.columnList)) {
31            return "";
32        }
33        return String.join(acc.delimiter, acc.columnList);
34    }
35
36    /**
37     * 创建累加器
38     *
39     * @return 累加器类型ACC
40     */

41    @Override
42    public GroupConcatList.AggregateList createAccumulator() {
43        GroupConcatList.AggregateList acc = new GroupConcatList.AggregateList();
44        acc.columnList = new ArrayList<>();
45        return acc;
46    }
47
48    /**
49     * 更新累加器
50     *
51     * @param acc   当前累加器,类型为ACC
52     * @param param 可变字符串,第一个字符串为值,第二个字符串为间隔符,若无间隔符则默认为逗号
53     */

54    public void accumulate(GroupConcatList.AggregateList acc, String... param) {
55        if (param.length == 1) {
56            acc.columnList.add(param[0]);
57            acc.delimiter = ",";
58        } else if (param.length == 2) {
59            acc.columnList.add(param[0]);
60            acc.delimiter = param[1];
61        } else {
62            LOG.error("参数错误");
63        }
64    }
65
66    /**
67     * 回撤相关操作
68     */

69    public void retract(GroupConcatList.AggregateList acc, String... param) {
70        acc.columnList.remove(param[0]);
71    }
72
73    /**
74     * 重置操作
75     */

76    public void resetAccumulator(GroupConcatList.AggregateList acc) {
77        acc.columnList.clear();
78    }
79}

② 自定义去重的group_concat

 1import org.apache.commons.collections.CollectionUtils;
2import org.apache.flink.table.functions.AggregateFunction;
3import org.slf4j.Logger;
4import org.slf4j.LoggerFactory;
5
6import java.util.ArrayList;
7import java.util.HashSet;
8import java.util.List;
9import java.util.Set;
10
11/**
12 * 自定义去重的group_concat
13 * group_concat_set(list, separator),separator不填则默认为逗号
14 */

15public class GroupConcatSet extends AggregateFunction<StringGroupConcatSet.AggregateList{
16
17    private static final Logger LOG = LoggerFactory.getLogger(GroupConcatSet.class);
18
19    public static class AggregateList {
20        public List columnList;
21        public String delimiter;
22    }
23
24    /**
25     * 返回聚合结果
26     *
27     * @param acc ACC类型的累加器
28     */

29    @Override
30    public String getValue(GroupConcatSet.AggregateList acc) {
31        if (CollectionUtils.isEmpty(acc.columnList)) {
32            return "";
33        }
34        Set set = new HashSet<>(acc.columnList);
35        return String.join(acc.delimiter, set);
36    }
37
38    /**
39     * 创建累加器
40     *
41     * @return 累加器类型ACC
42     */

43    @Override
44    public GroupConcatSet.AggregateList createAccumulator() {
45        GroupConcatSet.AggregateList acc = new GroupConcatSet.AggregateList();
46        acc.columnList = new ArrayList<>();
47        return acc;
48    }
49
50    /**
51     * 更新累加器
52     *
53     * @param acc   当前累加器,类型为ACC
54     * @param param 可变字符串,第一个字符串为值,第二个字符串为间隔符,若无间隔符则默认为逗号
55     */

56    public void accumulate(GroupConcatSet.AggregateList acc, String... param) {
57        if (param.length == 1) {
58            acc.columnList.add(param[0]);
59            acc.delimiter = ",";
60        } else if (param.length == 2) {
61            acc.columnList.add(param[0]);
62            acc.delimiter = param[1];
63        } else {
64            LOG.error("参数错误");
65        }
66    }
67
68    /**
69     * 回撤相关操作
70     */

71    public void retract(GroupConcatSet.AggregateList acc, String... param) {
72        acc.columnList.remove(param[0]);
73
74    }
75
76    /**
77     * 重置操作
78     */

79    public void resetAccumulator(GroupConcatSet.AggregateList acc) {
80        acc.columnList.clear();
81    }
82}

3. 使用

① 先定义一个对象

 1import lombok.Getter;
2import lombok.NoArgsConstructor;
3
4@NoArgsConstructor
5@Getter
6public class User {
7    private Integer id;
8    private Integer age;
9    private String gender;
10
11    public User(Integer id, Integer age, String gender) {
12        this.id = id;
13        this.age = age;
14        this.gender = gender;
15    }
16
17    public void setId(Integer id) {
18        this.id = id;
19    }
20
21    public void setAge(Integer age) {
22        this.age = age;
23    }
24
25    public void setGender(String gender) {
26        this.gender = gender;
27    }
28}

② 写一个测试类

 1import com.eyesmoons.sql.functions.GroupConcatList;
2import com.eyesmoons.sql.functions.GroupConcatSet;
3import org.apache.flink.api.common.functions.FilterFunction;
4import org.apache.flink.api.java.tuple.Tuple2;
5import org.apache.flink.streaming.api.datastream.DataStream;
6import org.apache.flink.streaming.api.datastream.DataStreamSource;
7import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
8import org.apache.flink.table.api.Table;
9import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
10import org.apache.flink.types.Row;
11
12public class FlinkSqlDemo {
13    public static void main(String[] args) throws Exception {
14        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
15        env.setParallelism(1);
16
17        DataStreamSource inputStream = env.fromElements("1,21,男""2,22,女""3,23,男""4,22,男""5,22,女");
18        DataStream dataStream = inputStream.map(line -> {
19            String[] fields = line.split(",");
20            return new User(Integer.parseInt(fields[0]), Integer.parseInt(fields[1]), fields[2]);
21        });
22        //创建表的执行环境
23        StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
24        // 创建表
25        Table dataTable = tableEnv.fromDataStream(dataStream);
26        // 注册udaf函数:多行输入,一行输出
27        tableEnv.createTemporarySystemFunction("group_concat_list"new GroupConcatList());
28        tableEnv.createTemporarySystemFunction("group_concat_set"new GroupConcatSet());
29        // 执行sql
30        tableEnv.createTemporaryView("t_user", dataTable);// within group(order by age)
31        String sql = "select age, listagg(gender) as genders from t_user group by age";
32//        String sql = "select age, group_concat_list(gender) as genders from t_user group by age";
33//        String sql = "select age, group_concat_set(gender, ',') as genders from t_user group by age";
34        Table resultSqlTable = tableEnv.sqlQuery(sql);
35        // 打印结果
36        tableEnv.toRetractStream(resultSqlTable, Row.class).filter((FilterFunction>) row -> row.f0).print("sql");
37        env.execute();
38    }
39}

四、自定义UDTAF

1. 概述

自定义表值聚合函数(UDTAGG)可以把一个表(一行或者多行,每行有一列或者多列)聚合成另一张表,结果中可以有多行多列,即:多行输入,多行输出

2. 自定义UDTAF

① 先定义一个对象

1public class vCTop2 {
2    public Integer first = Integer.MIN_VALUE;
3    public Integer second = Integer.MIN_VALUE;
4}

② 自定义UDTAF

 1import com.eyesmoons.sql.vCTop2;
2import org.apache.flink.api.java.tuple.Tuple2;
3import org.apache.flink.table.functions.TableAggregateFunction;
4import org.apache.flink.util.Collector;
5
6public class MyUDTAF extends TableAggregateFunction<Tuple2<IntegerInteger>, vCTop2{
7
8    //创建累加器
9    @Override
10    public vCTop2 createAccumulator() {
11        return new vCTop2();
12    }
13
14    //比较数据,如果当前数据大于累加器中存的数据则替换,并将原累加器中的数据往下(第二)赋值
15    public void accumulate(vCTop2 acc, Integer value) {
16        if (value > acc.first) {
17            acc.second = acc.first;
18            acc.first = value;
19        } else if (value > acc.second) {
20            acc.second = value;
21        }
22    }
23
24    //计算(排名)
25    public void emitValue(vCTop2 acc, Collector> out) {
26        // emit the value and rank
27        if (acc.first != Integer.MIN_VALUE) {
28            out.collect(Tuple2.of(acc.first, 1));
29        }
30        if (acc.second != Integer.MIN_VALUE) {
31            out.collect(Tuple2.of(acc.second, 2));
32        }
33    }
34}

3. 使用

 1import com.eyesmoons.sql.functions.MyUDTAF;
2import org.apache.flink.streaming.api.datastream.DataStreamSource;
3import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
4import org.apache.flink.table.api.Table;
5import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
6
7import static org.apache.flink.table.api.Expressions.$;
8import static org.apache.flink.table.api.Expressions.call;
9
10public class UDTAFDemo {
11    public static void main(String[] args) throws Exception {
12        //1.获取执行环境
13        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
14        env.setParallelism(1);
15        StreamTableEnvironment tableEnv = StreamTableEnvironment.create(env);
16
17        //2.读取文件得到DataStream
18        DataStreamSource waterSensorDataStreamSource = env.fromElements(new WaterSensor("sensor_1"1000L10),
19                new WaterSensor("sensor_1"2000L20),
20                new WaterSensor("sensor_2"3000L30),
21                new WaterSensor("sensor_1"4000L40),
22                new WaterSensor("sensor_1"5000L50),
23                new WaterSensor("sensor_2"6000L60));
24
25        // 3.将流转换为动态表
26        Table table = tableEnv.fromDataStream(waterSensorDataStreamSource);
27
28        // 4.先注册再使用
29        tableEnv.createTemporarySystemFunction("my_udtaf", MyUDTAF.class);
30
31        // table API
32        table.groupBy($("id"))
33                .flatAggregate(call("my_udtaf", $("ct")).as("top""rank"))
34                .select($("id"), $("top"), $("rank"))
35                .execute()
36                .print();
37
38    }
39}

五、Flink Sql 中使用

以上分享了在 Flink 的 table api 中自定义函数的用法以及示例。总结一下,分为四类:
① UDF
一条输入,一条输出。类似于 DataStream 中的 map 方法
② UDTF
一条输入,多条输出。类似于 DataStream 中的 flatMap 方法
③ UDAF
多条输入,一条输出。主要用于聚合的。
④ UDTAF
多条输入,多条输出。主要用于 table 的聚合。
除此之外,这些自定义函数也可以直接在 SQL 中使用,使用步骤如下:
① 将自定义 UDF 函数的实现逻辑,打包成 jar 包
1mvn clean packagen

② 上传 jar 包 到 flink 工程的 lib 目录下

③ 使用 Flink 中的 SQL API 直接进行使用

 1CREATE TABLE test(
2    id INT,
3    name STRING
4    PRIMARY KEY (idNOT ENFORCED
5  ) WITH (
6    'connector' = 'jdbc',
7    'url' = 'jdbc:mysql://127.0.0.1:3306/demo',
8    'username' = 'demo',
9    'password' = '123456',
10    'table-name' = 'test'
11  );
12//创建 function
13CREATE FUNCTION split as 'com.demo.udf.SplitUdf';
14//使用
15SELECT idnamesplit(nameas newName  from test;



往期推荐



用代码实例讲解Flink WaterMark

Flink 多维实时分析的真实案例

5种常见的Flink维表Join方案

Flink确定TaskManager个数以及内存计算

请使用浏览器的分享功能分享到微信等