blk-mq-sched: fix crash in switch error path

In elevator_switch(), if blk_mq_init_sched() fails, we attempt to fall
back to the original scheduler. However, at this point, we've already
torn down the original scheduler's tags, so this causes a crash. Doing
the fallback like the legacy elevator path is much harder for mq, so fix
it by just falling back to none, instead.

Signed-off-by: Omar Sandoval <osandov@fb.com>
Signed-off-by: Jens Axboe <axboe@fb.com>
diff --git a/block/blk-mq-sched.c b/block/blk-mq-sched.c
index 0bb13bb..e8c2ed6 100644
--- a/block/blk-mq-sched.c
+++ b/block/blk-mq-sched.c
@@ -451,7 +451,7 @@
 	return ret;
 }
 
-void blk_mq_sched_teardown(struct request_queue *q)
+static void blk_mq_sched_tags_teardown(struct request_queue *q)
 {
 	struct blk_mq_tag_set *set = q->tag_set;
 	struct blk_mq_hw_ctx *hctx;
@@ -513,10 +513,19 @@
 	return 0;
 
 err:
-	blk_mq_sched_teardown(q);
+	blk_mq_sched_tags_teardown(q);
+	q->elevator = NULL;
 	return ret;
 }
 
+void blk_mq_exit_sched(struct request_queue *q, struct elevator_queue *e)
+{
+	if (e->type->ops.mq.exit_sched)
+		e->type->ops.mq.exit_sched(e);
+	blk_mq_sched_tags_teardown(q);
+	q->elevator = NULL;
+}
+
 int blk_mq_sched_init(struct request_queue *q)
 {
 	int ret;
diff --git a/block/blk-mq-sched.h b/block/blk-mq-sched.h
index 19db25e..e704956 100644
--- a/block/blk-mq-sched.h
+++ b/block/blk-mq-sched.h
@@ -33,7 +33,7 @@
 			struct request *(*get_rq)(struct blk_mq_hw_ctx *));
 
 int blk_mq_init_sched(struct request_queue *q, struct elevator_type *e);
-void blk_mq_sched_teardown(struct request_queue *q);
+void blk_mq_exit_sched(struct request_queue *q, struct elevator_queue *e);
 
 int blk_mq_sched_init_hctx(struct request_queue *q, struct blk_mq_hw_ctx *hctx,
 			   unsigned int hctx_idx);
diff --git a/block/blk-mq.c b/block/blk-mq.c
index 72e744c..cfb7c97 100644
--- a/block/blk-mq.c
+++ b/block/blk-mq.c
@@ -2240,8 +2240,6 @@
 	struct blk_mq_hw_ctx *hctx;
 	unsigned int i;
 
-	blk_mq_sched_teardown(q);
-
 	/* hctx kobj stays in hctx */
 	queue_for_each_hw_ctx(q, hctx, i) {
 		if (!hctx)
diff --git a/block/blk-sysfs.c b/block/blk-sysfs.c
index c44b321..37f0b3a 100644
--- a/block/blk-sysfs.c
+++ b/block/blk-sysfs.c
@@ -816,7 +816,7 @@
 
 	if (q->elevator) {
 		ioc_clear_queue(q);
-		elevator_exit(q->elevator);
+		elevator_exit(q, q->elevator);
 	}
 
 	blk_exit_rl(&q->root_rl);
diff --git a/block/elevator.c b/block/elevator.c
index f236ef1..dbeecf7 100644
--- a/block/elevator.c
+++ b/block/elevator.c
@@ -252,11 +252,11 @@
 }
 EXPORT_SYMBOL(elevator_init);
 
-void elevator_exit(struct elevator_queue *e)
+void elevator_exit(struct request_queue *q, struct elevator_queue *e)
 {
 	mutex_lock(&e->sysfs_lock);
 	if (e->uses_mq && e->type->ops.mq.exit_sched)
-		e->type->ops.mq.exit_sched(e);
+		blk_mq_exit_sched(q, e);
 	else if (!e->uses_mq && e->type->ops.sq.elevator_exit_fn)
 		e->type->ops.sq.elevator_exit_fn(e);
 	mutex_unlock(&e->sysfs_lock);
@@ -941,6 +941,45 @@
 }
 EXPORT_SYMBOL_GPL(elv_unregister);
 
+static int elevator_switch_mq(struct request_queue *q,
+			      struct elevator_type *new_e)
+{
+	int ret;
+
+	blk_mq_freeze_queue(q);
+	blk_mq_quiesce_queue(q);
+
+	if (q->elevator) {
+		if (q->elevator->registered)
+			elv_unregister_queue(q);
+		ioc_clear_queue(q);
+		elevator_exit(q, q->elevator);
+	}
+
+	ret = blk_mq_init_sched(q, new_e);
+	if (ret)
+		goto out;
+
+	if (new_e) {
+		ret = elv_register_queue(q);
+		if (ret) {
+			elevator_exit(q, q->elevator);
+			goto out;
+		}
+	}
+
+	if (new_e)
+		blk_add_trace_msg(q, "elv switch: %s", new_e->elevator_name);
+	else
+		blk_add_trace_msg(q, "elv switch: none");
+
+out:
+	blk_mq_unfreeze_queue(q);
+	blk_mq_start_stopped_hw_queues(q, true);
+	return ret;
+
+}
+
 /*
  * switch to new_e io scheduler. be careful not to introduce deadlocks -
  * we don't free the old io scheduler, before we have allocated what we
@@ -953,10 +992,8 @@
 	bool old_registered = false;
 	int err;
 
-	if (q->mq_ops) {
-		blk_mq_freeze_queue(q);
-		blk_mq_quiesce_queue(q);
-	}
+	if (q->mq_ops)
+		return elevator_switch_mq(q, new_e);
 
 	/*
 	 * Turn on BYPASS and drain all requests w/ elevator private data.
@@ -968,11 +1005,7 @@
 	if (old) {
 		old_registered = old->registered;
 
-		if (old->uses_mq)
-			blk_mq_sched_teardown(q);
-
-		if (!q->mq_ops)
-			blk_queue_bypass_start(q);
+		blk_queue_bypass_start(q);
 
 		/* unregister and clear all auxiliary data of the old elevator */
 		if (old_registered)
@@ -982,53 +1015,32 @@
 	}
 
 	/* allocate, init and register new elevator */
-	if (q->mq_ops)
-		err = blk_mq_init_sched(q, new_e);
-	else
-		err = new_e->ops.sq.elevator_init_fn(q, new_e);
+	err = new_e->ops.sq.elevator_init_fn(q, new_e);
 	if (err)
 		goto fail_init;
 
-	if (new_e) {
-		err = elv_register_queue(q);
-		if (err)
-			goto fail_register;
-	}
+	err = elv_register_queue(q);
+	if (err)
+		goto fail_register;
 
 	/* done, kill the old one and finish */
 	if (old) {
-		elevator_exit(old);
-		if (!q->mq_ops)
-			blk_queue_bypass_end(q);
+		elevator_exit(q, old);
+		blk_queue_bypass_end(q);
 	}
 
-	if (q->mq_ops) {
-		blk_mq_unfreeze_queue(q);
-		blk_mq_start_stopped_hw_queues(q, true);
-	}
-
-	if (new_e)
-		blk_add_trace_msg(q, "elv switch: %s", new_e->elevator_name);
-	else
-		blk_add_trace_msg(q, "elv switch: none");
+	blk_add_trace_msg(q, "elv switch: %s", new_e->elevator_name);
 
 	return 0;
 
 fail_register:
-	if (q->mq_ops)
-		blk_mq_sched_teardown(q);
-	elevator_exit(q->elevator);
+	elevator_exit(q, q->elevator);
 fail_init:
 	/* switch failed, restore and re-register old elevator */
 	if (old) {
 		q->elevator = old;
 		elv_register_queue(q);
-		if (!q->mq_ops)
-			blk_queue_bypass_end(q);
-	}
-	if (q->mq_ops) {
-		blk_mq_unfreeze_queue(q);
-		blk_mq_start_stopped_hw_queues(q, true);
+		blk_queue_bypass_end(q);
 	}
 
 	return err;
diff --git a/include/linux/elevator.h b/include/linux/elevator.h
index aebecc4..22d39e8 100644
--- a/include/linux/elevator.h
+++ b/include/linux/elevator.h
@@ -211,7 +211,7 @@
 extern ssize_t elv_iosched_store(struct request_queue *, const char *, size_t);
 
 extern int elevator_init(struct request_queue *, char *);
-extern void elevator_exit(struct elevator_queue *);
+extern void elevator_exit(struct request_queue *, struct elevator_queue *);
 extern int elevator_change(struct request_queue *, const char *);
 extern bool elv_bio_merge_ok(struct request *, struct bio *);
 extern struct elevator_queue *elevator_alloc(struct request_queue *,