bpf: Introduce dynamic program extensions

Introduce dynamic program extensions. The users can load additional BPF
functions and replace global functions in previously loaded BPF programs while
these programs are executing.

Signed-off-by: Alexei Starovoitov <ast@kernel.org>
diff --git a/include/linux/bpf.h b/include/linux/bpf.h
index 8e3b8f4..05d1661 100644
--- a/include/linux/bpf.h
+++ b/include/linux/bpf.h
@@ -465,7 +465,8 @@ void notrace __bpf_prog_exit(struct bpf_prog *prog, u64 start);
 enum bpf_tramp_prog_type {
 	BPF_TRAMP_FENTRY,
 	BPF_TRAMP_FEXIT,
-	BPF_TRAMP_MAX
+	BPF_TRAMP_MAX,
+	BPF_TRAMP_REPLACE, /* more than MAX */
 };
 
 struct bpf_trampoline {
@@ -480,6 +481,11 @@ struct bpf_trampoline {
 		void *addr;
 		bool ftrace_managed;
 	} func;
+	/* if !NULL this is BPF_PROG_TYPE_EXT program that extends another BPF
+	 * program by replacing one of its functions. func.addr is the address
+	 * of the function it replaced.
+	 */
+	struct bpf_prog *extension_prog;
 	/* list of BPF programs using this trampoline */
 	struct hlist_head progs_hlist[BPF_TRAMP_MAX];
 	/* Number of attached programs. A counter per kind. */
@@ -1107,6 +1113,8 @@ int btf_check_func_arg_match(struct bpf_verifier_env *env, int subprog,
 			     struct bpf_reg_state *regs);
 int btf_prepare_func_args(struct bpf_verifier_env *env, int subprog,
 			  struct bpf_reg_state *reg);
+int btf_check_type_match(struct bpf_verifier_env *env, struct bpf_prog *prog,
+			 struct btf *btf, const struct btf_type *t);
 
 struct bpf_prog *bpf_prog_by_id(u32 id);
 
diff --git a/include/linux/bpf_types.h b/include/linux/bpf_types.h
index 9f326e6..c81d4ec 100644
--- a/include/linux/bpf_types.h
+++ b/include/linux/bpf_types.h
@@ -68,6 +68,8 @@ BPF_PROG_TYPE(BPF_PROG_TYPE_SK_REUSEPORT, sk_reuseport,
 #if defined(CONFIG_BPF_JIT)
 BPF_PROG_TYPE(BPF_PROG_TYPE_STRUCT_OPS, bpf_struct_ops,
 	      void *, void *)
+BPF_PROG_TYPE(BPF_PROG_TYPE_EXT, bpf_extension,
+	      void *, void *)
 #endif
 
 BPF_MAP_TYPE(BPF_MAP_TYPE_ARRAY, array_map_ops)
diff --git a/include/linux/btf.h b/include/linux/btf.h
index 881e9b7..5c1ea99 100644
--- a/include/linux/btf.h
+++ b/include/linux/btf.h
@@ -107,6 +107,11 @@ static inline u16 btf_type_vlen(const struct btf_type *t)
 	return BTF_INFO_VLEN(t->info);
 }
 
+static inline u16 btf_func_linkage(const struct btf_type *t)
+{
+	return BTF_INFO_VLEN(t->info);
+}
+
 static inline bool btf_type_kflag(const struct btf_type *t)
 {
 	return BTF_INFO_KFLAG(t->info);
diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index 033d90a..e81628e 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -180,6 +180,7 @@ enum bpf_prog_type {
 	BPF_PROG_TYPE_CGROUP_SOCKOPT,
 	BPF_PROG_TYPE_TRACING,
 	BPF_PROG_TYPE_STRUCT_OPS,
+	BPF_PROG_TYPE_EXT,
 };
 
 enum bpf_attach_type {
diff --git a/kernel/bpf/btf.c b/kernel/bpf/btf.c
index 832b5d7..32963b6 100644
--- a/kernel/bpf/btf.c
+++ b/kernel/bpf/btf.c
@@ -276,6 +276,11 @@ static const char * const btf_kind_str[NR_BTF_KINDS] = {
 	[BTF_KIND_DATASEC]	= "DATASEC",
 };
 
+static const char *btf_type_str(const struct btf_type *t)
+{
+	return btf_kind_str[BTF_INFO_KIND(t->info)];
+}
+
 struct btf_kind_operations {
 	s32 (*check_meta)(struct btf_verifier_env *env,
 			  const struct btf_type *t,
@@ -4115,6 +4120,148 @@ int btf_distill_func_proto(struct bpf_verifier_log *log,
 	return 0;
 }
 
+/* Compare BTFs of two functions assuming only scalars and pointers to context.
+ * t1 points to BTF_KIND_FUNC in btf1
+ * t2 points to BTF_KIND_FUNC in btf2
+ * Returns:
+ * EINVAL - function prototype mismatch
+ * EFAULT - verifier bug
+ * 0 - 99% match. The last 1% is validated by the verifier.
+ */
+int btf_check_func_type_match(struct bpf_verifier_log *log,
+			      struct btf *btf1, const struct btf_type *t1,
+			      struct btf *btf2, const struct btf_type *t2)
+{
+	const struct btf_param *args1, *args2;
+	const char *fn1, *fn2, *s1, *s2;
+	u32 nargs1, nargs2, i;
+
+	fn1 = btf_name_by_offset(btf1, t1->name_off);
+	fn2 = btf_name_by_offset(btf2, t2->name_off);
+
+	if (btf_func_linkage(t1) != BTF_FUNC_GLOBAL) {
+		bpf_log(log, "%s() is not a global function\n", fn1);
+		return -EINVAL;
+	}
+	if (btf_func_linkage(t2) != BTF_FUNC_GLOBAL) {
+		bpf_log(log, "%s() is not a global function\n", fn2);
+		return -EINVAL;
+	}
+
+	t1 = btf_type_by_id(btf1, t1->type);
+	if (!t1 || !btf_type_is_func_proto(t1))
+		return -EFAULT;
+	t2 = btf_type_by_id(btf2, t2->type);
+	if (!t2 || !btf_type_is_func_proto(t2))
+		return -EFAULT;
+
+	args1 = (const struct btf_param *)(t1 + 1);
+	nargs1 = btf_type_vlen(t1);
+	args2 = (const struct btf_param *)(t2 + 1);
+	nargs2 = btf_type_vlen(t2);
+
+	if (nargs1 != nargs2) {
+		bpf_log(log, "%s() has %d args while %s() has %d args\n",
+			fn1, nargs1, fn2, nargs2);
+		return -EINVAL;
+	}
+
+	t1 = btf_type_skip_modifiers(btf1, t1->type, NULL);
+	t2 = btf_type_skip_modifiers(btf2, t2->type, NULL);
+	if (t1->info != t2->info) {
+		bpf_log(log,
+			"Return type %s of %s() doesn't match type %s of %s()\n",
+			btf_type_str(t1), fn1,
+			btf_type_str(t2), fn2);
+		return -EINVAL;
+	}
+
+	for (i = 0; i < nargs1; i++) {
+		t1 = btf_type_skip_modifiers(btf1, args1[i].type, NULL);
+		t2 = btf_type_skip_modifiers(btf2, args2[i].type, NULL);
+
+		if (t1->info != t2->info) {
+			bpf_log(log, "arg%d in %s() is %s while %s() has %s\n",
+				i, fn1, btf_type_str(t1),
+				fn2, btf_type_str(t2));
+			return -EINVAL;
+		}
+		if (btf_type_has_size(t1) && t1->size != t2->size) {
+			bpf_log(log,
+				"arg%d in %s() has size %d while %s() has %d\n",
+				i, fn1, t1->size,
+				fn2, t2->size);
+			return -EINVAL;
+		}
+
+		/* global functions are validated with scalars and pointers
+		 * to context only. And only global functions can be replaced.
+		 * Hence type check only those types.
+		 */
+		if (btf_type_is_int(t1) || btf_type_is_enum(t1))
+			continue;
+		if (!btf_type_is_ptr(t1)) {
+			bpf_log(log,
+				"arg%d in %s() has unrecognized type\n",
+				i, fn1);
+			return -EINVAL;
+		}
+		t1 = btf_type_skip_modifiers(btf1, t1->type, NULL);
+		t2 = btf_type_skip_modifiers(btf2, t2->type, NULL);
+		if (!btf_type_is_struct(t1)) {
+			bpf_log(log,
+				"arg%d in %s() is not a pointer to context\n",
+				i, fn1);
+			return -EINVAL;
+		}
+		if (!btf_type_is_struct(t2)) {
+			bpf_log(log,
+				"arg%d in %s() is not a pointer to context\n",
+				i, fn2);
+			return -EINVAL;
+		}
+		/* This is an optional check to make program writing easier.
+		 * Compare names of structs and report an error to the user.
+		 * btf_prepare_func_args() already checked that t2 struct
+		 * is a context type. btf_prepare_func_args() will check
+		 * later that t1 struct is a context type as well.
+		 */
+		s1 = btf_name_by_offset(btf1, t1->name_off);
+		s2 = btf_name_by_offset(btf2, t2->name_off);
+		if (strcmp(s1, s2)) {
+			bpf_log(log,
+				"arg%d %s(struct %s *) doesn't match %s(struct %s *)\n",
+				i, fn1, s1, fn2, s2);
+			return -EINVAL;
+		}
+	}
+	return 0;
+}
+
+/* Compare BTFs of given program with BTF of target program */
+int btf_check_type_match(struct bpf_verifier_env *env, struct bpf_prog *prog,
+			 struct btf *btf2, const struct btf_type *t2)
+{
+	struct btf *btf1 = prog->aux->btf;
+	const struct btf_type *t1;
+	u32 btf_id = 0;
+
+	if (!prog->aux->func_info) {
+		bpf_log(&env->log, "Program extension requires BTF\n");
+		return -EINVAL;
+	}
+
+	btf_id = prog->aux->func_info[0].type_id;
+	if (!btf_id)
+		return -EFAULT;
+
+	t1 = btf_type_by_id(btf1, btf_id);
+	if (!t1 || !btf_type_is_func(t1))
+		return -EFAULT;
+
+	return btf_check_func_type_match(&env->log, btf1, t1, btf2, t2);
+}
+
 /* Compare BTF of a function with given bpf_reg_state.
  * Returns:
  * EFAULT - there is a verifier bug. Abort verification.
@@ -4224,6 +4371,7 @@ int btf_prepare_func_args(struct bpf_verifier_env *env, int subprog,
 {
 	struct bpf_verifier_log *log = &env->log;
 	struct bpf_prog *prog = env->prog;
+	enum bpf_prog_type prog_type = prog->type;
 	struct btf *btf = prog->aux->btf;
 	const struct btf_param *args;
 	const struct btf_type *t;
@@ -4261,6 +4409,8 @@ int btf_prepare_func_args(struct bpf_verifier_env *env, int subprog,
 		bpf_log(log, "Verifier bug in function %s()\n", tname);
 		return -EFAULT;
 	}
+	if (prog_type == BPF_PROG_TYPE_EXT)
+		prog_type = prog->aux->linked_prog->type;
 
 	t = btf_type_by_id(btf, t->type);
 	if (!t || !btf_type_is_func_proto(t)) {
@@ -4296,7 +4446,7 @@ int btf_prepare_func_args(struct bpf_verifier_env *env, int subprog,
 			continue;
 		}
 		if (btf_type_is_ptr(t) &&
-		    btf_get_prog_ctx_type(log, btf, t, prog->type, i)) {
+		    btf_get_prog_ctx_type(log, btf, t, prog_type, i)) {
 			reg[i + 1].type = PTR_TO_CTX;
 			continue;
 		}
diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c
index c26a714..4aaea62 100644
--- a/kernel/bpf/syscall.c
+++ b/kernel/bpf/syscall.c
@@ -1924,13 +1924,15 @@ bpf_prog_load_check_attach(enum bpf_prog_type prog_type,
 		switch (prog_type) {
 		case BPF_PROG_TYPE_TRACING:
 		case BPF_PROG_TYPE_STRUCT_OPS:
+		case BPF_PROG_TYPE_EXT:
 			break;
 		default:
 			return -EINVAL;
 		}
 	}
 
-	if (prog_fd && prog_type != BPF_PROG_TYPE_TRACING)
+	if (prog_fd && prog_type != BPF_PROG_TYPE_TRACING &&
+	    prog_type != BPF_PROG_TYPE_EXT)
 		return -EINVAL;
 
 	switch (prog_type) {
@@ -1973,6 +1975,10 @@ bpf_prog_load_check_attach(enum bpf_prog_type prog_type,
 		default:
 			return -EINVAL;
 		}
+	case BPF_PROG_TYPE_EXT:
+		if (expected_attach_type)
+			return -EINVAL;
+		/* fallthrough */
 	default:
 		return 0;
 	}
@@ -2175,7 +2181,8 @@ static int bpf_tracing_prog_attach(struct bpf_prog *prog)
 	int tr_fd, err;
 
 	if (prog->expected_attach_type != BPF_TRACE_FENTRY &&
-	    prog->expected_attach_type != BPF_TRACE_FEXIT) {
+	    prog->expected_attach_type != BPF_TRACE_FEXIT &&
+	    prog->type != BPF_PROG_TYPE_EXT) {
 		err = -EINVAL;
 		goto out_put_prog;
 	}
@@ -2242,12 +2249,14 @@ static int bpf_raw_tracepoint_open(const union bpf_attr *attr)
 
 	if (prog->type != BPF_PROG_TYPE_RAW_TRACEPOINT &&
 	    prog->type != BPF_PROG_TYPE_TRACING &&
+	    prog->type != BPF_PROG_TYPE_EXT &&
 	    prog->type != BPF_PROG_TYPE_RAW_TRACEPOINT_WRITABLE) {
 		err = -EINVAL;
 		goto out_put_prog;
 	}
 
-	if (prog->type == BPF_PROG_TYPE_TRACING) {
+	if (prog->type == BPF_PROG_TYPE_TRACING ||
+	    prog->type == BPF_PROG_TYPE_EXT) {
 		if (attr->raw_tracepoint.name) {
 			/* The attach point for this category of programs
 			 * should be specified via btf_id during program load.
diff --git a/kernel/bpf/trampoline.c b/kernel/bpf/trampoline.c
index 79a0441..194f25a 100644
--- a/kernel/bpf/trampoline.c
+++ b/kernel/bpf/trampoline.c
@@ -5,6 +5,12 @@
 #include <linux/filter.h>
 #include <linux/ftrace.h>
 
+/* dummy _ops. The verifier will operate on target program's ops. */
+const struct bpf_verifier_ops bpf_extension_verifier_ops = {
+};
+const struct bpf_prog_ops bpf_extension_prog_ops = {
+};
+
 /* btf_vmlinux has ~22k attachable functions. 1k htab is enough. */
 #define TRAMPOLINE_HASH_BITS 10
 #define TRAMPOLINE_TABLE_SIZE (1 << TRAMPOLINE_HASH_BITS)
@@ -186,8 +192,10 @@ static enum bpf_tramp_prog_type bpf_attach_type_to_tramp(enum bpf_attach_type t)
 	switch (t) {
 	case BPF_TRACE_FENTRY:
 		return BPF_TRAMP_FENTRY;
-	default:
+	case BPF_TRACE_FEXIT:
 		return BPF_TRAMP_FEXIT;
+	default:
+		return BPF_TRAMP_REPLACE;
 	}
 }
 
@@ -200,6 +208,26 @@ int bpf_trampoline_link_prog(struct bpf_prog *prog)
 	tr = prog->aux->trampoline;
 	kind = bpf_attach_type_to_tramp(prog->expected_attach_type);
 	mutex_lock(&tr->mutex);
+	if (kind == BPF_TRAMP_REPLACE) {
+		/* If this program already has an extension program
+		 * or it has fentry/fexit attached then return EBUSY.
+		 */
+		if (tr->extension_prog ||
+		    tr->progs_cnt[BPF_TRAMP_FENTRY] +
+		    tr->progs_cnt[BPF_TRAMP_FEXIT]) {
+			err = -EBUSY;
+			goto out;
+		}
+		tr->extension_prog = prog;
+		err = bpf_arch_text_poke(tr->func.addr, BPF_MOD_JUMP, NULL,
+					 prog->bpf_func);
+		goto out;
+	}
+	if (tr->extension_prog) {
+		/* cannot attach fentry/fexit if extension prog is attached */
+		err = -EBUSY;
+		goto out;
+	}
 	if (tr->progs_cnt[BPF_TRAMP_FENTRY] + tr->progs_cnt[BPF_TRAMP_FEXIT]
 	    >= BPF_MAX_TRAMP_PROGS) {
 		err = -E2BIG;
@@ -232,9 +260,17 @@ int bpf_trampoline_unlink_prog(struct bpf_prog *prog)
 	tr = prog->aux->trampoline;
 	kind = bpf_attach_type_to_tramp(prog->expected_attach_type);
 	mutex_lock(&tr->mutex);
+	if (kind == BPF_TRAMP_REPLACE) {
+		WARN_ON_ONCE(!tr->extension_prog);
+		err = bpf_arch_text_poke(tr->func.addr, BPF_MOD_JUMP,
+					 tr->extension_prog->bpf_func, NULL);
+		tr->extension_prog = NULL;
+		goto out;
+	}
 	hlist_del(&prog->aux->tramp_hlist);
 	tr->progs_cnt[kind]--;
 	err = bpf_trampoline_update(prog->aux->trampoline);
+out:
 	mutex_unlock(&tr->mutex);
 	return err;
 }
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index ca17dccc..5e4876e 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -9564,7 +9564,7 @@ static int do_check_common(struct bpf_verifier_env *env, int subprog)
 			subprog);
 
 	regs = state->frame[state->curframe]->regs;
-	if (subprog) {
+	if (subprog || env->prog->type == BPF_PROG_TYPE_EXT) {
 		ret = btf_prepare_func_args(env, subprog, regs);
 		if (ret)
 			goto out;
@@ -9737,6 +9737,7 @@ static int check_struct_ops_btf_id(struct bpf_verifier_env *env)
 static int check_attach_btf_id(struct bpf_verifier_env *env)
 {
 	struct bpf_prog *prog = env->prog;
+	bool prog_extension = prog->type == BPF_PROG_TYPE_EXT;
 	struct bpf_prog *tgt_prog = prog->aux->linked_prog;
 	u32 btf_id = prog->aux->attach_btf_id;
 	const char prefix[] = "btf_trace_";
@@ -9752,7 +9753,7 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
 	if (prog->type == BPF_PROG_TYPE_STRUCT_OPS)
 		return check_struct_ops_btf_id(env);
 
-	if (prog->type != BPF_PROG_TYPE_TRACING)
+	if (prog->type != BPF_PROG_TYPE_TRACING && !prog_extension)
 		return 0;
 
 	if (!btf_id) {
@@ -9788,8 +9789,38 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
 			return -EINVAL;
 		}
 		conservative = aux->func_info_aux[subprog].unreliable;
+		if (prog_extension) {
+			if (conservative) {
+				verbose(env,
+					"Cannot replace static functions\n");
+				return -EINVAL;
+			}
+			if (!prog->jit_requested) {
+				verbose(env,
+					"Extension programs should be JITed\n");
+				return -EINVAL;
+			}
+			env->ops = bpf_verifier_ops[tgt_prog->type];
+		}
+		if (!tgt_prog->jited) {
+			verbose(env, "Can attach to only JITed progs\n");
+			return -EINVAL;
+		}
+		if (tgt_prog->type == prog->type) {
+			/* Cannot fentry/fexit another fentry/fexit program.
+			 * Cannot attach program extension to another extension.
+			 * It's ok to replace fentry/fexit with extension.
+			 * It's ok to attach fentry/fexit to extension program.
+			 */
+			verbose(env, "Cannot recursively attach\n");
+			return -EINVAL;
+		}
 		key = ((u64)aux->id) << 32 | btf_id;
 	} else {
+		if (prog_extension) {
+			verbose(env, "Cannot replace kernel functions\n");
+			return -EINVAL;
+		}
 		key = btf_id;
 	}
 
@@ -9827,6 +9858,10 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
 		prog->aux->attach_func_proto = t;
 		prog->aux->attach_btf_trace = true;
 		return 0;
+	default:
+		if (!prog_extension)
+			return -EINVAL;
+		/* fallthrough */
 	case BPF_TRACE_FENTRY:
 	case BPF_TRACE_FEXIT:
 		if (!btf_type_is_func(t)) {
@@ -9834,6 +9869,9 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
 				btf_id);
 			return -EINVAL;
 		}
+		if (prog_extension &&
+		    btf_check_type_match(env, prog, btf, t))
+			return -EINVAL;
 		t = btf_type_by_id(btf, t->type);
 		if (!btf_type_is_func_proto(t))
 			return -EINVAL;
@@ -9857,18 +9895,6 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
 		if (ret < 0)
 			goto out;
 		if (tgt_prog) {
-			if (!tgt_prog->jited) {
-				/* for now */
-				verbose(env, "Can trace only JITed BPF progs\n");
-				ret = -EINVAL;
-				goto out;
-			}
-			if (tgt_prog->type == BPF_PROG_TYPE_TRACING) {
-				/* prevent cycles */
-				verbose(env, "Cannot recursively attach\n");
-				ret = -EINVAL;
-				goto out;
-			}
 			if (subprog == 0)
 				addr = (long) tgt_prog->bpf_func;
 			else
@@ -9890,8 +9916,6 @@ static int check_attach_btf_id(struct bpf_verifier_env *env)
 		if (ret)
 			bpf_trampoline_put(tr);
 		return ret;
-	default:
-		return -EINVAL;
 	}
 }
 
@@ -9961,10 +9985,6 @@ int bpf_check(struct bpf_prog **prog, union bpf_attr *attr,
 		goto skip_full_check;
 	}
 
-	ret = check_attach_btf_id(env);
-	if (ret)
-		goto skip_full_check;
-
 	env->strict_alignment = !!(attr->prog_flags & BPF_F_STRICT_ALIGNMENT);
 	if (!IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS))
 		env->strict_alignment = true;
@@ -10001,6 +10021,10 @@ int bpf_check(struct bpf_prog **prog, union bpf_attr *attr,
 	if (ret < 0)
 		goto skip_full_check;
 
+	ret = check_attach_btf_id(env);
+	if (ret)
+		goto skip_full_check;
+
 	ret = check_cfg(env);
 	if (ret < 0)
 		goto skip_full_check;
diff --git a/tools/include/uapi/linux/bpf.h b/tools/include/uapi/linux/bpf.h
index 033d90a..e81628e 100644
--- a/tools/include/uapi/linux/bpf.h
+++ b/tools/include/uapi/linux/bpf.h
@@ -180,6 +180,7 @@ enum bpf_prog_type {
 	BPF_PROG_TYPE_CGROUP_SOCKOPT,
 	BPF_PROG_TYPE_TRACING,
 	BPF_PROG_TYPE_STRUCT_OPS,
+	BPF_PROG_TYPE_EXT,
 };
 
 enum bpf_attach_type {
diff --git a/tools/lib/bpf/bpf.c b/tools/lib/bpf/bpf.c
index ed42b00..c6dafe5 100644
--- a/tools/lib/bpf/bpf.c
+++ b/tools/lib/bpf/bpf.c
@@ -237,7 +237,8 @@ int bpf_load_program_xattr(const struct bpf_load_program_attr *load_attr,
 	attr.expected_attach_type = load_attr->expected_attach_type;
 	if (attr.prog_type == BPF_PROG_TYPE_STRUCT_OPS) {
 		attr.attach_btf_id = load_attr->attach_btf_id;
-	} else if (attr.prog_type == BPF_PROG_TYPE_TRACING) {
+	} else if (attr.prog_type == BPF_PROG_TYPE_TRACING ||
+		   attr.prog_type == BPF_PROG_TYPE_EXT) {
 		attr.attach_btf_id = load_attr->attach_btf_id;
 		attr.attach_prog_fd = load_attr->attach_prog_fd;
 	} else {
diff --git a/tools/lib/bpf/libbpf.c b/tools/lib/bpf/libbpf.c
index 3afaca9..3f4ecfd 100644
--- a/tools/lib/bpf/libbpf.c
+++ b/tools/lib/bpf/libbpf.c
@@ -4816,7 +4816,8 @@ load_program(struct bpf_program *prog, struct bpf_insn *insns, int insns_cnt,
 	load_attr.license = license;
 	if (prog->type == BPF_PROG_TYPE_STRUCT_OPS) {
 		load_attr.attach_btf_id = prog->attach_btf_id;
-	} else if (prog->type == BPF_PROG_TYPE_TRACING) {
+	} else if (prog->type == BPF_PROG_TYPE_TRACING ||
+		   prog->type == BPF_PROG_TYPE_EXT) {
 		load_attr.attach_prog_fd = prog->attach_prog_fd;
 		load_attr.attach_btf_id = prog->attach_btf_id;
 	} else {
@@ -4899,7 +4900,8 @@ int bpf_program__load(struct bpf_program *prog, char *license, __u32 kern_ver)
 {
 	int err = 0, fd, i, btf_id;
 
-	if (prog->type == BPF_PROG_TYPE_TRACING) {
+	if (prog->type == BPF_PROG_TYPE_TRACING ||
+	    prog->type == BPF_PROG_TYPE_EXT) {
 		btf_id = libbpf_find_attach_btf_id(prog->section_name,
 						   prog->expected_attach_type,
 						   prog->attach_prog_fd);
@@ -5075,7 +5077,8 @@ __bpf_object__open(const char *path, const void *obj_buf, size_t obj_buf_sz,
 
 		bpf_program__set_type(prog, prog_type);
 		bpf_program__set_expected_attach_type(prog, attach_type);
-		if (prog_type == BPF_PROG_TYPE_TRACING)
+		if (prog_type == BPF_PROG_TYPE_TRACING ||
+		    prog_type == BPF_PROG_TYPE_EXT)
 			prog->attach_prog_fd = OPTS_GET(opts, attach_prog_fd, 0);
 	}
 
@@ -6144,6 +6147,7 @@ BPF_PROG_TYPE_FNS(xdp, BPF_PROG_TYPE_XDP);
 BPF_PROG_TYPE_FNS(perf_event, BPF_PROG_TYPE_PERF_EVENT);
 BPF_PROG_TYPE_FNS(tracing, BPF_PROG_TYPE_TRACING);
 BPF_PROG_TYPE_FNS(struct_ops, BPF_PROG_TYPE_STRUCT_OPS);
+BPF_PROG_TYPE_FNS(extension, BPF_PROG_TYPE_EXT);
 
 enum bpf_attach_type
 bpf_program__get_expected_attach_type(struct bpf_program *prog)
@@ -6243,6 +6247,10 @@ static const struct bpf_sec_def section_defs[] = {
 		.expected_attach_type = BPF_TRACE_FEXIT,
 		.is_attach_btf = true,
 		.attach_fn = attach_trace),
+	SEC_DEF("replace/", EXT,
+		.expected_attach_type = 0,
+		.is_attach_btf = true,
+		.attach_fn = attach_trace),
 	BPF_PROG_SEC("xdp",			BPF_PROG_TYPE_XDP),
 	BPF_PROG_SEC("perf_event",		BPF_PROG_TYPE_PERF_EVENT),
 	BPF_PROG_SEC("lwt_in",			BPF_PROG_TYPE_LWT_IN),
diff --git a/tools/lib/bpf/libbpf.h b/tools/lib/bpf/libbpf.h
index 01639f9..2a5e3b0 100644
--- a/tools/lib/bpf/libbpf.h
+++ b/tools/lib/bpf/libbpf.h
@@ -318,6 +318,7 @@ LIBBPF_API int bpf_program__set_xdp(struct bpf_program *prog);
 LIBBPF_API int bpf_program__set_perf_event(struct bpf_program *prog);
 LIBBPF_API int bpf_program__set_tracing(struct bpf_program *prog);
 LIBBPF_API int bpf_program__set_struct_ops(struct bpf_program *prog);
+LIBBPF_API int bpf_program__set_extension(struct bpf_program *prog);
 
 LIBBPF_API enum bpf_prog_type bpf_program__get_type(struct bpf_program *prog);
 LIBBPF_API void bpf_program__set_type(struct bpf_program *prog,
@@ -339,6 +340,7 @@ LIBBPF_API bool bpf_program__is_xdp(const struct bpf_program *prog);
 LIBBPF_API bool bpf_program__is_perf_event(const struct bpf_program *prog);
 LIBBPF_API bool bpf_program__is_tracing(const struct bpf_program *prog);
 LIBBPF_API bool bpf_program__is_struct_ops(const struct bpf_program *prog);
+LIBBPF_API bool bpf_program__is_extension(const struct bpf_program *prog);
 
 /*
  * No need for __attribute__((packed)), all members of 'bpf_map_def'
diff --git a/tools/lib/bpf/libbpf.map b/tools/lib/bpf/libbpf.map
index 64ec71b..b035122 100644
--- a/tools/lib/bpf/libbpf.map
+++ b/tools/lib/bpf/libbpf.map
@@ -228,7 +228,9 @@
 		bpf_prog_attach_xattr;
 		bpf_program__attach;
 		bpf_program__name;
+		bpf_program__is_extension;
 		bpf_program__is_struct_ops;
+		bpf_program__set_extension;
 		bpf_program__set_struct_ops;
 		btf__align_of;
 		libbpf_find_kernel_btf;
diff --git a/tools/lib/bpf/libbpf_probes.c b/tools/lib/bpf/libbpf_probes.c
index 8cc992b..b782ebe 100644
--- a/tools/lib/bpf/libbpf_probes.c
+++ b/tools/lib/bpf/libbpf_probes.c
@@ -107,6 +107,7 @@ probe_load(enum bpf_prog_type prog_type, const struct bpf_insn *insns,
 	case BPF_PROG_TYPE_CGROUP_SOCKOPT:
 	case BPF_PROG_TYPE_TRACING:
 	case BPF_PROG_TYPE_STRUCT_OPS:
+	case BPF_PROG_TYPE_EXT:
 	default:
 		break;
 	}
diff --git a/tools/testing/selftests/bpf/prog_tests/fexit_bpf2bpf.c b/tools/testing/selftests/bpf/prog_tests/fexit_bpf2bpf.c
index 7d3740d..0afa380 100644
--- a/tools/testing/selftests/bpf/prog_tests/fexit_bpf2bpf.c
+++ b/tools/testing/selftests/bpf/prog_tests/fexit_bpf2bpf.c
@@ -26,7 +26,7 @@ static void test_fexit_bpf2bpf_common(const char *obj_file,
 
 	link = calloc(sizeof(struct bpf_link *), prog_cnt);
 	prog = calloc(sizeof(struct bpf_program *), prog_cnt);
-	result = malloc(prog_cnt * sizeof(u64));
+	result = malloc((prog_cnt + 32 /* spare */) * sizeof(u64));
 	if (CHECK(!link || !prog || !result, "alloc_memory",
 		  "failed to alloc memory"))
 		goto close_prog;
@@ -106,8 +106,25 @@ static void test_target_yes_callees(void)
 				  prog_name);
 }
 
+static void test_func_replace(void)
+{
+	const char *prog_name[] = {
+		"fexit/test_pkt_access",
+		"fexit/test_pkt_access_subprog1",
+		"fexit/test_pkt_access_subprog2",
+		"fexit/test_pkt_access_subprog3",
+		"replace/get_skb_len",
+		"replace/get_skb_ifindex",
+	};
+	test_fexit_bpf2bpf_common("./fexit_bpf2bpf.o",
+				  "./test_pkt_access.o",
+				  ARRAY_SIZE(prog_name),
+				  prog_name);
+}
+
 void test_fexit_bpf2bpf(void)
 {
 	test_target_no_callees();
 	test_target_yes_callees();
+	test_func_replace();
 }
diff --git a/tools/testing/selftests/bpf/progs/fexit_bpf2bpf.c b/tools/testing/selftests/bpf/progs/fexit_bpf2bpf.c
index 7c17ee1..8b8c34c 100644
--- a/tools/testing/selftests/bpf/progs/fexit_bpf2bpf.c
+++ b/tools/testing/selftests/bpf/progs/fexit_bpf2bpf.c
@@ -94,4 +94,28 @@ int BPF_PROG(test_subprog3, int val, struct sk_buff *skb, int ret)
 	test_result_subprog3 = 1;
 	return 0;
 }
+
+__u64 test_get_skb_len = 0;
+SEC("replace/get_skb_len")
+int new_get_skb_len(struct __sk_buff *skb)
+{
+	int len = skb->len;
+
+	if (len != 74)
+		return 0;
+	test_get_skb_len = 1;
+	return 74; /* original get_skb_len() returns skb->len */
+}
+
+__u64 test_get_skb_ifindex = 0;
+SEC("replace/get_skb_ifindex")
+int new_get_skb_ifindex(int val, struct __sk_buff *skb, int var)
+{
+	int ifindex = skb->ifindex;
+
+	if (ifindex != 1 || val != 3 || var != 1)
+		return 0;
+	test_get_skb_ifindex = 1;
+	return 3; /* original get_skb_ifindex() returns val * ifindex * var */
+}
 char _license[] SEC("license") = "GPL";