io_uring: stop SQPOLL submit on creator's death

When the creator of SQPOLL io_uring dies (i.e. sqo_task), we don't want
its internals like ->files and ->mm to be poked by the SQPOLL task, it
have never been nice and recently got racy. That can happen when the
owner undergoes destruction and SQPOLL tasks tries to submit new
requests in parallel, and so calls io_sq_thread_acquire*().

That patch halts SQPOLL submissions when sqo_task dies by introducing
sqo_dead flag. Once set, the SQPOLL task must not do any submission,
which is synchronised by uring_lock as well as the new flag.

The tricky part is to make sure that disabling always happens, that
means either the ring is discovered by creator's do_exit() -> cancel,
or if the final close() happens before it's done by the creator. The
last is guaranteed by the fact that for SQPOLL the creator task and only
it holds exactly one file note, so either it pins up to do_exit() or
removed by the creator on the final put in flush. (see comments in
uring_flush() around file->f_count == 2).

One more place that can trigger io_sq_thread_acquire_*() is
__io_req_task_submit(). Shoot off requests on sqo_dead there, even
though actually we don't need to. That's because cancellation of
sqo_task should wait for the request before going any further.

note 1: io_disable_sqo_submit() does io_ring_set_wakeup_flag() so the
caller would enter the ring to get an error, but it still doesn't
guarantee that the flag won't be cleared.

note 2: if final __userspace__ close happens not from the creator
task, the file note will pin the ring until the task dies.

Fixed: b1b6b5a30dce8 ("kernel/io_uring: cancel io_uring before task works")
Signed-off-by: Pavel Begunkov <asml.silence@gmail.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/fs/io_uring.c b/fs/io_uring.c
index f39671a..2f305c0 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -262,6 +262,7 @@
 		unsigned int		drain_next: 1;
 		unsigned int		eventfd_async: 1;
 		unsigned int		restricted: 1;
+		unsigned int		sqo_dead: 1;
 
 		/*
 		 * Ring buffer of indices into array of io_uring_sqe, which is
@@ -2160,12 +2161,11 @@
 static void __io_req_task_submit(struct io_kiocb *req)
 {
 	struct io_ring_ctx *ctx = req->ctx;
-	bool fail;
 
-	fail = __io_sq_thread_acquire_mm(ctx) ||
-		__io_sq_thread_acquire_files(ctx);
 	mutex_lock(&ctx->uring_lock);
-	if (!fail)
+	if (!ctx->sqo_dead &&
+	    !__io_sq_thread_acquire_mm(ctx) &&
+	    !__io_sq_thread_acquire_files(ctx))
 		__io_queue_sqe(req, NULL);
 	else
 		__io_req_task_cancel(req, -EFAULT);
@@ -6954,7 +6954,8 @@
 		if (!list_empty(&ctx->iopoll_list))
 			io_do_iopoll(ctx, &nr_events, 0);
 
-		if (to_submit && likely(!percpu_ref_is_dying(&ctx->refs)))
+		if (to_submit && !ctx->sqo_dead &&
+		    likely(!percpu_ref_is_dying(&ctx->refs)))
 			ret = io_submit_sqes(ctx, to_submit);
 		mutex_unlock(&ctx->uring_lock);
 	}
@@ -8712,6 +8713,10 @@
 {
 	mutex_lock(&ctx->uring_lock);
 	percpu_ref_kill(&ctx->refs);
+
+	if (WARN_ON_ONCE((ctx->flags & IORING_SETUP_SQPOLL) && !ctx->sqo_dead))
+		ctx->sqo_dead = 1;
+
 	/* if force is set, the ring is going away. always drop after that */
 	ctx->cq_overflow_flushed = 1;
 	if (ctx->rings)
@@ -8874,6 +8879,18 @@
 	}
 }
 
+static void io_disable_sqo_submit(struct io_ring_ctx *ctx)
+{
+	WARN_ON_ONCE(ctx->sqo_task != current);
+
+	mutex_lock(&ctx->uring_lock);
+	ctx->sqo_dead = 1;
+	mutex_unlock(&ctx->uring_lock);
+
+	/* make sure callers enter the ring to get error */
+	io_ring_set_wakeup_flag(ctx);
+}
+
 /*
  * We need to iteratively cancel requests, in case a request has dependent
  * hard links. These persist even for failure of cancelations, hence keep
@@ -8885,6 +8902,8 @@
 	struct task_struct *task = current;
 
 	if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
+		/* for SQPOLL only sqo_task has task notes */
+		io_disable_sqo_submit(ctx);
 		task = ctx->sq_data->thread;
 		atomic_inc(&task->io_uring->in_idle);
 		io_sq_thread_park(ctx->sq_data);
@@ -9056,6 +9075,7 @@
 static int io_uring_flush(struct file *file, void *data)
 {
 	struct io_uring_task *tctx = current->io_uring;
+	struct io_ring_ctx *ctx = file->private_data;
 
 	if (!tctx)
 		return 0;
@@ -9071,7 +9091,16 @@
 	if (atomic_long_read(&file->f_count) != 2)
 		return 0;
 
-	io_uring_del_task_file(file);
+	if (ctx->flags & IORING_SETUP_SQPOLL) {
+		/* there is only one file note, which is owned by sqo_task */
+		WARN_ON_ONCE((ctx->sqo_task == current) ==
+			     !xa_load(&tctx->xa, (unsigned long)file));
+
+		io_disable_sqo_submit(ctx);
+	}
+
+	if (!(ctx->flags & IORING_SETUP_SQPOLL) || ctx->sqo_task == current)
+		io_uring_del_task_file(file);
 	return 0;
 }
 
@@ -9145,8 +9174,9 @@
 
 #endif /* !CONFIG_MMU */
 
-static void io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
+static int io_sqpoll_wait_sq(struct io_ring_ctx *ctx)
 {
+	int ret = 0;
 	DEFINE_WAIT(wait);
 
 	do {
@@ -9155,6 +9185,11 @@
 
 		prepare_to_wait(&ctx->sqo_sq_wait, &wait, TASK_INTERRUPTIBLE);
 
+		if (unlikely(ctx->sqo_dead)) {
+			ret = -EOWNERDEAD;
+			goto out;
+		}
+
 		if (!io_sqring_full(ctx))
 			break;
 
@@ -9162,6 +9197,8 @@
 	} while (!signal_pending(current));
 
 	finish_wait(&ctx->sqo_sq_wait, &wait);
+out:
+	return ret;
 }
 
 static int io_get_ext_arg(unsigned flags, const void __user *argp, size_t *argsz,
@@ -9235,10 +9272,16 @@
 	if (ctx->flags & IORING_SETUP_SQPOLL) {
 		io_cqring_overflow_flush(ctx, false, NULL, NULL);
 
+		ret = -EOWNERDEAD;
+		if (unlikely(ctx->sqo_dead))
+			goto out;
 		if (flags & IORING_ENTER_SQ_WAKEUP)
 			wake_up(&ctx->sq_data->wait);
-		if (flags & IORING_ENTER_SQ_WAIT)
-			io_sqpoll_wait_sq(ctx);
+		if (flags & IORING_ENTER_SQ_WAIT) {
+			ret = io_sqpoll_wait_sq(ctx);
+			if (ret)
+				goto out;
+		}
 		submitted = to_submit;
 	} else if (to_submit) {
 		ret = io_uring_add_task_file(ctx, f.file);
@@ -9665,6 +9708,7 @@
 	trace_io_uring_create(ret, ctx, p->sq_entries, p->cq_entries, p->flags);
 	return ret;
 err:
+	io_disable_sqo_submit(ctx);
 	io_ring_ctx_wait_and_kill(ctx);
 	return ret;
 }