bpf: kptr

Signed-off-by: Alexei Starovoitov <ast@kernel.org>
diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index afe3d0d..365ad3e 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -5086,6 +5086,26 @@ union bpf_attr {
  *	Return
  *		0 on success, or a negative error in case of failure. On error
  *		*dst* buffer is zeroed out.
+ *
+ * long bpf_kptr_try_set(void *kptr, void *ptr)
+ *	Description
+ *		**ptr** is a refcnt-ed ptr_to_btf_id.
+ *		Atomically if (\*kptr == NULL) { \*kptr = ptr; refcount_inc((struct btf_id \*)ptr->refcnt); }
+ *	Return
+ *		Returns -EBUSY if \*kptr != NULL.
+ *
+ * void *bpf_kptr_get(void *kptr)
+ *	Description
+ *		if (\*kptr != NULL) { refcount_inc_not_zero((struct btf_id \*)\*kptr->refcnt); return \*kptr; }
+ *	Return
+ *		Returns NULL if \*kptr == NULL || refcount is already zero.
+ *
+ * void *bpf_kptr_xchg(void *kptr, void *ptr)
+ *	Description
+ *		Does xchg(kptr, ptr); where **ptr** is a refcnt-ed ptr_to_btf_id or NULL.
+ *	Return
+ *		Returns previous value of \*kptr which can be NULL.
+ *
  */
 #define __BPF_FUNC_MAPPER(FN)		\
 	FN(unspec),			\
@@ -5280,6 +5300,9 @@ union bpf_attr {
 	FN(xdp_load_bytes),		\
 	FN(xdp_store_bytes),		\
 	FN(copy_from_user_task),	\
+	FN(kptr_try_set),		\
+	FN(kptr_get),			\
+	FN(kptr_xchg),			\
 	/* */
 
 /* integer value in 'imm' field of BPF_CALL instruction selects which helper
diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
index 4e5969f..0c3b806 100644
--- a/kernel/bpf/helpers.c
+++ b/kernel/bpf/helpers.c
@@ -1377,6 +1377,67 @@ void bpf_timer_cancel_and_free(void *val)
 	kfree(t);
 }
 
+BPF_CALL_2(bpf_kptr_get, void **, kptr, int, refcnt_off)
+{
+	void *ptr = READ_ONCE(kptr);
+
+	if (!ptr)
+		return 0;
+	/* ptr->refcnt could be == 0 if another cpu did
+	 * ptr2 = bpf_kptr_xchg();
+	 * bpf_*_release(ptr2);
+	 */
+	if (!refcount_inc_not_zero((refcount_t *)(ptr + refcnt_off)))
+		return 0;
+	return (long) ptr;
+}
+
+static const struct bpf_func_proto bpf_kptr_get_proto = {
+	.func		= bpf_kptr_get,
+	.gpl_only	= false,
+	.ret_type	= RET_PTR_TO_BTF_ID_OR_NULL,
+	.arg1_type	= ARG_PTR_TO_MAP_VALUE,
+};
+
+BPF_CALL_2(bpf_kptr_xchg, void **, kptr, void *, ptr)
+{
+	/* ptr is ptr_to_btf_id returned from bpf_*_lookup() with ptr->refcnt >= 1
+	 * or ptr == NULL.
+	 * returns ptr_to_btf_id with refcnt >= 1 or NULL
+	 */
+	return (long) xchg(kptr, ptr);
+}
+
+static const struct bpf_func_proto bpf_kptr_xchg_proto = {
+	.func		= bpf_kptr_xchg,
+	.gpl_only	= false,
+	.ret_type	= RET_PTR_TO_BTF_ID_OR_NULL,
+	.arg1_type	= ARG_PTR_TO_MAP_VALUE,
+};
+
+BPF_CALL_3(bpf_kptr_try_set, void **, kptr, void *, ptr, int, refcnt_off)
+{
+	/* ptr is ptr_to_btf_id returned from bpf_*_lookup() with ptr->refcnt >= 1
+	 * refcount_inc() has to be done before cmpxchg() because
+	 * another cpu might do bpf_kptr_xchg+release.
+	 */
+	refcount_inc((refcount_t *)(ptr + refcnt_off));
+	if (cmpxchg(kptr, NULL, ptr)) {
+		/* refcnt >= 2 here */
+		refcount_dec((refcount_t *)(ptr + refcnt_off));
+		return -EBUSY;
+	}
+	return 0;
+}
+
+static const struct bpf_func_proto bpf_kptr_try_set_proto = {
+	.func		= bpf_kptr_try_set,
+	.gpl_only	= false,
+	.ret_type	= RET_INTEGER,
+	.arg1_type	= ARG_PTR_TO_MAP_VALUE,
+	.arg2_type	= ARG_PTR_TO_BTF_ID,
+};
+
 const struct bpf_func_proto bpf_get_current_task_proto __weak;
 const struct bpf_func_proto bpf_get_current_task_btf_proto __weak;
 const struct bpf_func_proto bpf_probe_read_user_proto __weak;
diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h
index afe3d0d..365ad3e 100644
--- a/tools/include/uapi/linux/bpf.h
+++ b/tools/include/uapi/linux/bpf.h
@@ -5086,6 +5086,26 @@ union bpf_attr {
  *	Return
  *		0 on success, or a negative error in case of failure. On error
  *		*dst* buffer is zeroed out.
+ *
+ * long bpf_kptr_try_set(void *kptr, void *ptr)
+ *	Description
+ *		**ptr** is a refcnt-ed ptr_to_btf_id.
+ *		Atomically if (\*kptr == NULL) { \*kptr = ptr; refcount_inc((struct btf_id \*)ptr->refcnt); }
+ *	Return
+ *		Returns -EBUSY if \*kptr != NULL.
+ *
+ * void *bpf_kptr_get(void *kptr)
+ *	Description
+ *		if (\*kptr != NULL) { refcount_inc_not_zero((struct btf_id \*)\*kptr->refcnt); return \*kptr; }
+ *	Return
+ *		Returns NULL if \*kptr == NULL || refcount is already zero.
+ *
+ * void *bpf_kptr_xchg(void *kptr, void *ptr)
+ *	Description
+ *		Does xchg(kptr, ptr); where **ptr** is a refcnt-ed ptr_to_btf_id or NULL.
+ *	Return
+ *		Returns previous value of \*kptr which can be NULL.
+ *
  */
 #define __BPF_FUNC_MAPPER(FN)		\
 	FN(unspec),			\
@@ -5280,6 +5300,9 @@ union bpf_attr {
 	FN(xdp_load_bytes),		\
 	FN(xdp_store_bytes),		\
 	FN(copy_from_user_task),	\
+	FN(kptr_try_set),		\
+	FN(kptr_get),			\
+	FN(kptr_xchg),			\
 	/* */
 
 /* integer value in 'imm' field of BPF_CALL instruction selects which helper
diff --git a/tools/testing/selftests/bpf/progs/test_bpf_nf.c b/tools/testing/selftests/bpf/progs/test_bpf_nf.c
index f00a973..4229a4c9 100644
--- a/tools/testing/selftests/bpf/progs/test_bpf_nf.c
+++ b/tools/testing/selftests/bpf/progs/test_bpf_nf.c
@@ -32,6 +32,17 @@ struct nf_conn *bpf_skb_ct_lookup(struct __sk_buff *, struct bpf_sock_tuple *, u
 				  struct bpf_ct_opts___local *, u32) __ksym;
 void bpf_ct_release(struct nf_conn *) __ksym;
 
+struct map_val {
+	struct nf_conn *ct;
+};
+
+struct {
+	__uint(type, BPF_MAP_TYPE_HASH);
+	__uint(max_entries, 1);
+	__type(key, __u32);
+	__type(value, struct map_val);
+} hash SEC(".maps");
+
 static __always_inline void
 nf_ct_test(struct nf_conn *(*func)(void *, struct bpf_sock_tuple *, u32,
 				   struct bpf_ct_opts___local *, u32),
@@ -44,6 +55,23 @@ nf_ct_test(struct nf_conn *(*func)(void *, struct bpf_sock_tuple *, u32,
 	__builtin_memset(&bpf_tuple, 0, sizeof(bpf_tuple.ipv4));
 
 	ct = func(ctx, NULL, 0, &opts_def, sizeof(opts_def));
+	if (ct) {
+		struct map_val *val;
+		struct nf_conn *ct2, *ct3;
+		__u32 zero = 0;
+
+		val = bpf_map_lookup_elem(&hash, &zero);
+		bpf_kptr_try_set(&val->ct, ct);
+		bpf_ct_release(ct);
+
+		ct2 = bpf_kptr_get(&val->ct);
+		bpf_ct_release(ct2);
+
+		ct3 = bpf_kptr_xchg(&val->ct, NULL);
+		bpf_ct_release(ct3);
+	}
+
+	ct = func(ctx, NULL, 0, &opts_def, sizeof(opts_def));
 	if (ct)
 		bpf_ct_release(ct);
 	else