selftests/bpf: add test for program chaining

Signed-off-by: Alexei Starovoitov <ast@kernel.org>
diff --git a/samples/bpf/bpf_helpers.h b/samples/bpf/bpf_helpers.h
index 52de9d8..89e3a1b 100644
--- a/samples/bpf/bpf_helpers.h
+++ b/samples/bpf/bpf_helpers.h
@@ -23,6 +23,8 @@ static int (*bpf_trace_printk)(const char *fmt, int fmt_size, ...) =
 	(void *) BPF_FUNC_trace_printk;
 static void (*bpf_tail_call)(void *ctx, void *map, int index) =
 	(void *) BPF_FUNC_tail_call;
+static void (*bpf_tail_call_next)(void *ctx) =
+	(void *) BPF_FUNC_tail_call_next;
 static unsigned long long (*bpf_get_smp_processor_id)(void) =
 	(void *) BPF_FUNC_get_smp_processor_id;
 static unsigned long long (*bpf_get_current_pid_tgid)(void) =
diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h
index 1e062bb..e2e29ad 100644
--- a/tools/include/uapi/linux/bpf.h
+++ b/tools/include/uapi/linux/bpf.h
@@ -82,6 +82,9 @@ enum bpf_cmd {
 	BPF_PROG_ATTACH,
 	BPF_PROG_DETACH,
 	BPF_PROG_TEST_RUN,
+	BPF_PROG_CHAIN_ADD,
+	BPF_PROG_CHAIN_DEL,
+	BPF_PROG_CHAIN_GET,
 };
 
 enum bpf_map_type {
@@ -201,6 +204,12 @@ union bpf_attr {
 		__u32		repeat;
 		__u32		duration;
 	} test;
+
+	struct { /* anonymous struct used by BPF_PROG_CHAIN_* commands */
+		__u32		root_prog_fd;
+		__u32		next_prog_fd;
+		__u32		priority;
+	};
 } __attribute__((aligned(8)));
 
 /* BPF helper function descriptions:
@@ -532,7 +541,8 @@ union bpf_attr {
 	FN(xdp_adjust_head),		\
 	FN(probe_read_str),		\
 	FN(get_socket_cookie),		\
-	FN(get_socket_uid),
+	FN(get_socket_uid),		\
+	FN(tail_call_next),
 
 /* integer value in 'imm' field of BPF_CALL instruction selects which helper
  * function eBPF program intends to call
diff --git a/tools/lib/bpf/bpf.c b/tools/lib/bpf/bpf.c
index f84c398..7f6516c 100644
--- a/tools/lib/bpf/bpf.c
+++ b/tools/lib/bpf/bpf.c
@@ -233,3 +233,25 @@ int bpf_prog_test_run(int prog_fd, int repeat, void *data, __u32 size,
 		*duration = attr.test.duration;
 	return ret;
 }
+
+int bpf_prog_chain_add(int root_prog_fd, int next_prog_fd, int priority)
+{
+	union bpf_attr attr;
+
+	bzero(&attr, sizeof(attr));
+	attr.root_prog_fd = root_prog_fd;
+	attr.next_prog_fd = next_prog_fd;
+	attr.priority = priority;
+	return sys_bpf(BPF_PROG_CHAIN_ADD, &attr, sizeof(attr));
+}
+
+int bpf_prog_chain_del(int root_prog_fd, int next_prog_fd, int priority)
+{
+	union bpf_attr attr;
+
+	bzero(&attr, sizeof(attr));
+	attr.root_prog_fd = root_prog_fd;
+	attr.next_prog_fd = next_prog_fd;
+	attr.priority = priority;
+	return sys_bpf(BPF_PROG_CHAIN_DEL, &attr, sizeof(attr));
+}
diff --git a/tools/lib/bpf/bpf.h b/tools/lib/bpf/bpf.h
index edb4dae..c1a84fc 100644
--- a/tools/lib/bpf/bpf.h
+++ b/tools/lib/bpf/bpf.h
@@ -50,5 +50,7 @@ int bpf_prog_detach(int attachable_fd, enum bpf_attach_type type);
 int bpf_prog_test_run(int prog_fd, int repeat, void *data, __u32 size,
 		      void *data_out, __u32 *size_out, __u32 *retval,
 		      __u32 *duration);
+int bpf_prog_chain_add(int root_prog_fd, int next_prog_fd, int priority);
+int bpf_prog_chain_del(int root_prog_fd, int next_prog_fd, int priority);
 
 #endif
diff --git a/tools/testing/selftests/bpf/Makefile b/tools/testing/selftests/bpf/Makefile
index d8d94b9..547a478 100644
--- a/tools/testing/selftests/bpf/Makefile
+++ b/tools/testing/selftests/bpf/Makefile
@@ -13,7 +13,7 @@
 
 TEST_GEN_PROGS = test_verifier test_tag test_maps test_lru_map test_lpm_map test_progs
 
-TEST_GEN_FILES = test_pkt_access.o test_xdp.o test_l4lb.o
+TEST_GEN_FILES = test_pkt_access.o test_xdp.o test_l4lb.o test_chain.o
 
 TEST_PROGS := test_kmod.sh
 
diff --git a/tools/testing/selftests/bpf/test_chain.c b/tools/testing/selftests/bpf/test_chain.c
new file mode 100644
index 0000000..3473f31
--- /dev/null
+++ b/tools/testing/selftests/bpf/test_chain.c
@@ -0,0 +1,121 @@
+/* Copyright (c) 2017 Facebook
+ *
+ * This program is free software; you can redistribute it and/or
+ * modify it under the terms of version 2 of the GNU General Public
+ * License as published by the Free Software Foundation.
+ */
+#include <stddef.h>
+#include <stdbool.h>
+#include <string.h>
+#include <linux/bpf.h>
+#include <linux/if_ether.h>
+#include <linux/if_packet.h>
+#include <linux/ip.h>
+#include <linux/ipv6.h>
+#include <linux/in.h>
+#include <linux/udp.h>
+#include <linux/tcp.h>
+#include <sys/socket.h>
+#include "bpf_helpers.h"
+
+#define htons __builtin_bswap16
+#define ntohs __builtin_bswap16
+int _version SEC("version") = 1;
+
+struct md {
+	__u8 proto;
+	__u8 is_ipv6;
+};
+
+struct bpf_map_def SEC("maps") pcpu = {
+	.type = BPF_MAP_TYPE_PERCPU_ARRAY,
+	.key_size = sizeof(__u32),
+	.value_size = sizeof(struct md),
+	.max_entries = 1,
+};
+
+static __always_inline int handle_ipv4(struct xdp_md *xdp, struct md *md)
+{
+	void *data_end = (void *)(long)xdp->data_end;
+	void *data = (void *)(long)xdp->data;
+	struct iphdr *iph = data + sizeof(struct ethhdr);
+
+	md->is_ipv6 = false;
+	if (iph + 1 > data_end)
+		return XDP_DROP;
+	md->proto = iph->protocol;
+	bpf_tail_call_next(xdp);
+	return XDP_DROP;
+}
+
+static __always_inline int handle_ipv6(struct xdp_md *xdp, struct md *md)
+{
+	void *data_end = (void *)(long)xdp->data_end;
+	void *data = (void *)(long)xdp->data;
+	struct ipv6hdr *ip6h = data + sizeof(struct ethhdr);
+
+	md->is_ipv6 = true;
+	if (ip6h + 1 > data_end)
+		return XDP_DROP;
+	md->proto = ip6h->nexthdr;
+	bpf_tail_call_next(xdp);
+	return XDP_DROP;
+}
+
+SEC("xdp_prog1")
+int _xdp_prog1(struct xdp_md *xdp)
+{
+	void *data_end = (void *)(long)xdp->data_end;
+	void *data = (void *)(long)xdp->data;
+	struct ethhdr *eth = data;
+	struct md *md;
+	__u32 key = 0;
+	__u16 h_proto;
+
+	if (eth + 1 > data_end)
+		return XDP_DROP;
+	h_proto = eth->h_proto;
+
+	md = bpf_map_lookup_elem(&pcpu, &key);
+	if (!md)
+		return XDP_DROP;
+
+	if (h_proto == htons(ETH_P_IP))
+		return handle_ipv4(xdp, md);
+	else if (h_proto == htons(ETH_P_IPV6))
+		return handle_ipv6(xdp, md);
+	else
+		return XDP_DROP;
+}
+
+SEC("xdp_prog2")
+int _xdp_prog2(struct xdp_md *xdp)
+{
+	struct md *md;
+	__u32 key = 0;
+
+	md = bpf_map_lookup_elem(&pcpu, &key);
+	if (!md)
+		return XDP_DROP;
+	if (md->is_ipv6 && md->proto == 6)
+		return XDP_PASS;
+	bpf_tail_call_next(xdp);
+	return XDP_ABORTED;
+}
+
+SEC("xdp_prog3")
+int _xdp_prog3(struct xdp_md *xdp)
+{
+	struct md *md;
+	__u32 key = 0;
+
+	md = bpf_map_lookup_elem(&pcpu, &key);
+	if (!md)
+		return XDP_DROP;
+	if (!md->is_ipv6 && md->proto == 6)
+		return XDP_TX;
+	bpf_tail_call_next(xdp);
+	return XDP_ABORTED;
+}
+
+char _license[] SEC("license") = "GPL";
diff --git a/tools/testing/selftests/bpf/test_progs.c b/tools/testing/selftests/bpf/test_progs.c
index 5275d4a..dd3bbc5 100644
--- a/tools/testing/selftests/bpf/test_progs.c
+++ b/tools/testing/selftests/bpf/test_progs.c
@@ -78,7 +78,7 @@ static int bpf_prog_load(const char *file, enum bpf_prog_type type,
 {
 	struct bpf_program *prog;
 	struct bpf_object *obj;
-	int err;
+	int err, i = 0;
 
 	obj = bpf_object__open(file);
 	if (IS_ERR(obj)) {
@@ -86,14 +86,14 @@ static int bpf_prog_load(const char *file, enum bpf_prog_type type,
 		return -ENOENT;
 	}
 
-	prog = bpf_program__next(NULL, obj);
-	if (!prog) {
+	bpf_object__for_each_program(prog, obj)
+		bpf_program__set_type(prog, type);
+/*	if (!prog) {
 		bpf_object__close(obj);
 		error_cnt++;
 		return -ENOENT;
-	}
+	}*/
 
-	bpf_program__set_type(prog, type);
 	err = bpf_object__load(obj);
 	if (err) {
 		bpf_object__close(obj);
@@ -102,7 +102,8 @@ static int bpf_prog_load(const char *file, enum bpf_prog_type type,
 	}
 
 	*pobj = obj;
-	*prog_fd = bpf_program__fd(prog);
+	bpf_object__for_each_program(prog, obj)
+		prog_fd[i++] = bpf_program__fd(prog);
 	return 0;
 }
 
@@ -269,6 +270,77 @@ static void test_l4lb(void)
 	bpf_object__close(obj);
 }
 
+static void test_chain(void)
+{
+	const char *file = "./test_chain.o";
+	struct bpf_object *obj;
+	__u32 duration, retval;
+	int err, prog_fd[3];
+
+	err = bpf_prog_load(file, BPF_PROG_TYPE_XDP, &obj, prog_fd);
+	if (err)
+		return;
+
+	err = bpf_prog_chain_add(prog_fd[0], prog_fd[1], 10);
+	if (err) {
+		printf("chain_add fail\n");
+		error_cnt++;
+	}
+	err = bpf_prog_chain_add(prog_fd[0], prog_fd[2], 100);
+	if (err) {
+		printf("chain_add fail\n");
+		error_cnt++;
+	}
+	err = bpf_prog_test_run(prog_fd[0], 100000, &pkt_v4, sizeof(pkt_v4),
+				NULL, NULL, &retval, &duration);
+	CHECK(err || errno || retval != 3, "ipv4_p100",
+	      "err %d errno %d retval %d duration %d\n",
+	      err, errno, retval, duration);
+	err = bpf_prog_test_run(prog_fd[0], 100000, &pkt_v6, sizeof(pkt_v6),
+				NULL, NULL, &retval, &duration);
+	CHECK(err || errno || retval != 2, "ipv6_p10",
+	      "err %d errno %d retval %d duration %d\n",
+	      err, errno, retval, duration);
+
+	err = bpf_prog_chain_del(prog_fd[0], prog_fd[1], 0);
+	if (err) {
+		printf("chain_del fail\n");
+		error_cnt++;
+	}
+	err = bpf_prog_chain_del(prog_fd[0], prog_fd[2], 0);
+	if (err) {
+		printf("chain_del fail\n");
+		error_cnt++;
+	}
+	err = bpf_prog_test_run(prog_fd[0], 100000, &pkt_v4, sizeof(pkt_v4),
+				NULL, NULL, &retval, &duration);
+	CHECK(err || errno || retval != 1, "nop",
+	      "err %d errno %d retval %d duration %d\n",
+	      err, errno, retval, duration);
+
+	err = bpf_prog_chain_add(prog_fd[0], prog_fd[1], 100);
+	if (err) {
+		printf("chain_add fail\n");
+		error_cnt++;
+	}
+	err = bpf_prog_chain_add(prog_fd[0], prog_fd[2], 10);
+	if (err) {
+		printf("chain_add fail\n");
+		error_cnt++;
+	}
+	err = bpf_prog_test_run(prog_fd[0], 100000, &pkt_v4, sizeof(pkt_v4),
+				NULL, NULL, &retval, &duration);
+	CHECK(err || errno || retval != 3, "ipv4_p10",
+	      "err %d errno %d retval %d duration %d\n",
+	      err, errno, retval, duration);
+	err = bpf_prog_test_run(prog_fd[0], 100000, &pkt_v6, sizeof(pkt_v6),
+				NULL, NULL, &retval, &duration);
+	CHECK(err || errno || retval != 2, "ipv6_p100",
+	      "err %d errno %d retval %d duration %d\n",
+	      err, errno, retval, duration);
+	bpf_object__close(obj);
+}
+
 int main(void)
 {
 	struct rlimit rinf = { RLIM_INFINITY, RLIM_INFINITY };
@@ -278,6 +350,7 @@ int main(void)
 	test_pkt_access();
 	test_xdp();
 	test_l4lb();
+	test_chain();
 
 	printf("Summary: %d PASSED, %d FAILED\n", pass_cnt, error_cnt);
 	return 0;