KVM: arm64: Reject shared table walks in the hyp code

Exclusive table walks are the only supported table walk in the hyp, as
there is no construct like RCU available in the hypervisor code. Reject
any attempt to do a shared table walk by returning an error and allowing
the caller to clean up the mess.

Suggested-by: Will Deacon <will@kernel.org>
Signed-off-by: Oliver Upton <oliver.upton@linux.dev>
Acked-by: Will Deacon <will@kernel.org>
Signed-off-by: Marc Zyngier <maz@kernel.org>
Link: https://lore.kernel.org/r/20221118182222.3932898-4-oliver.upton@linux.dev
diff --git a/arch/arm64/include/asm/kvm_pgtable.h b/arch/arm64/include/asm/kvm_pgtable.h
index 4b6b52e..d5cb01f 100644
--- a/arch/arm64/include/asm/kvm_pgtable.h
+++ b/arch/arm64/include/asm/kvm_pgtable.h
@@ -229,7 +229,18 @@
 	return pteref;
 }
 
-static inline void kvm_pgtable_walk_begin(struct kvm_pgtable_walker *walker) {}
+static inline int kvm_pgtable_walk_begin(struct kvm_pgtable_walker *walker)
+{
+	/*
+	 * Due to the lack of RCU (or a similar protection scheme), only
+	 * non-shared table walkers are allowed in the hypervisor.
+	 */
+	if (walker->flags & KVM_PGTABLE_WALK_SHARED)
+		return -EPERM;
+
+	return 0;
+}
+
 static inline void kvm_pgtable_walk_end(struct kvm_pgtable_walker *walker) {}
 
 static inline bool kvm_pgtable_walk_lock_held(void)
@@ -247,10 +258,12 @@
 	return rcu_dereference_check(pteref, !(walker->flags & KVM_PGTABLE_WALK_SHARED));
 }
 
-static inline void kvm_pgtable_walk_begin(struct kvm_pgtable_walker *walker)
+static inline int kvm_pgtable_walk_begin(struct kvm_pgtable_walker *walker)
 {
 	if (walker->flags & KVM_PGTABLE_WALK_SHARED)
 		rcu_read_lock();
+
+	return 0;
 }
 
 static inline void kvm_pgtable_walk_end(struct kvm_pgtable_walker *walker)
diff --git a/arch/arm64/kvm/hyp/pgtable.c b/arch/arm64/kvm/hyp/pgtable.c
index d6f3753..58dbe0a 100644
--- a/arch/arm64/kvm/hyp/pgtable.c
+++ b/arch/arm64/kvm/hyp/pgtable.c
@@ -289,7 +289,10 @@
 	};
 	int r;
 
-	kvm_pgtable_walk_begin(walker);
+	r = kvm_pgtable_walk_begin(walker);
+	if (r)
+		return r;
+
 	r = _kvm_pgtable_walk(pgt, &walk_data);
 	kvm_pgtable_walk_end(walker);