sched_ext: Implement scx_bpf_consume_task()

Implement scx_bpf_consume_task() which allows consuming arbitrary tasks on
the DSQ in any order while iterating in the dispatch path.

scx_qmap is updated to implement periodic dumping of the shared DSQ and a
rather silly prioritization mechanism to demonstrate the use of DSQ
iteration and selective consumption.

v2: - BPF now allows kfuncs to take pointer to iterators. Drop the now
      unnecessary hack. This makes things noticeably cleaner.

Signed-off-by: Tejun Heo <tj@kernel.org>
Reviewed-by: David Vernet <dvernet@meta.com>
Cc: Alexei Starovoitov <ast@kernel.org>
Cc: bpf@vger.kernel.org
diff --git a/kernel/sched/ext.c b/kernel/sched/ext.c
index b9bf9ee..ac6f3f7 100644
--- a/kernel/sched/ext.c
+++ b/kernel/sched/ext.c
@@ -5676,12 +5676,71 @@ __bpf_kfunc bool scx_bpf_consume(u64 dsq_id)
 	}
 }
 
+/**
+ * scx_bpf_consume_task - Transfer a task from DSQ iteration to the local DSQ
+ * @it: DSQ iterator in progress
+ * @p: task to consume
+ *
+ * Transfer @p which is on the DSQ currently iterated by @it to the current
+ * CPU's local DSQ. For the transfer to be successful, @p must still be on the
+ * DSQ and have been queued before the DSQ iteration started. This function
+ * doesn't care whether @p was obtained from the DSQ iteration. @p just has to
+ * be on the DSQ and have been queued before the iteration started.
+ *
+ * Returns %true if @p has been consumed, %false if @p had already been consumed
+ * or dequeued.
+ */
+__bpf_kfunc bool scx_bpf_consume_task(struct bpf_iter_scx_dsq *it,
+				      struct task_struct *p)
+{
+	struct bpf_iter_scx_dsq_kern *kit = (void *)it;
+	struct scx_dispatch_q *dsq = kit->dsq;
+	struct scx_dsp_ctx *dspc = this_cpu_ptr(scx_dsp_ctx);
+	struct rq *task_rq;
+
+	if (unlikely(dsq->id == SCX_DSQ_LOCAL)) {
+		scx_ops_error("local DSQ not allowed");
+		return false;
+	}
+
+	if (!scx_kf_allowed(SCX_KF_DISPATCH))
+		return false;
+
+	flush_dispatch_buf(dspc->rq);
+
+	raw_spin_lock(&dsq->lock);
+
+	/*
+	 * Did someone else get to it? @p could have already left $dsq, got
+	 * re-enqueud, or be in the process of being consumed by someone else.
+	 */
+	if (unlikely(p->scx.dsq != dsq ||
+		     u32_before(kit->dsq_seq, p->scx.dsq_seq) ||
+		     p->scx.holding_cpu >= 0))
+		goto out_unlock;
+
+	task_rq = task_rq(p);
+
+	if (dspc->rq == task_rq) {
+		consume_local_task(dspc->rq, dsq, p);
+		return true;
+	}
+
+	if (task_can_run_on_remote_rq(p, dspc->rq))
+		return consume_remote_task(dspc->rq, dsq, p, task_rq);
+
+out_unlock:
+	raw_spin_unlock(&dsq->lock);
+	return false;
+}
+
 __bpf_kfunc_end_defs();
 
 BTF_KFUNCS_START(scx_kfunc_ids_dispatch)
 BTF_ID_FLAGS(func, scx_bpf_dispatch_nr_slots)
 BTF_ID_FLAGS(func, scx_bpf_dispatch_cancel)
 BTF_ID_FLAGS(func, scx_bpf_consume)
+BTF_ID_FLAGS(func, scx_bpf_consume_task)
 BTF_KFUNCS_END(scx_kfunc_ids_dispatch)
 
 static const struct btf_kfunc_id_set scx_kfunc_set_dispatch = {
diff --git a/tools/sched_ext/include/scx/common.bpf.h b/tools/sched_ext/include/scx/common.bpf.h
index 20280df..84922a7 100644
--- a/tools/sched_ext/include/scx/common.bpf.h
+++ b/tools/sched_ext/include/scx/common.bpf.h
@@ -35,6 +35,7 @@ void scx_bpf_dispatch_vtime(struct task_struct *p, u64 dsq_id, u64 slice, u64 vt
 u32 scx_bpf_dispatch_nr_slots(void) __ksym;
 void scx_bpf_dispatch_cancel(void) __ksym;
 bool scx_bpf_consume(u64 dsq_id) __ksym;
+bool scx_bpf_consume_task(struct bpf_iter_scx_dsq *it, struct task_struct *p) __ksym __weak;
 u32 scx_bpf_reenqueue_local(void) __ksym;
 void scx_bpf_kick_cpu(s32 cpu, u64 flags) __ksym;
 s32 scx_bpf_dsq_nr_queued(u64 dsq_id) __ksym;
@@ -62,6 +63,12 @@ bool scx_bpf_task_running(const struct task_struct *p) __ksym;
 s32 scx_bpf_task_cpu(const struct task_struct *p) __ksym;
 struct rq *scx_bpf_cpu_rq(s32 cpu) __ksym;
 
+/*
+ * Use the following as @it when calling scx_bpf_consume_task() from whitin
+ * bpf_for_each() loops.
+ */
+#define BPF_FOR_EACH_ITER	(&___it)
+
 static inline __attribute__((format(printf, 1, 2)))
 void ___scx_bpf_bstr_format_checker(const char *fmt, ...) {}
 
diff --git a/tools/sched_ext/scx_qmap.bpf.c b/tools/sched_ext/scx_qmap.bpf.c
index 892278f..befd955 100644
--- a/tools/sched_ext/scx_qmap.bpf.c
+++ b/tools/sched_ext/scx_qmap.bpf.c
@@ -23,6 +23,7 @@
  * Copyright (c) 2022 David Vernet <dvernet@meta.com>
  */
 #include <scx/common.bpf.h>
+#include <string.h>
 
 enum consts {
 	ONE_SEC_IN_NS		= 1000000000,
@@ -37,6 +38,7 @@ const volatile u32 stall_kernel_nth;
 const volatile u32 dsp_inf_loop_after;
 const volatile u32 dsp_batch;
 const volatile bool print_shared_dsq;
+const volatile u64 exp_cgid;
 const volatile s32 disallow_tgid;
 const volatile bool suppress_dump;
 
@@ -121,7 +123,7 @@ struct {
 
 /* Statistics */
 u64 nr_enqueued, nr_dispatched, nr_reenqueued, nr_dequeued, nr_ddsp_from_enq;
-u64 nr_core_sched_execed;
+u64 nr_core_sched_execed, nr_expedited;
 u32 cpuperf_min, cpuperf_avg, cpuperf_max;
 u32 cpuperf_target_min, cpuperf_target_avg, cpuperf_target_max;
 
@@ -281,6 +283,32 @@ static void update_core_sched_head_seq(struct task_struct *p)
 		scx_bpf_error("task_ctx lookup failed");
 }
 
+static bool consume_shared_dsq(void)
+{
+	struct task_struct *p;
+	bool consumed;
+
+	if (!exp_cgid)
+		return scx_bpf_consume(SHARED_DSQ);
+
+	/*
+	 * To demonstrate the use of scx_bpf_consume_task(), implement silly
+	 * selective priority boosting mechanism by scanning SHARED_DSQ looking
+	 * for matching comms and consume them first. This makes difference only
+	 * when dsp_batch is larger than 1.
+	 */
+	consumed = false;
+	bpf_for_each(scx_dsq, p, SHARED_DSQ, 0) {
+		if (p->cgroups->dfl_cgrp->kn->id == exp_cgid &&
+		    scx_bpf_consume_task(BPF_FOR_EACH_ITER, p)) {
+			consumed = true;
+			__sync_fetch_and_add(&nr_expedited, 1);
+		}
+	}
+
+	return consumed || scx_bpf_consume(SHARED_DSQ);
+}
+
 void BPF_STRUCT_OPS(qmap_dispatch, s32 cpu, struct task_struct *prev)
 {
 	struct task_struct *p;
@@ -289,7 +317,7 @@ void BPF_STRUCT_OPS(qmap_dispatch, s32 cpu, struct task_struct *prev)
 	void *fifo;
 	s32 i, pid;
 
-	if (scx_bpf_consume(SHARED_DSQ))
+	if (consume_shared_dsq())
 		return;
 
 	if (dsp_inf_loop_after && nr_dispatched > dsp_inf_loop_after) {
@@ -340,7 +368,7 @@ void BPF_STRUCT_OPS(qmap_dispatch, s32 cpu, struct task_struct *prev)
 			batch--;
 			cpuc->dsp_cnt--;
 			if (!batch || !scx_bpf_dispatch_nr_slots()) {
-				scx_bpf_consume(SHARED_DSQ);
+				consume_shared_dsq();
 				return;
 			}
 			if (!cpuc->dsp_cnt)
diff --git a/tools/sched_ext/scx_qmap.c b/tools/sched_ext/scx_qmap.c
index c9ca30d..4a8ca28 100644
--- a/tools/sched_ext/scx_qmap.c
+++ b/tools/sched_ext/scx_qmap.c
@@ -20,7 +20,7 @@ const char help_fmt[] =
 "See the top-level comment in .bpf.c for more details.\n"
 "\n"
 "Usage: %s [-s SLICE_US] [-e COUNT] [-t COUNT] [-T COUNT] [-l COUNT] [-b COUNT]\n"
-"       [-P] [-d PID] [-D LEN] [-p] [-v]\n"
+"       [-P] [-E PREFIX] [-d PID] [-D LEN] [-p] [-v]\n"
 "\n"
 "  -s SLICE_US   Override slice duration\n"
 "  -e COUNT      Trigger scx_bpf_error() after COUNT enqueues\n"
@@ -29,10 +29,11 @@ const char help_fmt[] =
 "  -l COUNT      Trigger dispatch infinite looping after COUNT dispatches\n"
 "  -b COUNT      Dispatch upto COUNT tasks together\n"
 "  -P            Print out DSQ content to trace_pipe every second, use with -b\n"
+"  -E CGID       Expedite consumption of threads in a cgroup, use with -b\n"
 "  -d PID        Disallow a process from switching into SCHED_EXT (-1 for self)\n"
 "  -D LEN        Set scx_exit_info.dump buffer length\n"
 "  -S            Suppress qmap-specific debug dump\n"
-"  -p            Switch only tasks on SCHED_EXT policy instead of all\n"
+"  -p            Switch only tasks on SCHED_EXT policy intead of all\n"
 "  -v            Print libbpf debug messages\n"
 "  -h            Display this help and exit\n";
 
@@ -63,7 +64,7 @@ int main(int argc, char **argv)
 
 	skel = SCX_OPS_OPEN(qmap_ops, scx_qmap);
 
-	while ((opt = getopt(argc, argv, "s:e:t:T:l:b:Pd:D:Spvh")) != -1) {
+	while ((opt = getopt(argc, argv, "s:e:t:T:l:b:PE:d:D:Spvh")) != -1) {
 		switch (opt) {
 		case 's':
 			skel->rodata->slice_ns = strtoull(optarg, NULL, 0) * 1000;
@@ -86,6 +87,9 @@ int main(int argc, char **argv)
 		case 'P':
 			skel->rodata->print_shared_dsq = true;
 			break;
+		case 'E':
+			skel->rodata->exp_cgid = strtoull(optarg, NULL, 0);
+			break;
 		case 'd':
 			skel->rodata->disallow_tgid = strtol(optarg, NULL, 0);
 			if (skel->rodata->disallow_tgid < 0)
@@ -116,11 +120,12 @@ int main(int argc, char **argv)
 		long nr_enqueued = skel->bss->nr_enqueued;
 		long nr_dispatched = skel->bss->nr_dispatched;
 
-		printf("stats  : enq=%lu dsp=%lu delta=%ld reenq=%"PRIu64" deq=%"PRIu64" core=%"PRIu64" enq_ddsp=%"PRIu64"\n",
+		printf("stats  : enq=%lu dsp=%lu delta=%ld reenq=%"PRIu64" deq=%"PRIu64" core=%"PRIu64" enq_ddsp=%"PRIu64" exp=%"PRIu64"\n",
 		       nr_enqueued, nr_dispatched, nr_enqueued - nr_dispatched,
 		       skel->bss->nr_reenqueued, skel->bss->nr_dequeued,
 		       skel->bss->nr_core_sched_execed,
-		       skel->bss->nr_ddsp_from_enq);
+		       skel->bss->nr_ddsp_from_enq,
+		       skel->bss->nr_expedited);
 		if (__COMPAT_has_ksym("scx_bpf_cpuperf_cur"))
 			printf("cpuperf: cur min/avg/max=%u/%u/%u target min/avg/max=%u/%u/%u\n",
 			       skel->bss->cpuperf_min,