io_uring: split task_work into IRQ and non-IRQ safe variants

Now that the completion lock is no longer IRQ disabling, we can get
rid of that on the task_work side too. Provide two methods of adding
task_work, one that is IRQ safe and one that is not. The IRQ safe
variant is only used for bdev/reg file completions.

Work in progress, don't like the branching on the locking, but also
don't wnat to do too much duplicated code.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 067ce27..c263430 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -468,6 +468,7 @@ struct io_uring_tw {
 
 enum {
 	IO_TW_IRQ		= 0,
+	IO_TW_NONIRQ,
 	IO_TW_NR,
 };
 
@@ -1955,20 +1956,25 @@ static void ctx_flush_and_put(struct io_ring_ctx *ctx)
 	percpu_ref_put(&ctx->refs);
 }
 
-static void tctx_task_work(struct callback_head *cb)
+static void __tctx_task_work(struct io_uring_tw *tw, bool irq_safe)
 {
 	struct io_ring_ctx *ctx = NULL;
-	struct io_uring_tw *tw = container_of(cb, struct io_uring_tw, task_work);
 
 	while (1) {
 		struct io_wq_work_node *node;
 
-		spin_lock_irq(&tw->task_lock);
+		if (irq_safe)
+			spin_lock_irq(&tw->task_lock);
+		else
+			spin_lock(&tw->task_lock);
 		node = tw->task_list.first;
 		INIT_WQ_LIST(&tw->task_list);
 		if (!node)
 			tw->task_running = false;
-		spin_unlock_irq(&tw->task_lock);
+		if (irq_safe)
+			spin_unlock_irq(&tw->task_lock);
+		else
+			spin_unlock(&tw->task_lock);
 		if (!node)
 			break;
 
@@ -1992,7 +1998,22 @@ static void tctx_task_work(struct callback_head *cb)
 	ctx_flush_and_put(ctx);
 }
 
-static void __io_req_task_work_add(struct io_uring_tw *tw, struct io_kiocb *req)
+static void tctx_task_work(struct callback_head *cb)
+{
+	struct io_uring_tw *tw = container_of(cb, struct io_uring_tw, task_work);
+
+	__tctx_task_work(tw, false);
+}
+
+static void tctx_task_work_irq(struct callback_head *cb)
+{
+	struct io_uring_tw *tw = container_of(cb, struct io_uring_tw, task_work);
+
+	__tctx_task_work(tw, true);
+}
+
+static void __io_req_task_work_add(struct io_uring_tw *tw, struct io_kiocb *req,
+				   bool irq_safe)
 {
 	struct task_struct *tsk = req->task;
 	enum task_work_notify_mode notify;
@@ -2000,12 +2021,18 @@ static void __io_req_task_work_add(struct io_uring_tw *tw, struct io_kiocb *req)
 	unsigned long flags;
 	bool running;
 
-	spin_lock_irqsave(&tw->task_lock, flags);
+	if (irq_safe)
+		spin_lock_irqsave(&tw->task_lock, flags);
+	else
+		spin_lock(&tw->task_lock);
 	wq_list_add_tail(&req->io_task_work.node, &tw->task_list);
 	running = tw->task_running;
 	if (!running)
 		tw->task_running = true;
-	spin_unlock_irqrestore(&tw->task_lock, flags);
+	if (irq_safe)
+		spin_unlock_irqrestore(&tw->task_lock, flags);
+	else
+		spin_unlock(&tw->task_lock);
 
 	/* task_work already pending, we're done */
 	if (running)
@@ -2023,10 +2050,16 @@ static void __io_req_task_work_add(struct io_uring_tw *tw, struct io_kiocb *req)
 		return;
 	}
 
-	spin_lock_irqsave(&tw->task_lock, flags);
+	if (irq_safe)
+		spin_lock_irqsave(&tw->task_lock, flags);
+	else
+		spin_lock(&tw->task_lock);
 	node = tw->task_list.first;
 	INIT_WQ_LIST(&tw->task_list);
-	spin_unlock_irqrestore(&tw->task_lock, flags);
+	if (irq_safe)
+		spin_unlock_irqrestore(&tw->task_lock, flags);
+	else
+		spin_unlock(&tw->task_lock);
 
 	while (node) {
 		req = container_of(node, struct io_kiocb, io_task_work.node);
@@ -2040,9 +2073,19 @@ static void __io_req_task_work_add(struct io_uring_tw *tw, struct io_kiocb *req)
 static void io_req_task_work_add(struct io_kiocb *req)
 {
 	struct io_uring_task *tctx = req->task->io_uring;
+	struct io_uring_tw *tw = &tctx->tw[IO_TW_NONIRQ];
+
+	WARN_ON_ONCE(in_hardirq());
+	WARN_ON_ONCE(in_serving_softirq());
+	__io_req_task_work_add(tw, req, false);
+}
+
+static void io_req_task_work_add_irq(struct io_kiocb *req)
+{
+	struct io_uring_task *tctx = req->task->io_uring;
 	struct io_uring_tw *tw = &tctx->tw[IO_TW_IRQ];
 
-	__io_req_task_work_add(tw, req);
+	__io_req_task_work_add(tw, req, true);
 }
 
 static void io_req_task_cancel(struct io_kiocb *req)
@@ -2202,7 +2245,7 @@ static inline void io_put_req_deferred(struct io_kiocb *req, int refs)
 {
 	if (req_ref_sub_and_test(req, refs)) {
 		req->io_task_work.func = io_free_req;
-		io_req_task_work_add(req);
+		io_req_task_work_add_irq(req);
 	}
 }
 
@@ -2524,7 +2567,7 @@ static void io_complete_rw(struct kiocb *kiocb, long res, long res2)
 		return;
 	req->result = res;
 	req->io_task_work.func = io_req_task_complete;
-	io_req_task_work_add(req);
+	io_req_task_work_add_irq(req);
 }
 
 static void io_complete_rw_iopoll(struct kiocb *kiocb, long res, long res2)
@@ -3200,7 +3243,8 @@ static int io_async_buf_func(struct wait_queue_entry *wait, unsigned mode,
 
 	/* submit ref gets dropped, acquire a new one */
 	req_ref_get(req);
-	io_req_task_queue(req);
+	req->io_task_work.func = io_req_task_submit;
+	io_req_task_work_add_irq(req);
 	return 1;
 }
 
@@ -4848,7 +4892,7 @@ static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll,
 	 * of executing it. We can't safely execute it anyway, as we may not
 	 * have the needed state needed for it anyway.
 	 */
-	io_req_task_work_add(req);
+	io_req_task_work_add_irq(req);
 	return 1;
 }
 
@@ -5519,7 +5563,7 @@ static enum hrtimer_restart io_timeout_fn(struct hrtimer *timer)
 	spin_unlock_irqrestore(&ctx->timeout_lock, flags);
 
 	req->io_task_work.func = io_req_task_timeout;
-	io_req_task_work_add(req);
+	io_req_task_work_add_irq(req);
 	return HRTIMER_NORESTART;
 }
 
@@ -6424,7 +6468,7 @@ static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer)
 	spin_unlock_irqrestore(&ctx->timeout_lock, flags);
 
 	req->io_task_work.func = io_req_task_link_timeout;
-	io_req_task_work_add(req);
+	io_req_task_work_add_irq(req);
 	return HRTIMER_NORESTART;
 }
 
@@ -7991,7 +8035,10 @@ static int io_uring_alloc_task_context(struct task_struct *task,
 	for (i = 0; i < IO_TW_NR; i++) {
 		spin_lock_init(&tctx->tw[i].task_lock);
 		INIT_WQ_LIST(&tctx->tw[i].task_list);
-		init_task_work(&tctx->tw[i].task_work, tctx_task_work);
+		if (i == IO_TW_IRQ)
+			init_task_work(&tctx->tw[i].task_work, tctx_task_work_irq);
+		else
+			init_task_work(&tctx->tw[i].task_work, tctx_task_work);
 	}
 
 	return 0;