futex: Rework SET_SLOTS

Let SET_SLOTS have precedence over default scaling; once user sets a
size, stick with it.

Notably, doing SET_SLOTS 0 will cause fph->hash_mask to be 0, which
will cause __futex_hash() to return global hash buckets. Once in this
state, it is impossible to recover, so disable SET_SLOTS.

Also, let prctl() users wait-retry the rehash, such that return of
prctl() means new size is in effect.

Signed-off-by: Peter Zijlstra (Intel) <peterz@infradead.org>
diff --git a/kernel/futex/core.c b/kernel/futex/core.c
index 7516542..4b9a4fb0 100644
--- a/kernel/futex/core.c
+++ b/kernel/futex/core.c
@@ -61,6 +61,8 @@ struct futex_private_hash {
 	rcuref_t	users;
 	unsigned int	hash_mask;
 	struct rcu_head	rcu;
+	void		*mm;
+	bool		custom;
 	struct futex_hash_bucket queues[];
 };
 
@@ -192,12 +194,6 @@ static bool __futex_pivot_hash(struct mm_struct *mm,
 
 	fph = mm->futex_phash;
 	if (fph) {
-		if (fph->hash_mask >= new->hash_mask) {
-			/* It was increased again while we were waiting */
-			kvfree(new);
-			return true;
-		}
-
 		if (rcuref_read(&fph->users) != 0) {
 			mm->futex_phash_new = new;
 			return false;
@@ -207,6 +203,7 @@ static bool __futex_pivot_hash(struct mm_struct *mm,
 	}
 	rcu_assign_pointer(mm->futex_phash, new);
 	kvfree_rcu(fph, rcu);
+	wake_up_var(mm);
 	return true;
 }
 
@@ -258,11 +255,8 @@ bool futex_private_hash_get(struct futex_private_hash *fph)
 
 void futex_private_hash_put(struct futex_private_hash *fph)
 {
-	/*
-	 * Ignore the result; the DEAD state is picked up
-	 * when rcuref_get() starts failing.
-	 */
-	bool __maybe_unused ignore = rcuref_put(&fph->users);
+	if (rcuref_put(&fph->users))
+		wake_up_var(fph->mm);
 }
 
 struct futex_hash_bucket *futex_hash(union futex_key *key)
@@ -1402,71 +1396,124 @@ void futex_hash_free(struct mm_struct *mm)
 	kvfree(mm->futex_phash);
 }
 
-static int futex_hash_allocate(unsigned int hash_slots)
+static bool futex_pivot_pending(struct mm_struct *mm)
 {
-	struct futex_private_hash *fph, *hb_tofree = NULL;
+	guard(rcu)();
+
+	if (!mm->futex_phash_new)
+		return false;
+
+	return !rcuref_read(&mm->futex_phash->users);
+}
+
+static bool futex_hash_less(struct futex_private_hash *a,
+			    struct futex_private_hash *b)
+{
+	/* user provided always wins */
+	if (!a->custom && b->custom)
+		return true;
+	if (a->custom && !b->custom)
+		return false;
+
+	/* zero-sized hash wins */
+	if (!b->hash_mask)
+		return true;
+	if (!a->hash_mask)
+		return false;
+
+	/* keep the biggest */
+	if (a->hash_mask < b->hash_mask)
+		return true;
+	if (a->hash_mask > b->hash_mask)
+		return false;
+
+	return false; /* equal */
+}
+
+static int futex_hash_allocate(unsigned int hash_slots, bool custom)
+{
 	struct mm_struct *mm = current->mm;
-	size_t alloc_size;
+	struct futex_private_hash *fph;
 	int i;
 
-	if (hash_slots == 0)
-		hash_slots = 16;
-	hash_slots = clamp(hash_slots, 2, futex_hashmask + 1);
-	if (!is_power_of_2(hash_slots))
-		hash_slots = rounddown_pow_of_two(hash_slots);
+	if (hash_slots && (hash_slots == 1 || !is_power_of_2(hash_slots)))
+		return -EINVAL;
 
-	if (unlikely(check_mul_overflow(hash_slots, sizeof(struct futex_hash_bucket),
-					&alloc_size)))
-		return -ENOMEM;
+	/*
+	 * Once we've disabled the global hash there is no way back.
+	 */
+	scoped_guard (rcu) {
+		fph = rcu_dereference(mm->futex_phash);
+		if (fph && !mm->futex_phash->hash_mask) {
+			if (custom)
+				return -EBUSY;
+			return 0;
+		}
+	}
 
-	if (unlikely(check_add_overflow(alloc_size, sizeof(struct futex_private_hash),
-					&alloc_size)))
-		return -ENOMEM;
-
-	fph = kvmalloc(alloc_size, GFP_KERNEL_ACCOUNT);
+	fph = kvzalloc(struct_size(fph, queues, hash_slots), GFP_KERNEL_ACCOUNT);
 	if (!fph)
 		return -ENOMEM;
 
 	rcuref_init(&fph->users, 1);
-	fph->hash_mask = hash_slots - 1;
+	fph->hash_mask = hash_slots ? hash_slots - 1 : 0;
+	fph->custom = custom;
+	fph->mm = mm;
 
 	for (i = 0; i < hash_slots; i++)
 		futex_hash_bucket_init(&fph->queues[i], fph);
 
-	scoped_guard(mutex, &mm->futex_hash_lock) {
-		if (mm->futex_phash && !mm->futex_phash_new) {
-			/*
-			 * If we have an existing hash, but do not yet have
-			 * allocated a replacement hash, drop the initial
-			 * reference on the existing hash.
-			 *
-			 * Ignore the return value; removal is serialized by
-			 * mm->futex_hash_lock which we currently hold.
-			 */
-			futex_private_hash_put(mm->futex_phash);
-		}
-
-		if (mm->futex_phash_new) {
-			/*
-			 * If we already have a replacement hash pending;
-			 * keep the larger hash.
-			 */
-			if (mm->futex_phash_new->hash_mask <= fph->hash_mask) {
-				hb_tofree = mm->futex_phash_new;
-			} else {
-				hb_tofree = fph;
-				fph = mm->futex_phash_new;
-			}
-			mm->futex_phash_new = NULL;
-		}
-
+	if (custom) {
 		/*
-		 * Will set mm->futex_phash_new on failure;
-		 * futex_get_private_hash() will try again.
+		 * Only let prctl() wait / retry; don't unduly delay clone().
 		 */
-		__futex_pivot_hash(mm, fph);
+again:
+		wait_var_event(mm, futex_pivot_pending(mm));
 	}
-	kvfree(hb_tofree);
+
+	scoped_guard(mutex, &mm->futex_hash_lock) {
+		struct futex_private_hash *free __free(kvfree) = NULL;
+		struct futex_private_hash *cur, *new;
+
+		cur = mm->futex_phash;
+		new = mm->futex_phash_new;
+		mm->futex_phash_new = NULL;
+
+		if (fph) {
+			if (cur && !new) {
+				/*
+				 * If we have an existing hash, but do not yet have
+				 * allocated a replacement hash, drop the initial
+				 * reference on the existing hash.
+				 */
+				futex_private_hash_put(mm->futex_phash);
+			}
+
+			if (new) {
+				/*
+				 * Two updates raced; throw out the lesser one.
+				 */
+				if (futex_hash_less(new, fph)) {
+					free = new;
+					new = fph;
+				} else {
+					free = fph;
+				}
+			} else {
+				new = fph;
+			}
+			fph = NULL;
+		}
+
+		if (new) {
+			/*
+			 * Will set mm->futex_phash_new on failure;
+			 * futex_get_private_hash() will try again.
+			 */
+			if (!__futex_pivot_hash(mm, new) && custom)
+				goto again;
+		}
+	}
 	return 0;
 }
 
@@ -1479,10 +1526,17 @@ int futex_hash_allocate_default(void)
 		return 0;
 
 	scoped_guard(rcu) {
-		threads = min_t(unsigned int, get_nr_threads(current), num_online_cpus());
+		threads = min_t(unsigned int,
+				get_nr_threads(current),
+				num_online_cpus());
+
 		fph = rcu_dereference(current->mm->futex_phash);
-		if (fph)
+		if (fph) {
+			if (fph->custom)
+				return 0;
+
 			current_buckets = fph->hash_mask + 1;
+		}
 	}
 
 	/*
@@ -1495,7 +1549,7 @@ int futex_hash_allocate_default(void)
 	if (current_buckets >= buckets)
 		return 0;
 
-	return futex_hash_allocate(buckets);
+	return futex_hash_allocate(buckets, false);
 }
 
 static int futex_hash_get_slots(void)
@@ -1511,7 +1565,7 @@ static int futex_hash_get_slots(void)
 
 #else
 
-static int futex_hash_allocate(unsigned int hash_slots)
+static int futex_hash_allocate(unsigned int hash_slots, bool custom)
 {
 	return 0;
 }
@@ -1528,7 +1582,7 @@ int futex_hash_prctl(unsigned long arg2, unsigned long arg3)
 
 	switch (arg2) {
 	case PR_FUTEX_HASH_SET_SLOTS:
-		ret = futex_hash_allocate(arg3);
+		ret = futex_hash_allocate(arg3, true);
 		break;
 
 	case PR_FUTEX_HASH_GET_SLOTS: