sched_ext: [WIP] Implement sub-scheduler support
diff --git a/kernel/sched/ext.c b/kernel/sched/ext.c
index 8ccb5c7..3678384 100644
--- a/kernel/sched/ext.c
+++ b/kernel/sched/ext.c
@@ -37,6 +37,7 @@ enum scx_exit_kind {
 	SCX_EXIT_UNREG_BPF,	/* BPF-initiated unregistration */
 	SCX_EXIT_UNREG_KERN,	/* kernel-initiated unregistration */
 	SCX_EXIT_SYSRQ,		/* requested by 'S' sysrq */
+	SCX_EXIT_PARENT,	/* parent exiting */
 
 	SCX_EXIT_ERROR = 1024,	/* runtime error, error msg contains details */
 	SCX_EXIT_ERROR_BPF,	/* ERROR but triggered through scx_bpf_error() */
@@ -242,6 +243,8 @@ struct scx_dump_ctx {
 	u64			at_jiffies;
 };
 
+struct scx_sched;
+
 /**
  * struct sched_ext_ops - Operation table for BPF scheduler implementation
  *
@@ -669,6 +672,24 @@ struct sched_ext_ops {
 	void (*cgroup_set_weight)(struct cgroup *cgrp, u32 weight);
 #endif	/* CONFIG_EXT_GROUP_SCHED */
 
+#ifdef CONFIG_CGROUPS
+	/*
+	 * Hierarchical sub-scheduler support.
+	 */
+
+	/**
+	 * @cgroup_id: When >1, attach the scheduler as a sub-scheduler on the
+	 * specified cgroup.
+	 */
+	u64 sub_cgroup_id;
+
+	/**
+	 * @sub_attach: Attach @sub_sch as a sub-scheduler
+	 * @sub_sch: sub-scheduler being attached
+	 */
+	s32 (*sub_attach)(struct scx_sched *sub_sch);
+#endif
+
 	/*
 	 * All online ops must come before ops.cpu_online().
 	 */
@@ -850,6 +871,16 @@ struct scx_sched {
 	atomic_t		exit_kind;
 	struct scx_exit_info	*exit_info;
 
+#ifdef CONFIG_CGROUPS
+	bool			sub_disabling;
+	struct cgroup		*cgrp;
+	char			*cgrp_path;
+	struct scx_sched	*parent;
+	struct list_head	children;
+	struct list_head	sibling;
+	struct kset		*sub_kset;
+#endif
+
 	struct kobject		kobj;
 
 	struct kthread_worker	*helper;
@@ -1131,10 +1162,15 @@ static struct scx_dump_data scx_dump_data = {
 /* /sys/kernel/sched_ext interface */
 static struct kset *scx_kset;
 
+#ifdef CONFIG_CGROUPS
+static DEFINE_RAW_SPINLOCK(scx_sub_lock);
+#endif
+
 #define CREATE_TRACE_POINTS
 #include <trace/events/sched_ext.h>
 
 static void process_ddsp_deferred_locals(struct rq *rq);
+static void scx_disable(struct scx_sched *sch, enum scx_exit_kind kind);
 static void scx_bpf_kick_cpu(s32 cpu, u64 flags);
 static __printf(3, 4) void __scx_exit(enum scx_exit_kind kind, s64 exit_code,
 				      const char *fmt, ...);
@@ -4421,6 +4457,11 @@ static void scx_sched_free_rcu_work(struct work_struct *work)
 	struct scx_dispatch_q *dsq;
 	int node;
 
+#ifdef CONFIG_CGROUPS
+	kfree(sch->cgrp_path);
+	if (sch->cgrp)
+		cgroup_put(sch->cgrp);
+#endif
 	kthread_stop(sch->helper->task);
 	free_percpu(sch->event_stats_cpu);
 
@@ -4731,6 +4772,8 @@ static const char *scx_exit_reason(enum scx_exit_kind kind)
 		return "unregistered from the main kernel";
 	case SCX_EXIT_SYSRQ:
 		return "disabled by sysrq-S";
+	case SCX_EXIT_PARENT:
+		return "parent exiting";
 	case SCX_EXIT_ERROR:
 		return "runtime error";
 	case SCX_EXIT_ERROR_BPF:
@@ -4742,6 +4785,53 @@ static const char *scx_exit_reason(enum scx_exit_kind kind)
 	}
 }
 
+#ifdef CONFIG_CGROUPS
+static void scx_sub_disable(struct scx_sched *sch)
+{
+	struct scx_sched *child = NULL;
+
+	raw_spin_lock_irq(&scx_sub_lock);
+	if (sch->sub_disabling) {
+		/*
+		 * This function is called from a kthread worker and there can
+		 * be only one instance running. If @sch->sub_disabling is set,
+		 * it already finished disabling.
+		 */
+		raw_spin_unlock_irq(&scx_sub_lock);
+		return;
+	}
+	sch->sub_disabling = true;	/* prevents child creation */
+
+	while ((child = list_first_entry_or_null(&sch->children,
+						 struct scx_sched, sibling))) {
+		kobject_get(&child->kobj);
+		raw_spin_unlock_irq(&scx_sub_lock);
+
+		scx_disable(child, SCX_EXIT_PARENT);
+		kthread_flush_work(&child->disable_work);
+
+		kobject_put(&child->kobj);
+		raw_spin_lock_irq(&scx_sub_lock);
+	}
+
+	raw_spin_unlock_irq(&scx_sub_lock);
+
+	if (!sch->parent)
+		return;
+
+	/* TODO - perform actual disabling here */
+
+	if (sch->ops.exit)
+		SCX_CALL_OP(SCX_KF_UNLOCKED, exit, NULL, sch->exit_info);
+
+	raw_spin_lock_irq(&scx_sub_lock);
+	list_del_init(&sch->sibling);
+	raw_spin_unlock_irq(&scx_sub_lock);
+
+	kobject_del(&sch->kobj);
+}
+#endif
+
 static void scx_disable_workfn(struct kthread_work *work)
 {
 	struct scx_sched *sch = container_of(work, struct scx_sched, disable_work);
@@ -4761,6 +4851,12 @@ static void scx_disable_workfn(struct kthread_work *work)
 	ei->kind = kind;
 	ei->reason = scx_exit_reason(ei->kind);
 
+#if CONFIG_CGROUPS
+	scx_sub_disable(sch);
+	if (sch->parent)
+		return;
+#endif
+
 	/* guarantee forward progress by bypassing scx_ops */
 	scx_bypass(true);
 
@@ -4886,21 +4982,15 @@ static void scx_disable_workfn(struct kthread_work *work)
 	scx_bypass(false);
 }
 
-static void scx_disable(enum scx_exit_kind kind)
+static void scx_disable(struct scx_sched *sch, enum scx_exit_kind kind)
 {
 	int none = SCX_EXIT_NONE;
-	struct scx_sched *sch;
 
 	if (WARN_ON_ONCE(kind == SCX_EXIT_NONE || kind == SCX_EXIT_DONE))
 		kind = SCX_EXIT_ERROR;
 
-	rcu_read_lock();
-	sch = rcu_dereference(scx_root);
-	if (sch) {
-		atomic_try_cmpxchg(&sch->exit_kind, &none, kind);
-		kthread_queue_work(sch->helper, &sch->disable_work);
-	}
-	rcu_read_unlock();
+	atomic_try_cmpxchg(&sch->exit_kind, &none, kind);
+	kthread_queue_work(sch->helper, &sch->disable_work);
 }
 
 static void dump_newline(struct seq_buf *s)
@@ -5247,7 +5337,9 @@ static __printf(3, 4) void __scx_exit(enum scx_exit_kind kind, s64 exit_code,
 	rcu_read_unlock();
 }
 
-static struct scx_sched *scx_alloc_and_add_sched(struct sched_ext_ops *ops)
+static struct scx_sched *scx_alloc_and_add_sched(struct sched_ext_ops *ops,
+						 struct cgroup *cgrp,
+						 struct scx_sched *parent)
 {
 	struct scx_sched *sch;
 	int node, ret;
@@ -5302,12 +5394,49 @@ static struct scx_sched *scx_alloc_and_add_sched(struct sched_ext_ops *ops)
 	ops->priv = sch;
 
 	sch->kobj.kset = scx_kset;
-	ret = kobject_init_and_add(&sch->kobj, &scx_ktype, NULL, "root");
-	if (ret < 0)
+
+#ifdef CONFIG_CGROUPS
+	char *buf = kzalloc(PATH_MAX, GFP_KERNEL);
+	if (!buf)
+		goto err_stop_helper;
+	cgroup_path(cgrp, buf, PATH_MAX);
+	sch->cgrp_path = kstrdup(buf, GFP_KERNEL);
+	kfree(buf);
+	if (!sch->cgrp_path)
 		goto err_stop_helper;
 
+	sch->cgrp = cgrp;
+	INIT_LIST_HEAD(&sch->children);
+	INIT_LIST_HEAD(&sch->sibling);
+
+	if (parent)
+		ret = kobject_init_and_add(&sch->kobj, &scx_ktype,
+					   &parent->sub_kset->kobj,
+					   "sub-%llu", cgroup_id(cgrp));
+	else
+		ret = kobject_init_and_add(&sch->kobj, &scx_ktype, NULL, "root");
+
+	if (ret < 0)
+		goto err_free_cgrp_path;
+
+	if (ops->sub_attach) {
+		sch->sub_kset = kset_create_and_add("sub", NULL, &sch->kobj);
+		if (!sch->sub_kset) {
+			kobject_put(&sch->kobj);
+			return ERR_PTR(-ENOMEM);
+		}
+	}
+#else
+	ret = kobject_init_and_add(&sch->kobj, &scx_ktype, NULL, "root");
+	if (ret < 0)
+		goto err_free_cgrp_path;
+#endif
 	return sch;
 
+err_free_cgrp_path:
+#ifdef CONFIG_CGROUPS
+	kfree(sch->cgrp_path);
+#endif
 err_stop_helper:
 	kthread_stop(sch->helper->task);
 err_free_event_stats:
@@ -5392,7 +5521,7 @@ static int scx_enable(struct sched_ext_ops *ops, struct bpf_link *link)
 		goto err_unlock;
 	}
 
-	sch = scx_alloc_and_add_sched(ops);
+	sch = scx_alloc_and_add_sched(ops, &cgrp_dfl_root.cgrp, NULL);
 	if (IS_ERR(sch)) {
 		ret = PTR_ERR(sch);
 		goto err_unlock;
@@ -5619,6 +5748,107 @@ static int scx_enable(struct sched_ext_ops *ops, struct bpf_link *link)
 	return 0;
 }
 
+#ifdef CONFIG_CGROUPS
+static struct scx_sched *find_parent_sched(struct cgroup *cgrp)
+{
+	struct scx_sched *parent = scx_root, *pos;
+
+	lockdep_assert_held(&scx_sub_lock);
+again:
+	/* find the nearest ancestor */
+	list_for_each_entry(pos, &parent->children, sibling) {
+		if (pos->sub_disabling)
+			continue;
+		if (cgroup_is_descendant(cgrp, pos->cgrp)) {
+			parent = pos;
+			goto again;
+		}
+	}
+
+	/* can't attach twice to the same cgroup */
+	if (parent->cgrp == cgrp)
+		return ERR_PTR(-EBUSY);
+
+	/* are sub-schedulers allowed? */
+	if (!parent->ops.sub_attach)
+		return ERR_PTR(-EOPNOTSUPP);
+
+	/* shouldn't insert between the nearest ancestor and its children */
+	list_for_each_entry(pos, &parent->children, sibling)
+		if (cgroup_is_descendant(pos->cgrp, cgrp))
+			return ERR_PTR(-EBUSY);
+
+	return parent;
+}
+
+static int scx_sub_enable(struct sched_ext_ops *ops, struct bpf_link *link)
+{
+	struct cgroup *cgrp;
+	struct scx_sched *parent, *sch;
+	int ret;
+
+	mutex_lock(&scx_enable_mutex);
+
+	if (!scx_enabled()) {
+		ret = -ENODEV;
+		goto out_unlock;
+	}
+
+	cgrp = cgroup_get_from_id(ops->sub_cgroup_id);
+	if (IS_ERR(cgrp)) {
+		ret = PTR_ERR(cgrp);
+		goto out_unlock;
+	}
+
+	raw_spin_lock_irq(&scx_sub_lock);
+	parent = find_parent_sched(cgrp);
+	if (IS_ERR(parent)) {
+		raw_spin_unlock_irq(&scx_sub_lock);
+		ret = PTR_ERR(parent);
+		goto out_cgrp_put;
+	} else {
+		kobject_get(&parent->kobj);
+		raw_spin_unlock_irq(&scx_sub_lock);
+	}
+
+	sch = scx_alloc_and_add_sched(ops, cgrp, parent);
+	kobject_put(&parent->kobj);
+	if (IS_ERR(sch)) {
+		ret = PTR_ERR(sch);
+		goto out_cgrp_put;
+	}
+
+	sch->parent = parent;
+
+	raw_spin_lock_irq(&scx_sub_lock);
+	list_add_tail(&sch->sibling, &parent->children);
+	raw_spin_unlock_irq(&scx_sub_lock);
+
+	if (sch->ops.init) {
+		ret = SCX_CALL_OP_RET(SCX_KF_UNLOCKED, init, NULL);
+		if (ret) {
+			ret = ops_sanitize_err("init", ret);
+			scx_error("ops.init() failed (%d)", ret);
+			goto out_flush_disable;
+		}
+	}
+
+	kobject_uevent(&sch->kobj, KOBJ_ADD);
+	ret = 0;
+	goto out_unlock;
+
+out_cgrp_put:
+	cgroup_put(cgrp);
+out_unlock:
+	mutex_unlock(&scx_enable_mutex);
+	return ret;
+
+out_flush_disable:
+	mutex_unlock(&scx_enable_mutex);
+	kthread_flush_work(&sch->disable_work);
+	return 0;
+}
+#endif
 
 /********************************************************************************
  * bpf_struct_ops plumbing.
@@ -5726,6 +5956,11 @@ static int bpf_scx_init_member(const struct btf_type *t,
 	case offsetof(struct sched_ext_ops, hotplug_seq):
 		ops->hotplug_seq = *(u64 *)(udata + moff);
 		return 1;
+#ifdef CONFIG_CGROUPS
+	case offsetof(struct sched_ext_ops, sub_cgroup_id):
+		ops->sub_cgroup_id = *(u64 *)(udata + moff);
+		return 1;
+#endif
 	}
 
 	return 0;
@@ -5759,7 +5994,13 @@ static int bpf_scx_check_member(const struct btf_type *t,
 
 static int bpf_scx_reg(void *kdata, struct bpf_link *link)
 {
-	return scx_enable(kdata, link);
+	struct sched_ext_ops *ops = kdata;
+
+#ifdef CONFIG_CGROUPS
+	if (ops->sub_cgroup_id > 1)
+		return scx_sub_enable(ops, link);
+#endif
+	return scx_enable(ops, link);
 }
 
 static void bpf_scx_unreg(void *kdata, struct bpf_link *link)
@@ -5767,7 +6008,7 @@ static void bpf_scx_unreg(void *kdata, struct bpf_link *link)
 	struct sched_ext_ops *ops = kdata;
 	struct scx_sched *sch = ops->priv;
 
-	scx_disable(SCX_EXIT_UNREG);
+	scx_disable(sch, SCX_EXIT_UNREG);
 	kthread_flush_work(&sch->disable_work);
 	kobject_put(&sch->kobj);
 }
@@ -5824,6 +6065,9 @@ static void sched_ext_ops__cgroup_move(struct task_struct *p, struct cgroup *fro
 static void sched_ext_ops__cgroup_cancel_move(struct task_struct *p, struct cgroup *from, struct cgroup *to) {}
 static void sched_ext_ops__cgroup_set_weight(struct cgroup *cgrp, u32 weight) {}
 #endif
+#ifdef CONFIG_CGROUPS
+static s32 sched_ext_ops__sub_attach(struct scx_sched *sub_sch) { return -EINVAL; }
+#endif
 static void sched_ext_ops__cpu_online(s32 cpu) {}
 static void sched_ext_ops__cpu_offline(s32 cpu) {}
 static s32 sched_ext_ops__init(void) { return -EINVAL; }
@@ -5831,6 +6075,9 @@ static void sched_ext_ops__exit(struct scx_exit_info *info) {}
 static void sched_ext_ops__dump(struct scx_dump_ctx *ctx) {}
 static void sched_ext_ops__dump_cpu(struct scx_dump_ctx *ctx, s32 cpu, bool idle) {}
 static void sched_ext_ops__dump_task(struct scx_dump_ctx *ctx, struct task_struct *p) {}
+#ifdef CONFIG_CGROUPS
+
+#endif
 
 static struct sched_ext_ops __bpf_ops_sched_ext_ops = {
 	.select_cpu		= sched_ext_ops__select_cpu,
@@ -5861,6 +6108,9 @@ static struct sched_ext_ops __bpf_ops_sched_ext_ops = {
 	.cgroup_cancel_move	= sched_ext_ops__cgroup_cancel_move,
 	.cgroup_set_weight	= sched_ext_ops__cgroup_set_weight,
 #endif
+#ifdef CONFIG_CGROUPS
+	.sub_attach		= sched_ext_ops__sub_attach,
+#endif
 	.cpu_online		= sched_ext_ops__cpu_online,
 	.cpu_offline		= sched_ext_ops__cpu_offline,
 	.init			= sched_ext_ops__init,
@@ -5891,7 +6141,15 @@ static struct bpf_struct_ops bpf_sched_ext_ops = {
 
 static void sysrq_handle_sched_ext_reset(u8 key)
 {
-	scx_disable(SCX_EXIT_SYSRQ);
+	struct scx_sched *sch;
+
+	rcu_read_lock();
+	sch = rcu_dereference(scx_root);
+	if (sch)
+		scx_disable(sch, SCX_EXIT_SYSRQ);
+	else
+		pr_info("sched_ext: BPF schedulers not loaded\n");
+	rcu_read_unlock();
 }
 
 static const struct sysrq_key_op sysrq_sched_ext_reset_op = {
diff --git a/tools/sched_ext/scx_qmap.bpf.c b/tools/sched_ext/scx_qmap.bpf.c
index c3cd9a1..a4cd710 100644
--- a/tools/sched_ext/scx_qmap.bpf.c
+++ b/tools/sched_ext/scx_qmap.bpf.c
@@ -828,6 +828,11 @@ void BPF_STRUCT_OPS(qmap_exit, struct scx_exit_info *ei)
 	UEI_RECORD(uei, ei);
 }
 
+s32 BPF_STRUCT_OPS(qmap_sub_attach, struct scx_sched *sub_sch)
+{
+	return 0;
+}
+
 SCX_OPS_DEFINE(qmap_ops,
 	       .select_cpu		= (void *)qmap_select_cpu,
 	       .enqueue			= (void *)qmap_enqueue,
@@ -845,4 +850,5 @@ SCX_OPS_DEFINE(qmap_ops,
 	       .init			= (void *)qmap_init,
 	       .exit			= (void *)qmap_exit,
 	       .timeout_ms		= 5000U,
-	       .name			= "qmap");
+	       .name			= "qmap",
+	       .sub_attach		= (void *)qmap_sub_attach);
diff --git a/tools/sched_ext/scx_qmap.c b/tools/sched_ext/scx_qmap.c
index c4912ab..aa51a66 100644
--- a/tools/sched_ext/scx_qmap.c
+++ b/tools/sched_ext/scx_qmap.c
@@ -10,6 +10,7 @@
 #include <inttypes.h>
 #include <signal.h>
 #include <libgen.h>
+#include <sys/stat.h>
 #include <bpf/bpf.h>
 #include <scx/common.h>
 #include "scx_qmap.bpf.skel.h"
@@ -66,7 +67,7 @@ int main(int argc, char **argv)
 
 	skel->rodata->slice_ns = __COMPAT_ENUM_OR_ZERO("scx_public_consts", "SCX_SLICE_DFL");
 
-	while ((opt = getopt(argc, argv, "s:e:t:T:l:b:PHd:D:Spvh")) != -1) {
+	while ((opt = getopt(argc, argv, "s:e:t:T:l:b:PHc:d:D:Spvh")) != -1) {
 		switch (opt) {
 		case 's':
 			skel->rodata->slice_ns = strtoull(optarg, NULL, 0) * 1000;
@@ -92,6 +93,15 @@ int main(int argc, char **argv)
 		case 'H':
 			skel->rodata->highpri_boosting = true;
 			break;
+		case 'c': {
+			struct stat st;
+			if (stat(optarg, &st) < 0) {
+				perror("stat");
+				return 1;
+			}
+			skel->struct_ops.qmap_ops->sub_cgroup_id = st.st_ino;
+			break;
+		}
 		case 'd':
 			skel->rodata->disallow_tgid = strtol(optarg, NULL, 0);
 			if (skel->rodata->disallow_tgid < 0)