WIP

Signed-off-by: Marc Zyngier <maz@kernel.org>
diff --git a/arch/arm64/include/asm/kvm_host.h b/arch/arm64/include/asm/kvm_host.h
index a33f599..2615946 100644
--- a/arch/arm64/include/asm/kvm_host.h
+++ b/arch/arm64/include/asm/kvm_host.h
@@ -686,7 +686,7 @@ struct kvm_vcpu_arch {
 	u8 iflags;
 
 	/* State flags for kernel bookkeeping, unused by the hypervisor code */
-	u8 sflags;
+	u16 sflags;
 
 	/*
 	 * Don't run the guest (internal implementation need).
@@ -895,7 +895,8 @@ struct kvm_vcpu_arch {
 #define PMUSERENR_ON_CPU	__vcpu_single_flag(sflags, BIT(6))
 /* WFI instruction trapped */
 #define IN_WFI			__vcpu_single_flag(sflags, BIT(7))
-
+/* Non-preemptible put/load transition */
+#define TRANSIENT_PUT_LOAD	__vcpu_single_flag(sflags, BIT(8))
 
 /* Pointer to the vcpu's SVE FFR for sve_{save,load}_state() */
 #define vcpu_sve_pffr(vcpu) (kern_hyp_va((vcpu)->arch.sve_state) +	\
diff --git a/arch/arm64/kvm/arm.c b/arch/arm64/kvm/arm.c
index a7ca776..8481c1af 100644
--- a/arch/arm64/kvm/arm.c
+++ b/arch/arm64/kvm/arm.c
@@ -606,7 +606,8 @@ void kvm_arch_vcpu_load(struct kvm_vcpu *vcpu, int cpu)
 	kvm_timer_vcpu_load(vcpu);
 	if (has_vhe())
 		kvm_vcpu_load_vhe(vcpu);
-	kvm_arch_vcpu_load_fp(vcpu);
+	if (!vcpu_get_flag(vcpu, TRANSIENT_PUT_LOAD))
+		kvm_arch_vcpu_load_fp(vcpu);
 	kvm_vcpu_pmu_restore_guest(vcpu);
 	if (kvm_arm_is_pvtime_enabled(&vcpu->arch))
 		kvm_make_request(KVM_REQ_RECORD_STEAL, vcpu);
@@ -632,7 +633,8 @@ void kvm_arch_vcpu_load(struct kvm_vcpu *vcpu, int cpu)
 void kvm_arch_vcpu_put(struct kvm_vcpu *vcpu)
 {
 	kvm_arch_vcpu_put_debug_state_flags(vcpu);
-	kvm_arch_vcpu_put_fp(vcpu);
+	if (!vcpu_get_flag(vcpu, TRANSIENT_PUT_LOAD))
+		kvm_arch_vcpu_put_fp(vcpu);
 	if (has_vhe())
 		kvm_vcpu_put_vhe(vcpu);
 	kvm_timer_vcpu_put(vcpu);
diff --git a/arch/arm64/kvm/emulate-nested.c b/arch/arm64/kvm/emulate-nested.c
index e384913..bccb3a1 100644
--- a/arch/arm64/kvm/emulate-nested.c
+++ b/arch/arm64/kvm/emulate-nested.c
@@ -2300,11 +2300,16 @@ static void nested_switch(struct kvm_vcpu *vcpu, nv_switch_fn fn, const union nv
 	if (!is_hyp_ctxt(vcpu))
 		__kvm_adjust_pc(vcpu);
 
+	vcpu_set_flag(vcpu, TRANSIENT_PUT_LOAD);
+
 	kvm_arch_vcpu_put(vcpu);
 
 	fn(vcpu, data);
 
 	kvm_arch_vcpu_load(vcpu, smp_processor_id());
+
+	vcpu_clear_flag(vcpu, TRANSIENT_PUT_LOAD);
+
 	preempt_enable();
 }