Merge remote-tracking branch 'linux/master' into kvmarm-master/next

Signed-off-by: Marc Zyngier <maz@kernel.org>
diff --git a/Documentation/ABI/testing/sysfs-bus-coresight-devices-trbe b/Documentation/ABI/testing/sysfs-bus-coresight-devices-trbe
new file mode 100644
index 0000000..ad3bbc6
--- /dev/null
+++ b/Documentation/ABI/testing/sysfs-bus-coresight-devices-trbe
@@ -0,0 +1,14 @@
+What:		/sys/bus/coresight/devices/trbe<cpu>/align
+Date:		March 2021
+KernelVersion:	5.13
+Contact:	Anshuman Khandual <anshuman.khandual@arm.com>
+Description:	(Read) Shows the TRBE write pointer alignment. This value
+		is fetched from the TRBIDR register.
+
+What:		/sys/bus/coresight/devices/trbe<cpu>/flag
+Date:		March 2021
+KernelVersion:	5.13
+Contact:	Anshuman Khandual <anshuman.khandual@arm.com>
+Description:	(Read) Shows if TRBE updates in the memory are with access
+		and dirty flag updates as well. This value is fetched from
+		the TRBIDR register.
diff --git a/Documentation/admin-guide/kernel-parameters.txt b/Documentation/admin-guide/kernel-parameters.txt
index 0454572..18f8bb3 100644
--- a/Documentation/admin-guide/kernel-parameters.txt
+++ b/Documentation/admin-guide/kernel-parameters.txt
@@ -2279,8 +2279,7 @@
 				   state is kept private from the host.
 				   Not valid if the kernel is running in EL2.
 
-			Defaults to VHE/nVHE based on hardware support and
-			the value of CONFIG_ARM64_VHE.
+			Defaults to VHE/nVHE based on hardware support.
 
 	kvm-arm.vgic_v3_group0_trap=
 			[KVM,ARM] Trap guest accesses to GICv3 group-0
diff --git a/Documentation/devicetree/bindings/arm/ete.yaml b/Documentation/devicetree/bindings/arm/ete.yaml
new file mode 100644
index 0000000..7f9b2d1
--- /dev/null
+++ b/Documentation/devicetree/bindings/arm/ete.yaml
@@ -0,0 +1,75 @@
+# SPDX-License-Identifier: GPL-2.0-only or BSD-2-Clause
+# Copyright 2021, Arm Ltd
+%YAML 1.2
+---
+$id: "http://devicetree.org/schemas/arm/ete.yaml#"
+$schema: "http://devicetree.org/meta-schemas/core.yaml#"
+
+title: ARM Embedded Trace Extensions
+
+maintainers:
+  - Suzuki K Poulose <suzuki.poulose@arm.com>
+  - Mathieu Poirier <mathieu.poirier@linaro.org>
+
+description: |
+  Arm Embedded Trace Extension(ETE) is a per CPU trace component that
+  allows tracing the CPU execution. It overlaps with the CoreSight ETMv4
+  architecture and has extended support for future architecture changes.
+  The trace generated by the ETE could be stored via legacy CoreSight
+  components (e.g, TMC-ETR) or other means (e.g, using a per CPU buffer
+  Arm Trace Buffer Extension (TRBE)). Since the ETE can be connected to
+  legacy CoreSight components, a node must be listed per instance, along
+  with any optional connection graph as per the coresight bindings.
+  See bindings/arm/coresight.txt.
+
+properties:
+  $nodename:
+    pattern: "^ete([0-9a-f]+)$"
+  compatible:
+    items:
+      - const: arm,embedded-trace-extension
+
+  cpu:
+    description: |
+      Handle to the cpu this ETE is bound to.
+    $ref: /schemas/types.yaml#/definitions/phandle
+
+  out-ports:
+    description: |
+      Output connections from the ETE to legacy CoreSight trace bus.
+    $ref: /schemas/graph.yaml#/properties/ports
+    properties:
+      port:
+        description: Output connection from the ETE to legacy CoreSight Trace bus.
+        $ref: /schemas/graph.yaml#/properties/port
+
+required:
+  - compatible
+  - cpu
+
+additionalProperties: false
+
+examples:
+
+# An ETE node without legacy CoreSight connections
+  - |
+    ete0 {
+      compatible = "arm,embedded-trace-extension";
+      cpu = <&cpu_0>;
+    };
+# An ETE node with legacy CoreSight connections
+  - |
+   ete1 {
+      compatible = "arm,embedded-trace-extension";
+      cpu = <&cpu_1>;
+
+      out-ports {        /* legacy coresight connection */
+         port {
+             ete1_out_port: endpoint {
+                remote-endpoint = <&funnel_in_port0>;
+             };
+         };
+      };
+   };
+
+...
diff --git a/Documentation/devicetree/bindings/arm/trbe.yaml b/Documentation/devicetree/bindings/arm/trbe.yaml
new file mode 100644
index 0000000..4402d7b
--- /dev/null
+++ b/Documentation/devicetree/bindings/arm/trbe.yaml
@@ -0,0 +1,49 @@
+# SPDX-License-Identifier: GPL-2.0-only or BSD-2-Clause
+# Copyright 2021, Arm Ltd
+%YAML 1.2
+---
+$id: "http://devicetree.org/schemas/arm/trbe.yaml#"
+$schema: "http://devicetree.org/meta-schemas/core.yaml#"
+
+title: ARM Trace Buffer Extensions
+
+maintainers:
+  - Anshuman Khandual <anshuman.khandual@arm.com>
+
+description: |
+  Arm Trace Buffer Extension (TRBE) is a per CPU component
+  for storing trace generated on the CPU to memory. It is
+  accessed via CPU system registers. The software can verify
+  if it is permitted to use the component by checking the
+  TRBIDR register.
+
+properties:
+  $nodename:
+    const: "trbe"
+  compatible:
+    items:
+      - const: arm,trace-buffer-extension
+
+  interrupts:
+    description: |
+       Exactly 1 PPI must be listed. For heterogeneous systems where
+       TRBE is only supported on a subset of the CPUs, please consult
+       the arm,gic-v3 binding for details on describing a PPI partition.
+    maxItems: 1
+
+required:
+  - compatible
+  - interrupts
+
+additionalProperties: false
+
+examples:
+
+  - |
+   #include <dt-bindings/interrupt-controller/arm-gic.h>
+
+   trbe {
+     compatible = "arm,trace-buffer-extension";
+     interrupts = <GIC_PPI 15 IRQ_TYPE_LEVEL_HIGH>;
+   };
+...
diff --git a/Documentation/trace/coresight/coresight-trbe.rst b/Documentation/trace/coresight/coresight-trbe.rst
new file mode 100644
index 0000000..b9928ef
--- /dev/null
+++ b/Documentation/trace/coresight/coresight-trbe.rst
@@ -0,0 +1,38 @@
+.. SPDX-License-Identifier: GPL-2.0
+
+==============================
+Trace Buffer Extension (TRBE).
+==============================
+
+    :Author:   Anshuman Khandual <anshuman.khandual@arm.com>
+    :Date:     November 2020
+
+Hardware Description
+--------------------
+
+Trace Buffer Extension (TRBE) is a percpu hardware which captures in system
+memory, CPU traces generated from a corresponding percpu tracing unit. This
+gets plugged in as a coresight sink device because the corresponding trace
+generators (ETE), are plugged in as source device.
+
+The TRBE is not compliant to CoreSight architecture specifications, but is
+driven via the CoreSight driver framework to support the ETE (which is
+CoreSight compliant) integration.
+
+Sysfs files and directories
+---------------------------
+
+The TRBE devices appear on the existing coresight bus alongside the other
+coresight devices::
+
+	>$ ls /sys/bus/coresight/devices
+	trbe0  trbe1  trbe2 trbe3
+
+The ``trbe<N>`` named TRBEs are associated with a CPU.::
+
+	>$ ls /sys/bus/coresight/devices/trbe0/
+        align flag
+
+*Key file items are:-*
+   * ``align``: TRBE write pointer alignment
+   * ``flag``: TRBE updates memory with access and dirty flags
diff --git a/Documentation/virt/kvm/api.rst b/Documentation/virt/kvm/api.rst
index 307f2fc..0e89a1f 100644
--- a/Documentation/virt/kvm/api.rst
+++ b/Documentation/virt/kvm/api.rst
@@ -3116,6 +3116,18 @@
 registers to their initial values.  If this is not called, KVM_RUN will
 return ENOEXEC for that vcpu.
 
+The initial values are defined as:
+	- Processor state:
+		* AArch64: EL1h, D, A, I and F bits set. All other bits
+		  are cleared.
+		* AArch32: SVC, A, I and F bits set. All other bits are
+		  cleared.
+	- General Purpose registers, including PC and SP: set to 0
+	- FPSIMD/NEON registers: set to 0
+	- SVE registers: set to 0
+	- System registers: Reset to their architecturally defined
+	  values as for a warm reset to EL1 (resp. SVC)
+
 Note that because some registers reflect machine topology, all vcpus
 should be created before this ioctl is invoked.
 
@@ -3335,7 +3347,8 @@
 flags which can include the following:
 
   - KVM_GUESTDBG_USE_SW_BP:     using software breakpoints [x86, arm64]
-  - KVM_GUESTDBG_USE_HW_BP:     using hardware breakpoints [x86, s390, arm64]
+  - KVM_GUESTDBG_USE_HW_BP:     using hardware breakpoints [x86, s390]
+  - KVM_GUESTDBG_USE_HW:        using hardware debug events [arm64]
   - KVM_GUESTDBG_INJECT_DB:     inject DB type exception [x86]
   - KVM_GUESTDBG_INJECT_BP:     inject BP type exception [x86]
   - KVM_GUESTDBG_EXIT_PENDING:  trigger an immediate guest exit [s390]
@@ -6231,7 +6244,23 @@
 :Returns: 0 on success, -EINVAL when CPU doesn't support 2nd DAWR
 
 This capability can be used to check / enable 2nd DAWR feature provided
-by POWER10 processor.
+
+7.24 KVM_CAP_VM_COPY_ENC_CONTEXT_FROM
+-------------------------------------
+
+Architectures: x86 SEV enabled
+Type: vm
+Parameters: args[0] is the fd of the source vm
+Returns: 0 on success; ENOTTY on error
+
+This capability enables userspace to copy encryption context from the vm
+indicated by the fd to the vm this is called on.
+
+This is intended to support in-guest workloads scheduled by the host. This
+allows the in-guest workload to maintain its own memory slots and keeps
+the two vms from accidentally clobbering each other through interrupts and
+MSRs.
+
 
 8. Other capabilities.
 ======================
@@ -6727,3 +6756,13 @@
 The KVM_XEN_HVM_CONFIG_RUNSTATE flag indicates that the runstate-related
 features KVM_XEN_VCPU_ATTR_TYPE_RUNSTATE_ADDR/_CURRENT/_DATA/_ADJUST are
 supported by the KVM_XEN_VCPU_SET_ATTR/KVM_XEN_VCPU_GET_ATTR ioctls.
+
+8.31 KVM_CAP_PTP_KVM
+--------------------
+
+:Architectures: arm64
+
+This capability indicates that the KVM virtual PTP service is
+supported in the host. A VMM can check whether the service is
+available to the guest on migration.
+
diff --git a/Documentation/virt/kvm/arm/index.rst b/Documentation/virt/kvm/arm/index.rst
index 3e2b2ab..78a9b67 100644
--- a/Documentation/virt/kvm/arm/index.rst
+++ b/Documentation/virt/kvm/arm/index.rst
@@ -10,3 +10,4 @@
    hyp-abi
    psci
    pvtime
+   ptp_kvm
diff --git a/Documentation/virt/kvm/arm/ptp_kvm.rst b/Documentation/virt/kvm/arm/ptp_kvm.rst
new file mode 100644
index 0000000..5d2bc45
--- /dev/null
+++ b/Documentation/virt/kvm/arm/ptp_kvm.rst
@@ -0,0 +1,25 @@
+.. SPDX-License-Identifier: GPL-2.0
+
+PTP_KVM support for arm/arm64
+=============================
+
+PTP_KVM is used for high precision time sync between host and guests.
+It relies on transferring the wall clock and counter value from the
+host to the guest using a KVM-specific hypercall.
+
+* ARM_SMCCC_HYP_KVM_PTP_FUNC_ID: 0x86000001
+
+This hypercall uses the SMC32/HVC32 calling convention:
+
+ARM_SMCCC_HYP_KVM_PTP_FUNC_ID
+    ==============    ========    =====================================
+    Function ID:      (uint32)    0x86000001
+    Arguments:        (uint32)    KVM_PTP_VIRT_COUNTER(0)
+                                  KVM_PTP_PHYS_COUNTER(1)
+    Return Values:    (int32)     NOT_SUPPORTED(-1) on error, or
+                      (uint32)    Upper 32 bits of wall clock time (r0)
+                      (uint32)    Lower 32 bits of wall clock time (r1)
+                      (uint32)    Upper 32 bits of counter (r2)
+                      (uint32)    Lower 32 bits of counter (r3)
+    Endianness:                   No Restrictions.
+    ==============    ========    =====================================
diff --git a/Documentation/virt/kvm/devices/arm-vgic-its.rst b/Documentation/virt/kvm/devices/arm-vgic-its.rst
index 6c304fd..d257edd 100644
--- a/Documentation/virt/kvm/devices/arm-vgic-its.rst
+++ b/Documentation/virt/kvm/devices/arm-vgic-its.rst
@@ -80,7 +80,7 @@
     -EFAULT  Invalid guest ram access
     -EBUSY   One or more VCPUS are running
     -EACCES  The virtual ITS is backed by a physical GICv4 ITS, and the
-	     state is not available
+	     state is not available without GICv4.1
     =======  ==========================================================
 
 KVM_DEV_ARM_VGIC_GRP_ITS_REGS
diff --git a/Documentation/virt/kvm/devices/arm-vgic-v3.rst b/Documentation/virt/kvm/devices/arm-vgic-v3.rst
index 5dd3bff..51e5e57 100644
--- a/Documentation/virt/kvm/devices/arm-vgic-v3.rst
+++ b/Documentation/virt/kvm/devices/arm-vgic-v3.rst
@@ -228,7 +228,7 @@
 
     KVM_DEV_ARM_VGIC_CTRL_INIT
       request the initialization of the VGIC, no additional parameter in
-      kvm_device_attr.addr.
+      kvm_device_attr.addr. Must be called after all VCPUs have been created.
     KVM_DEV_ARM_VGIC_SAVE_PENDING_TABLES
       save all LPI pending bits into guest RAM pending tables.
 
diff --git a/Documentation/virt/kvm/locking.rst b/Documentation/virt/kvm/locking.rst
index 0aa4817..8f5d5bc 100644
--- a/Documentation/virt/kvm/locking.rst
+++ b/Documentation/virt/kvm/locking.rst
@@ -16,6 +16,13 @@
 - kvm->slots_lock is taken outside kvm->irq_lock, though acquiring
   them together is quite rare.
 
+- The kvm->mmu_notifier_slots_lock rwsem ensures that pairs of
+  invalidate_range_start() and invalidate_range_end() callbacks
+  use the same memslots array.  kvm->slots_lock is taken outside the
+  write-side critical section of kvm->mmu_notifier_slots_lock, so
+  MMU notifiers must not take kvm->slots_lock.  No other write-side
+  critical sections should be added.
+
 On x86:
 
 - vcpu->mutex is taken outside kvm->arch.hyperv.hv_lock
@@ -38,25 +45,24 @@
 following two cases:
 
 1. Access Tracking: The SPTE is not present, but it is marked for access
-   tracking i.e. the SPTE_SPECIAL_MASK is set. That means we need to
-   restore the saved R/X bits. This is described in more detail later below.
+   tracking. That means we need to restore the saved R/X bits. This is
+   described in more detail later below.
 
-2. Write-Protection: The SPTE is present and the fault is
-   caused by write-protect. That means we just need to change the W bit of
-   the spte.
+2. Write-Protection: The SPTE is present and the fault is caused by
+   write-protect. That means we just need to change the W bit of the spte.
 
-What we use to avoid all the race is the SPTE_HOST_WRITEABLE bit and
-SPTE_MMU_WRITEABLE bit on the spte:
+What we use to avoid all the race is the Host-writable bit and MMU-writable bit
+on the spte:
 
-- SPTE_HOST_WRITEABLE means the gfn is writable on host.
-- SPTE_MMU_WRITEABLE means the gfn is writable on mmu. The bit is set when
-  the gfn is writable on guest mmu and it is not write-protected by shadow
-  page write-protection.
+- Host-writable means the gfn is writable in the host kernel page tables and in
+  its KVM memslot.
+- MMU-writable means the gfn is writable in the guest's mmu and it is not
+  write-protected by shadow page write-protection.
 
 On fast page fault path, we will use cmpxchg to atomically set the spte W
-bit if spte.SPTE_HOST_WRITEABLE = 1 and spte.SPTE_WRITE_PROTECT = 1, or
-restore the saved R/X bits if VMX_EPT_TRACK_ACCESS mask is set, or both. This
-is safe because whenever changing these bits can be detected by cmpxchg.
+bit if spte.HOST_WRITEABLE = 1 and spte.WRITE_PROTECT = 1, to restore the saved
+R/X bits if for an access-traced spte, or both. This is safe because whenever
+changing these bits can be detected by cmpxchg.
 
 But we need carefully check these cases:
 
@@ -185,17 +191,17 @@
 Lockless Access Tracking:
 
 This is used for Intel CPUs that are using EPT but do not support the EPT A/D
-bits. In this case, when the KVM MMU notifier is called to track accesses to a
-page (via kvm_mmu_notifier_clear_flush_young), it marks the PTE as not-present
-by clearing the RWX bits in the PTE and storing the original R & X bits in
-some unused/ignored bits. In addition, the SPTE_SPECIAL_MASK is also set on the
-PTE (using the ignored bit 62). When the VM tries to access the page later on,
-a fault is generated and the fast page fault mechanism described above is used
-to atomically restore the PTE to a Present state. The W bit is not saved when
-the PTE is marked for access tracking and during restoration to the Present
-state, the W bit is set depending on whether or not it was a write access. If
-it wasn't, then the W bit will remain clear until a write access happens, at
-which time it will be set using the Dirty tracking mechanism described above.
+bits. In this case, PTEs are tagged as A/D disabled (using ignored bits), and
+when the KVM MMU notifier is called to track accesses to a page (via
+kvm_mmu_notifier_clear_flush_young), it marks the PTE not-present in hardware
+by clearing the RWX bits in the PTE and storing the original R & X bits in more
+unused/ignored bits. When the VM tries to access the page later on, a fault is
+generated and the fast page fault mechanism described above is used to
+atomically restore the PTE to a Present state. The W bit is not saved when the
+PTE is marked for access tracking and during restoration to the Present state,
+the W bit is set depending on whether or not it was a write access. If it
+wasn't, then the W bit will remain clear until a write access happens, at which
+time it will be set using the Dirty tracking mechanism described above.
 
 3. Reference
 ------------
diff --git a/MAINTAINERS b/MAINTAINERS
index c80ad73..96d7ce9 100644
--- a/MAINTAINERS
+++ b/MAINTAINERS
@@ -1761,6 +1761,8 @@
 F:	Documentation/devicetree/bindings/arm/coresight-cpu-debug.txt
 F:	Documentation/devicetree/bindings/arm/coresight-cti.yaml
 F:	Documentation/devicetree/bindings/arm/coresight.txt
+F:	Documentation/devicetree/bindings/arm/ete.yaml
+F:	Documentation/devicetree/bindings/arm/trbe.yaml
 F:	Documentation/trace/coresight/*
 F:	drivers/hwtracing/coresight/*
 F:	include/dt-bindings/arm/coresight-cti-dt.h
@@ -9764,10 +9766,10 @@
 KERNEL VIRTUAL MACHINE FOR ARM64 (KVM/arm64)
 M:	Marc Zyngier <maz@kernel.org>
 R:	James Morse <james.morse@arm.com>
-R:	Julien Thierry <julien.thierry.kdev@gmail.com>
+R:	Alexandru Elisei <alexandru.elisei@arm.com>
 R:	Suzuki K Poulose <suzuki.poulose@arm.com>
 L:	linux-arm-kernel@lists.infradead.org (moderated for non-subscribers)
-L:	kvmarm@lists.cs.columbia.edu
+L:	kvmarm@lists.cs.columbia.edu (moderated for non-subscribers)
 S:	Maintained
 T:	git git://git.kernel.org/pub/scm/linux/kernel/git/kvmarm/kvmarm.git
 F:	arch/arm64/include/asm/kvm*
diff --git a/arch/arm/include/asm/hypervisor.h b/arch/arm/include/asm/hypervisor.h
index df85243..bd61502 100644
--- a/arch/arm/include/asm/hypervisor.h
+++ b/arch/arm/include/asm/hypervisor.h
@@ -4,4 +4,7 @@
 
 #include <asm/xen/hypervisor.h>
 
+void kvm_init_hyp_services(void);
+bool kvm_arm_hyp_service_available(u32 func_id);
+
 #endif
diff --git a/arch/arm64/Kconfig b/arch/arm64/Kconfig
index e4e1b65..9ec09f9 100644
--- a/arch/arm64/Kconfig
+++ b/arch/arm64/Kconfig
@@ -1426,19 +1426,6 @@
 	  built with binutils >= 2.25 in order for the new instructions
 	  to be used.
 
-config ARM64_VHE
-	bool "Enable support for Virtualization Host Extensions (VHE)"
-	default y
-	help
-	  Virtualization Host Extensions (VHE) allow the kernel to run
-	  directly at EL2 (instead of EL1) on processors that support
-	  it. This leads to better performance for KVM, as they reduce
-	  the cost of the world switch.
-
-	  Selecting this option allows the VHE feature to be detected
-	  at runtime, and does not affect processors that do not
-	  implement this feature.
-
 endmenu
 
 menu "ARMv8.2 architectural features"
@@ -1694,7 +1681,6 @@
 config ARM64_SVE
 	bool "ARM Scalable Vector Extension support"
 	default y
-	depends on !KVM || ARM64_VHE
 	help
 	  The Scalable Vector Extension (SVE) is an extension to the AArch64
 	  execution state which complements and extends the SIMD functionality
@@ -1723,12 +1709,6 @@
 	  booting the kernel.  If unsure and you are not observing these
 	  symptoms, you should assume that it is safe to say Y.
 
-	  CPUs that support SVE are architecturally required to support the
-	  Virtualization Host Extensions (VHE), so the kernel makes no
-	  provision for supporting SVE alongside KVM without VHE enabled.
-	  Thus, you will need to enable CONFIG_ARM64_VHE if you want to support
-	  KVM in the same kernel image.
-
 config ARM64_MODULE_PLTS
 	bool "Use PLTs to allow module memory to spill over into vmalloc area"
 	depends on MODULES
diff --git a/arch/arm64/include/asm/assembler.h b/arch/arm64/include/asm/assembler.h
index ca31594..34ddd8a 100644
--- a/arch/arm64/include/asm/assembler.h
+++ b/arch/arm64/include/asm/assembler.h
@@ -15,6 +15,7 @@
 #include <asm-generic/export.h>
 
 #include <asm/asm-offsets.h>
+#include <asm/asm-bug.h>
 #include <asm/cpufeature.h>
 #include <asm/cputype.h>
 #include <asm/debug-monitors.h>
@@ -270,12 +271,24 @@
  * provide the system wide safe value from arm64_ftr_reg_ctrel0.sys_val
  */
 	.macro	read_ctr, reg
+#ifndef __KVM_NVHE_HYPERVISOR__
 alternative_if_not ARM64_MISMATCHED_CACHE_TYPE
 	mrs	\reg, ctr_el0			// read CTR
 	nop
 alternative_else
 	ldr_l	\reg, arm64_ftr_reg_ctrel0 + ARM64_FTR_SYSVAL
 alternative_endif
+#else
+alternative_if_not ARM64_KVM_PROTECTED_MODE
+	ASM_BUG()
+alternative_else_nop_endif
+alternative_cb kvm_compute_final_ctr_el0
+	movz	\reg, #0
+	movk	\reg, #0, lsl #16
+	movk	\reg, #0, lsl #32
+	movk	\reg, #0, lsl #48
+alternative_cb_end
+#endif
 	.endm
 
 
@@ -676,11 +689,11 @@
 	.endm
 
 /*
- * Set SCTLR_EL1 to the passed value, and invalidate the local icache
+ * Set SCTLR_ELx to the @reg value, and invalidate the local icache
  * in the process. This is called when setting the MMU on.
  */
-.macro set_sctlr_el1, reg
-	msr	sctlr_el1, \reg
+.macro set_sctlr, sreg, reg
+	msr	\sreg, \reg
 	isb
 	/*
 	 * Invalidate the local I-cache so that any instructions fetched
@@ -692,6 +705,14 @@
 	isb
 .endm
 
+.macro set_sctlr_el1, reg
+	set_sctlr sctlr_el1, \reg
+.endm
+
+.macro set_sctlr_el2, reg
+	set_sctlr sctlr_el2, \reg
+.endm
+
 /*
  * Check whether to yield to another runnable task from kernel mode NEON code
  * (which runs with preemption disabled).
diff --git a/arch/arm64/include/asm/barrier.h b/arch/arm64/include/asm/barrier.h
index c3009b0..5a8367a 100644
--- a/arch/arm64/include/asm/barrier.h
+++ b/arch/arm64/include/asm/barrier.h
@@ -23,6 +23,7 @@
 #define dsb(opt)	asm volatile("dsb " #opt : : : "memory")
 
 #define psb_csync()	asm volatile("hint #17" : : : "memory")
+#define tsb_csync()	asm volatile("hint #18" : : : "memory")
 #define csdb()		asm volatile("hint #20" : : : "memory")
 
 #define spec_bar()	asm volatile(ALTERNATIVE("dsb nsh\nisb\n",		\
diff --git a/arch/arm64/include/asm/cpufeature.h b/arch/arm64/include/asm/cpufeature.h
index 61177ba..338840c 100644
--- a/arch/arm64/include/asm/cpufeature.h
+++ b/arch/arm64/include/asm/cpufeature.h
@@ -63,6 +63,23 @@
 	s64		safe_val; /* safe value for FTR_EXACT features */
 };
 
+/*
+ * Describe the early feature override to the core override code:
+ *
+ * @val			Values that are to be merged into the final
+ *			sanitised value of the register. Only the bitfields
+ *			set to 1 in @mask are valid
+ * @mask		Mask of the features that are overridden by @val
+ *
+ * A @mask field set to full-1 indicates that the corresponding field
+ * in @val is a valid override.
+ *
+ * A @mask field set to full-0 with the corresponding @val field set
+ * to full-0 denotes that this field has no override
+ *
+ * A @mask field set to full-0 with the corresponding @val field set
+ * to full-1 denotes thath this field has an invalid override.
+ */
 struct arm64_ftr_override {
 	u64		val;
 	u64		mask;
diff --git a/arch/arm64/include/asm/el2_setup.h b/arch/arm64/include/asm/el2_setup.h
index d77d358..bda9189 100644
--- a/arch/arm64/include/asm/el2_setup.h
+++ b/arch/arm64/include/asm/el2_setup.h
@@ -65,6 +65,19 @@
 						// use EL1&0 translation.
 
 .Lskip_spe_\@:
+	/* Trace buffer */
+	ubfx	x0, x1, #ID_AA64DFR0_TRBE_SHIFT, #4
+	cbz	x0, .Lskip_trace_\@		// Skip if TraceBuffer is not present
+
+	mrs_s	x0, SYS_TRBIDR_EL1
+	and	x0, x0, TRBIDR_PROG
+	cbnz	x0, .Lskip_trace_\@		// If TRBE is available at EL2
+
+	mov	x0, #(MDCR_EL2_E2TB_MASK << MDCR_EL2_E2TB_SHIFT)
+	orr	x2, x2, x0			// allow the EL1&0 translation
+						// to own it.
+
+.Lskip_trace_\@:
 	msr	mdcr_el2, x2			// Configure debug traps
 .endm
 
diff --git a/arch/arm64/include/asm/fpsimd.h b/arch/arm64/include/asm/fpsimd.h
index bec5f14..ff3879a 100644
--- a/arch/arm64/include/asm/fpsimd.h
+++ b/arch/arm64/include/asm/fpsimd.h
@@ -130,6 +130,15 @@
 	sysreg_clear_set(cpacr_el1, 0, CPACR_EL1_ZEN_EL0EN);
 }
 
+#define sve_cond_update_zcr_vq(val, reg)		\
+	do {						\
+		u64 __zcr = read_sysreg_s((reg));	\
+		u64 __new = __zcr & ~ZCR_ELx_LEN_MASK;	\
+		__new |= (val) & ZCR_ELx_LEN_MASK;	\
+		if (__zcr != __new)			\
+			write_sysreg_s(__new, (reg));	\
+	} while (0)
+
 /*
  * Probing and setup functions.
  * Calls to these functions must be serialised with one another.
@@ -159,6 +168,8 @@
 static inline void sve_user_disable(void) { BUILD_BUG(); }
 static inline void sve_user_enable(void) { BUILD_BUG(); }
 
+#define sve_cond_update_zcr_vq(val, reg) do { } while (0)
+
 static inline void sve_init_vq_map(void) { }
 static inline void sve_update_vq_map(void) { }
 static inline int sve_verify_vq_map(void) { return 0; }
diff --git a/arch/arm64/include/asm/fpsimdmacros.h b/arch/arm64/include/asm/fpsimdmacros.h
index af43367..a256399 100644
--- a/arch/arm64/include/asm/fpsimdmacros.h
+++ b/arch/arm64/include/asm/fpsimdmacros.h
@@ -6,6 +6,8 @@
  * Author: Catalin Marinas <catalin.marinas@arm.com>
  */
 
+#include <asm/assembler.h>
+
 .macro fpsimd_save state, tmpnr
 	stp	q0, q1, [\state, #16 * 0]
 	stp	q2, q3, [\state, #16 * 2]
@@ -230,8 +232,7 @@
 		str		w\nxtmp, [\xpfpsr, #4]
 .endm
 
-.macro sve_load nxbase, xpfpsr, xvqminus1, nxtmp, xtmp2
-		sve_load_vq	\xvqminus1, x\nxtmp, \xtmp2
+.macro __sve_load nxbase, xpfpsr, nxtmp
  _for n, 0, 31,	_sve_ldr_v	\n, \nxbase, \n - 34
 		_sve_ldr_p	0, \nxbase
 		_sve_wrffr	0
@@ -242,3 +243,8 @@
 		ldr		w\nxtmp, [\xpfpsr, #4]
 		msr		fpcr, x\nxtmp
 .endm
+
+.macro sve_load nxbase, xpfpsr, xvqminus1, nxtmp, xtmp2
+		sve_load_vq	\xvqminus1, x\nxtmp, \xtmp2
+		__sve_load	\nxbase, \xpfpsr, \nxtmp
+.endm
diff --git a/arch/arm64/include/asm/hyp_image.h b/arch/arm64/include/asm/hyp_image.h
index 737ded6..b4b3076 100644
--- a/arch/arm64/include/asm/hyp_image.h
+++ b/arch/arm64/include/asm/hyp_image.h
@@ -10,11 +10,15 @@
 #define __HYP_CONCAT(a, b)	a ## b
 #define HYP_CONCAT(a, b)	__HYP_CONCAT(a, b)
 
+#ifndef __KVM_NVHE_HYPERVISOR__
 /*
  * KVM nVHE code has its own symbol namespace prefixed with __kvm_nvhe_,
  * to separate it from the kernel proper.
  */
 #define kvm_nvhe_sym(sym)	__kvm_nvhe_##sym
+#else
+#define kvm_nvhe_sym(sym)	sym
+#endif
 
 #ifdef LINKER_SCRIPT
 
@@ -56,6 +60,9 @@
  */
 #define KVM_NVHE_ALIAS(sym)	kvm_nvhe_sym(sym) = sym;
 
+/* Defines a linker script alias for KVM nVHE hyp symbols */
+#define KVM_NVHE_ALIAS_HYP(first, sec)	kvm_nvhe_sym(first) = kvm_nvhe_sym(sec);
+
 #endif /* LINKER_SCRIPT */
 
 #endif /* __ARM64_HYP_IMAGE_H__ */
diff --git a/arch/arm64/include/asm/hypervisor.h b/arch/arm64/include/asm/hypervisor.h
index f9cc1d0..0ae427f 100644
--- a/arch/arm64/include/asm/hypervisor.h
+++ b/arch/arm64/include/asm/hypervisor.h
@@ -4,4 +4,7 @@
 
 #include <asm/xen/hypervisor.h>
 
+void kvm_init_hyp_services(void);
+bool kvm_arm_hyp_service_available(u32 func_id);
+
 #endif
diff --git a/arch/arm64/include/asm/kvm_arm.h b/arch/arm64/include/asm/kvm_arm.h
index 94d4025..692c9049 100644
--- a/arch/arm64/include/asm/kvm_arm.h
+++ b/arch/arm64/include/asm/kvm_arm.h
@@ -278,6 +278,8 @@
 #define CPTR_EL2_DEFAULT	CPTR_EL2_RES1
 
 /* Hyp Debug Configuration Register bits */
+#define MDCR_EL2_E2TB_MASK	(UL(0x3))
+#define MDCR_EL2_E2TB_SHIFT	(UL(24))
 #define MDCR_EL2_TTRF		(1 << 19)
 #define MDCR_EL2_TPMS		(1 << 14)
 #define MDCR_EL2_E2PB_MASK	(UL(0x3))
diff --git a/arch/arm64/include/asm/kvm_asm.h b/arch/arm64/include/asm/kvm_asm.h
index a7ab84f..cf8df03 100644
--- a/arch/arm64/include/asm/kvm_asm.h
+++ b/arch/arm64/include/asm/kvm_asm.h
@@ -57,6 +57,12 @@
 #define __KVM_HOST_SMCCC_FUNC___kvm_get_mdcr_el2		12
 #define __KVM_HOST_SMCCC_FUNC___vgic_v3_save_aprs		13
 #define __KVM_HOST_SMCCC_FUNC___vgic_v3_restore_aprs		14
+#define __KVM_HOST_SMCCC_FUNC___pkvm_init			15
+#define __KVM_HOST_SMCCC_FUNC___pkvm_create_mappings		16
+#define __KVM_HOST_SMCCC_FUNC___pkvm_create_private_mapping	17
+#define __KVM_HOST_SMCCC_FUNC___pkvm_cpu_set_vector		18
+#define __KVM_HOST_SMCCC_FUNC___pkvm_prot_finalize		19
+#define __KVM_HOST_SMCCC_FUNC___pkvm_mark_hyp			20
 
 #ifndef __ASSEMBLY__
 
@@ -154,6 +160,9 @@
 	unsigned long tpidr_el2;
 	unsigned long stack_hyp_va;
 	phys_addr_t pgd_pa;
+	unsigned long hcr_el2;
+	unsigned long vttbr;
+	unsigned long vtcr;
 };
 
 /* Translate a kernel address @ptr into its equivalent linear mapping */
diff --git a/arch/arm64/include/asm/kvm_host.h b/arch/arm64/include/asm/kvm_host.h
index 3d10e65..748bd03 100644
--- a/arch/arm64/include/asm/kvm_host.h
+++ b/arch/arm64/include/asm/kvm_host.h
@@ -94,7 +94,7 @@
 	/* The last vcpu id that ran on each physical CPU */
 	int __percpu *last_vcpu_ran;
 
-	struct kvm *kvm;
+	struct kvm_arch *arch;
 };
 
 struct kvm_arch_memory_slot {
@@ -315,6 +315,8 @@
 		struct kvm_guest_debug_arch regs;
 		/* Statistical profiling extension */
 		u64 pmscr_el1;
+		/* Self-hosted trace */
+		u64 trfcr_el1;
 	} host_debug_state;
 
 	/* VGIC state */
@@ -372,8 +374,10 @@
 };
 
 /* Pointer to the vcpu's SVE FFR for sve_{save,load}_state() */
-#define vcpu_sve_pffr(vcpu) ((void *)((char *)((vcpu)->arch.sve_state) + \
-				      sve_ffr_offset((vcpu)->arch.sve_max_vl)))
+#define vcpu_sve_pffr(vcpu) (kern_hyp_va((vcpu)->arch.sve_state) +	\
+			     sve_ffr_offset((vcpu)->arch.sve_max_vl))
+
+#define vcpu_sve_max_vq(vcpu)	sve_vq_from_vl((vcpu)->arch.sve_max_vl)
 
 #define vcpu_sve_state_size(vcpu) ({					\
 	size_t __size_ret;						\
@@ -382,7 +386,7 @@
 	if (WARN_ON(!sve_vl_valid((vcpu)->arch.sve_max_vl))) {		\
 		__size_ret = 0;						\
 	} else {							\
-		__vcpu_vq = sve_vq_from_vl((vcpu)->arch.sve_max_vl);	\
+		__vcpu_vq = vcpu_sve_max_vq(vcpu);			\
 		__size_ret = SVE_SIG_REGS_SIZE(__vcpu_vq);		\
 	}								\
 									\
@@ -400,6 +404,8 @@
 #define KVM_ARM64_GUEST_HAS_PTRAUTH	(1 << 7) /* PTRAUTH exposed to guest */
 #define KVM_ARM64_PENDING_EXCEPTION	(1 << 8) /* Exception pending */
 #define KVM_ARM64_EXCEPT_MASK		(7 << 9) /* Target EL/MODE */
+#define KVM_ARM64_DEBUG_STATE_SAVE_SPE	(1 << 12) /* Save SPE context if active  */
+#define KVM_ARM64_DEBUG_STATE_SAVE_TRBE	(1 << 13) /* Save TRBE context if active  */
 
 /*
  * When KVM_ARM64_PENDING_EXCEPTION is set, KVM_ARM64_EXCEPT_MASK can
@@ -582,15 +588,11 @@
 			      struct kvm_vcpu_events *events);
 
 #define KVM_ARCH_WANT_MMU_NOTIFIER
-int kvm_unmap_hva_range(struct kvm *kvm,
-			unsigned long start, unsigned long end, unsigned flags);
-int kvm_set_spte_hva(struct kvm *kvm, unsigned long hva, pte_t pte);
-int kvm_age_hva(struct kvm *kvm, unsigned long start, unsigned long end);
-int kvm_test_age_hva(struct kvm *kvm, unsigned long hva);
 
 void kvm_arm_halt_guest(struct kvm *kvm);
 void kvm_arm_resume_guest(struct kvm *kvm);
 
+#ifndef __KVM_NVHE_HYPERVISOR__
 #define kvm_call_hyp_nvhe(f, ...)						\
 	({								\
 		struct arm_smccc_res res;				\
@@ -630,9 +632,13 @@
 									\
 		ret;							\
 	})
+#else /* __KVM_NVHE_HYPERVISOR__ */
+#define kvm_call_hyp(f, ...) f(__VA_ARGS__)
+#define kvm_call_hyp_ret(f, ...) f(__VA_ARGS__)
+#define kvm_call_hyp_nvhe(f, ...) f(__VA_ARGS__)
+#endif /* __KVM_NVHE_HYPERVISOR__ */
 
 void force_vm_exit(const cpumask_t *mask);
-void kvm_mmu_wp_memory_region(struct kvm *kvm, int slot);
 
 int handle_exit(struct kvm_vcpu *vcpu, int exception_index);
 void handle_exit_early(struct kvm_vcpu *vcpu, int exception_index);
@@ -692,19 +698,6 @@
 	ctxt_sys_reg(cpu_ctxt, MPIDR_EL1) = read_cpuid_mpidr();
 }
 
-static inline bool kvm_arch_requires_vhe(void)
-{
-	/*
-	 * The Arm architecture specifies that implementation of SVE
-	 * requires VHE also to be implemented.  The KVM code for arm64
-	 * relies on this when SVE is present:
-	 */
-	if (system_supports_sve())
-		return true;
-
-	return false;
-}
-
 void kvm_arm_vcpu_ptrauth_trap(struct kvm_vcpu *vcpu);
 
 static inline void kvm_arch_hardware_unsetup(void) {}
@@ -713,6 +706,7 @@
 static inline void kvm_arch_vcpu_block_finish(struct kvm_vcpu *vcpu) {}
 
 void kvm_arm_init_debug(void);
+void kvm_arm_vcpu_init_debug(struct kvm_vcpu *vcpu);
 void kvm_arm_setup_debug(struct kvm_vcpu *vcpu);
 void kvm_arm_clear_debug(struct kvm_vcpu *vcpu);
 void kvm_arm_reset_debug_ptr(struct kvm_vcpu *vcpu);
@@ -734,6 +728,10 @@
 	return (!has_vhe() && attr->exclude_host);
 }
 
+/* Flags for host debug state */
+void kvm_arch_vcpu_load_debug_state_flags(struct kvm_vcpu *vcpu);
+void kvm_arch_vcpu_put_debug_state_flags(struct kvm_vcpu *vcpu);
+
 #ifdef CONFIG_KVM /* Avoid conflicts with core headers if CONFIG_KVM=n */
 static inline int kvm_arch_vcpu_run_pid_change(struct kvm_vcpu *vcpu)
 {
@@ -771,5 +769,12 @@
 	(test_bit(KVM_ARM_VCPU_PMU_V3, (vcpu)->arch.features))
 
 int kvm_trng_call(struct kvm_vcpu *vcpu);
+#ifdef CONFIG_KVM
+extern phys_addr_t hyp_mem_base;
+extern phys_addr_t hyp_mem_size;
+void __init kvm_hyp_reserve(void);
+#else
+static inline void kvm_hyp_reserve(void) { }
+#endif
 
 #endif /* __ARM64_KVM_HOST_H__ */
diff --git a/arch/arm64/include/asm/kvm_hyp.h b/arch/arm64/include/asm/kvm_hyp.h
index 32ae676..9d60b30 100644
--- a/arch/arm64/include/asm/kvm_hyp.h
+++ b/arch/arm64/include/asm/kvm_hyp.h
@@ -90,6 +90,8 @@
 
 void __fpsimd_save_state(struct user_fpsimd_state *fp_regs);
 void __fpsimd_restore_state(struct user_fpsimd_state *fp_regs);
+void __sve_save_state(void *sve_pffr, u32 *fpsr);
+void __sve_restore_state(void *sve_pffr, u32 *fpsr);
 
 #ifndef __KVM_NVHE_HYPERVISOR__
 void activate_traps_vhe_load(struct kvm_vcpu *vcpu);
@@ -100,10 +102,20 @@
 
 bool kvm_host_psci_handler(struct kvm_cpu_context *host_ctxt);
 
-void __noreturn hyp_panic(void);
 #ifdef __KVM_NVHE_HYPERVISOR__
 void __noreturn __hyp_do_panic(struct kvm_cpu_context *host_ctxt, u64 spsr,
 			       u64 elr, u64 par);
 #endif
 
+#ifdef __KVM_NVHE_HYPERVISOR__
+void __pkvm_init_switch_pgd(phys_addr_t phys, unsigned long size,
+			    phys_addr_t pgd, void *sp, void *cont_fn);
+int __pkvm_init(phys_addr_t phys, unsigned long size, unsigned long nr_cpus,
+		unsigned long *per_cpu_base, u32 hyp_va_bits);
+void __noreturn __host_enter(struct kvm_cpu_context *host_ctxt);
+#endif
+
+extern u64 kvm_nvhe_sym(id_aa64mmfr0_el1_sys_val);
+extern u64 kvm_nvhe_sym(id_aa64mmfr1_el1_sys_val);
+
 #endif /* __ARM64_KVM_HYP_H__ */
diff --git a/arch/arm64/include/asm/kvm_mmu.h b/arch/arm64/include/asm/kvm_mmu.h
index 9087385..25ed956 100644
--- a/arch/arm64/include/asm/kvm_mmu.h
+++ b/arch/arm64/include/asm/kvm_mmu.h
@@ -121,6 +121,8 @@
 void kvm_compute_layout(void);
 void kvm_apply_hyp_relocations(void);
 
+#define __hyp_pa(x) (((phys_addr_t)(x)) + hyp_physvirt_offset)
+
 static __always_inline unsigned long __kern_hyp_va(unsigned long v)
 {
 	asm volatile(ALTERNATIVE_CB("and %0, %0, #1\n"
@@ -166,7 +168,15 @@
 
 phys_addr_t kvm_mmu_get_httbr(void);
 phys_addr_t kvm_get_idmap_vector(void);
-int kvm_mmu_init(void);
+int kvm_mmu_init(u32 *hyp_va_bits);
+
+static inline void *__kvm_vector_slot2addr(void *base,
+					   enum arm64_hyp_spectre_vector slot)
+{
+	int idx = slot - (slot != HYP_VECTOR_DIRECT);
+
+	return base + (idx * SZ_2K);
+}
 
 struct kvm;
 
@@ -262,9 +272,9 @@
  * Must be called from hyp code running at EL2 with an updated VTTBR
  * and interrupts disabled.
  */
-static __always_inline void __load_guest_stage2(struct kvm_s2_mmu *mmu)
+static __always_inline void __load_stage2(struct kvm_s2_mmu *mmu, unsigned long vtcr)
 {
-	write_sysreg(kern_hyp_va(mmu->kvm)->arch.vtcr, vtcr_el2);
+	write_sysreg(vtcr, vtcr_el2);
 	write_sysreg(kvm_get_vttbr(mmu), vttbr_el2);
 
 	/*
@@ -275,5 +285,14 @@
 	asm(ALTERNATIVE("nop", "isb", ARM64_WORKAROUND_SPECULATIVE_AT));
 }
 
+static __always_inline void __load_guest_stage2(struct kvm_s2_mmu *mmu)
+{
+	__load_stage2(mmu, kern_hyp_va(mmu->arch)->vtcr);
+}
+
+static inline struct kvm *kvm_s2_mmu_to_kvm(struct kvm_s2_mmu *mmu)
+{
+	return container_of(mmu->arch, struct kvm, arch);
+}
 #endif /* __ASSEMBLY__ */
 #endif /* __ARM64_KVM_MMU_H__ */
diff --git a/arch/arm64/include/asm/kvm_pgtable.h b/arch/arm64/include/asm/kvm_pgtable.h
index 8886d43..c3674c4 100644
--- a/arch/arm64/include/asm/kvm_pgtable.h
+++ b/arch/arm64/include/asm/kvm_pgtable.h
@@ -11,22 +11,79 @@
 #include <linux/kvm_host.h>
 #include <linux/types.h>
 
+#define KVM_PGTABLE_MAX_LEVELS		4U
+
+static inline u64 kvm_get_parange(u64 mmfr0)
+{
+	u64 parange = cpuid_feature_extract_unsigned_field(mmfr0,
+				ID_AA64MMFR0_PARANGE_SHIFT);
+	if (parange > ID_AA64MMFR0_PARANGE_MAX)
+		parange = ID_AA64MMFR0_PARANGE_MAX;
+
+	return parange;
+}
+
 typedef u64 kvm_pte_t;
 
 /**
+ * struct kvm_pgtable_mm_ops - Memory management callbacks.
+ * @zalloc_page:	Allocate a single zeroed memory page. The @arg parameter
+ *			can be used by the walker to pass a memcache. The
+ *			initial refcount of the page is 1.
+ * @zalloc_pages_exact:	Allocate an exact number of zeroed memory pages. The
+ *			@size parameter is in bytes, and is rounded-up to the
+ *			next page boundary. The resulting allocation is
+ *			physically contiguous.
+ * @free_pages_exact:	Free an exact number of memory pages previously
+ *			allocated by zalloc_pages_exact.
+ * @get_page:		Increment the refcount on a page.
+ * @put_page:		Decrement the refcount on a page. When the refcount
+ *			reaches 0 the page is automatically freed.
+ * @page_count:		Return the refcount of a page.
+ * @phys_to_virt:	Convert a physical address into a virtual address mapped
+ *			in the current context.
+ * @virt_to_phys:	Convert a virtual address mapped in the current context
+ *			into a physical address.
+ */
+struct kvm_pgtable_mm_ops {
+	void*		(*zalloc_page)(void *arg);
+	void*		(*zalloc_pages_exact)(size_t size);
+	void		(*free_pages_exact)(void *addr, size_t size);
+	void		(*get_page)(void *addr);
+	void		(*put_page)(void *addr);
+	int		(*page_count)(void *addr);
+	void*		(*phys_to_virt)(phys_addr_t phys);
+	phys_addr_t	(*virt_to_phys)(void *addr);
+};
+
+/**
+ * enum kvm_pgtable_stage2_flags - Stage-2 page-table flags.
+ * @KVM_PGTABLE_S2_NOFWB:	Don't enforce Normal-WB even if the CPUs have
+ *				ARM64_HAS_STAGE2_FWB.
+ * @KVM_PGTABLE_S2_IDMAP:	Only use identity mappings.
+ */
+enum kvm_pgtable_stage2_flags {
+	KVM_PGTABLE_S2_NOFWB			= BIT(0),
+	KVM_PGTABLE_S2_IDMAP			= BIT(1),
+};
+
+/**
  * struct kvm_pgtable - KVM page-table.
  * @ia_bits:		Maximum input address size, in bits.
  * @start_level:	Level at which the page-table walk starts.
  * @pgd:		Pointer to the first top-level entry of the page-table.
+ * @mm_ops:		Memory management callbacks.
  * @mmu:		Stage-2 KVM MMU struct. Unused for stage-1 page-tables.
  */
 struct kvm_pgtable {
 	u32					ia_bits;
 	u32					start_level;
 	kvm_pte_t				*pgd;
+	struct kvm_pgtable_mm_ops		*mm_ops;
 
 	/* Stage-2 only */
 	struct kvm_s2_mmu			*mmu;
+	enum kvm_pgtable_stage2_flags		flags;
 };
 
 /**
@@ -50,6 +107,16 @@
 #define PAGE_HYP_DEVICE		(PAGE_HYP | KVM_PGTABLE_PROT_DEVICE)
 
 /**
+ * struct kvm_mem_range - Range of Intermediate Physical Addresses
+ * @start:	Start of the range.
+ * @end:	End of the range.
+ */
+struct kvm_mem_range {
+	u64 start;
+	u64 end;
+};
+
+/**
  * enum kvm_pgtable_walk_flags - Flags to control a depth-first page-table walk.
  * @KVM_PGTABLE_WALK_LEAF:		Visit leaf entries, including invalid
  *					entries.
@@ -86,10 +153,12 @@
  * kvm_pgtable_hyp_init() - Initialise a hypervisor stage-1 page-table.
  * @pgt:	Uninitialised page-table structure to initialise.
  * @va_bits:	Maximum virtual address bits.
+ * @mm_ops:	Memory management callbacks.
  *
  * Return: 0 on success, negative error code on failure.
  */
-int kvm_pgtable_hyp_init(struct kvm_pgtable *pgt, u32 va_bits);
+int kvm_pgtable_hyp_init(struct kvm_pgtable *pgt, u32 va_bits,
+			 struct kvm_pgtable_mm_ops *mm_ops);
 
 /**
  * kvm_pgtable_hyp_destroy() - Destroy an unused hypervisor stage-1 page-table.
@@ -123,17 +192,41 @@
 			enum kvm_pgtable_prot prot);
 
 /**
- * kvm_pgtable_stage2_init() - Initialise a guest stage-2 page-table.
+ * kvm_get_vtcr() - Helper to construct VTCR_EL2
+ * @mmfr0:	Sanitized value of SYS_ID_AA64MMFR0_EL1 register.
+ * @mmfr1:	Sanitized value of SYS_ID_AA64MMFR1_EL1 register.
+ * @phys_shfit:	Value to set in VTCR_EL2.T0SZ.
+ *
+ * The VTCR value is common across all the physical CPUs on the system.
+ * We use system wide sanitised values to fill in different fields,
+ * except for Hardware Management of Access Flags. HA Flag is set
+ * unconditionally on all CPUs, as it is safe to run with or without
+ * the feature and the bit is RES0 on CPUs that don't support it.
+ *
+ * Return: VTCR_EL2 value
+ */
+u64 kvm_get_vtcr(u64 mmfr0, u64 mmfr1, u32 phys_shift);
+
+/**
+ * kvm_pgtable_stage2_init_flags() - Initialise a guest stage-2 page-table.
  * @pgt:	Uninitialised page-table structure to initialise.
- * @kvm:	KVM structure representing the guest virtual machine.
+ * @arch:	Arch-specific KVM structure representing the guest virtual
+ *		machine.
+ * @mm_ops:	Memory management callbacks.
+ * @flags:	Stage-2 configuration flags.
  *
  * Return: 0 on success, negative error code on failure.
  */
-int kvm_pgtable_stage2_init(struct kvm_pgtable *pgt, struct kvm *kvm);
+int kvm_pgtable_stage2_init_flags(struct kvm_pgtable *pgt, struct kvm_arch *arch,
+				  struct kvm_pgtable_mm_ops *mm_ops,
+				  enum kvm_pgtable_stage2_flags flags);
+
+#define kvm_pgtable_stage2_init(pgt, arch, mm_ops) \
+	kvm_pgtable_stage2_init_flags(pgt, arch, mm_ops, 0)
 
 /**
  * kvm_pgtable_stage2_destroy() - Destroy an unused guest stage-2 page-table.
- * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init().
+ * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init*().
  *
  * The page-table is assumed to be unreachable by any hardware walkers prior
  * to freeing and therefore no TLB invalidation is performed.
@@ -142,13 +235,13 @@
 
 /**
  * kvm_pgtable_stage2_map() - Install a mapping in a guest stage-2 page-table.
- * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init().
+ * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init*().
  * @addr:	Intermediate physical address at which to place the mapping.
  * @size:	Size of the mapping.
  * @phys:	Physical address of the memory to map.
  * @prot:	Permissions and attributes for the mapping.
- * @mc:		Cache of pre-allocated GFP_PGTABLE_USER memory from which to
- *		allocate page-table pages.
+ * @mc:		Cache of pre-allocated and zeroed memory from which to allocate
+ *		page-table pages.
  *
  * The offset of @addr within a page is ignored, @size is rounded-up to
  * the next page boundary and @phys is rounded-down to the previous page
@@ -170,11 +263,31 @@
  */
 int kvm_pgtable_stage2_map(struct kvm_pgtable *pgt, u64 addr, u64 size,
 			   u64 phys, enum kvm_pgtable_prot prot,
-			   struct kvm_mmu_memory_cache *mc);
+			   void *mc);
+
+/**
+ * kvm_pgtable_stage2_set_owner() - Unmap and annotate pages in the IPA space to
+ *				    track ownership.
+ * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init*().
+ * @addr:	Base intermediate physical address to annotate.
+ * @size:	Size of the annotated range.
+ * @mc:		Cache of pre-allocated and zeroed memory from which to allocate
+ *		page-table pages.
+ * @owner_id:	Unique identifier for the owner of the page.
+ *
+ * By default, all page-tables are owned by identifier 0. This function can be
+ * used to mark portions of the IPA space as owned by other entities. When a
+ * stage 2 is used with identity-mappings, these annotations allow to use the
+ * page-table data structure as a simple rmap.
+ *
+ * Return: 0 on success, negative error code on failure.
+ */
+int kvm_pgtable_stage2_set_owner(struct kvm_pgtable *pgt, u64 addr, u64 size,
+				 void *mc, u8 owner_id);
 
 /**
  * kvm_pgtable_stage2_unmap() - Remove a mapping from a guest stage-2 page-table.
- * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init().
+ * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init*().
  * @addr:	Intermediate physical address from which to remove the mapping.
  * @size:	Size of the mapping.
  *
@@ -194,7 +307,7 @@
 /**
  * kvm_pgtable_stage2_wrprotect() - Write-protect guest stage-2 address range
  *                                  without TLB invalidation.
- * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init().
+ * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init*().
  * @addr:	Intermediate physical address from which to write-protect,
  * @size:	Size of the range.
  *
@@ -211,7 +324,7 @@
 
 /**
  * kvm_pgtable_stage2_mkyoung() - Set the access flag in a page-table entry.
- * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init().
+ * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init*().
  * @addr:	Intermediate physical address to identify the page-table entry.
  *
  * The offset of @addr within a page is ignored.
@@ -225,7 +338,7 @@
 
 /**
  * kvm_pgtable_stage2_mkold() - Clear the access flag in a page-table entry.
- * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init().
+ * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init*().
  * @addr:	Intermediate physical address to identify the page-table entry.
  *
  * The offset of @addr within a page is ignored.
@@ -244,7 +357,7 @@
 /**
  * kvm_pgtable_stage2_relax_perms() - Relax the permissions enforced by a
  *				      page-table entry.
- * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init().
+ * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init*().
  * @addr:	Intermediate physical address to identify the page-table entry.
  * @prot:	Additional permissions to grant for the mapping.
  *
@@ -263,7 +376,7 @@
 /**
  * kvm_pgtable_stage2_is_young() - Test whether a page-table entry has the
  *				   access flag set.
- * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init().
+ * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init*().
  * @addr:	Intermediate physical address to identify the page-table entry.
  *
  * The offset of @addr within a page is ignored.
@@ -276,7 +389,7 @@
  * kvm_pgtable_stage2_flush_range() - Clean and invalidate data cache to Point
  * 				      of Coherency for guest stage-2 address
  *				      range.
- * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init().
+ * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init*().
  * @addr:	Intermediate physical address from which to flush.
  * @size:	Size of the range.
  *
@@ -311,4 +424,23 @@
 int kvm_pgtable_walk(struct kvm_pgtable *pgt, u64 addr, u64 size,
 		     struct kvm_pgtable_walker *walker);
 
+/**
+ * kvm_pgtable_stage2_find_range() - Find a range of Intermediate Physical
+ *				     Addresses with compatible permission
+ *				     attributes.
+ * @pgt:	Page-table structure initialised by kvm_pgtable_stage2_init*().
+ * @addr:	Address that must be covered by the range.
+ * @prot:	Protection attributes that the range must be compatible with.
+ * @range:	Range structure used to limit the search space at call time and
+ *		that will hold the result.
+ *
+ * The offset of @addr within a page is ignored. An IPA is compatible with @prot
+ * iff its corresponding stage-2 page-table entry has default ownership and, if
+ * valid, is mapped with protection attributes identical to @prot.
+ *
+ * Return: 0 on success, negative error code on failure.
+ */
+int kvm_pgtable_stage2_find_range(struct kvm_pgtable *pgt, u64 addr,
+				  enum kvm_pgtable_prot prot,
+				  struct kvm_mem_range *range);
 #endif	/* __ARM64_KVM_PGTABLE_H__ */
diff --git a/arch/arm64/include/asm/pgtable-prot.h b/arch/arm64/include/asm/pgtable-prot.h
index 9a65fb5..079f4e9 100644
--- a/arch/arm64/include/asm/pgtable-prot.h
+++ b/arch/arm64/include/asm/pgtable-prot.h
@@ -71,10 +71,10 @@
 #define PAGE_KERNEL_EXEC	__pgprot(PROT_NORMAL & ~PTE_PXN)
 #define PAGE_KERNEL_EXEC_CONT	__pgprot((PROT_NORMAL & ~PTE_PXN) | PTE_CONT)
 
-#define PAGE_S2_MEMATTR(attr)						\
+#define PAGE_S2_MEMATTR(attr, has_fwb)					\
 	({								\
 		u64 __val;						\
-		if (cpus_have_const_cap(ARM64_HAS_STAGE2_FWB))		\
+		if (has_fwb)						\
 			__val = PTE_S2_MEMATTR(MT_S2_FWB_ ## attr);	\
 		else							\
 			__val = PTE_S2_MEMATTR(MT_S2_ ## attr);		\
diff --git a/arch/arm64/include/asm/sections.h b/arch/arm64/include/asm/sections.h
index 2f36b16..e4ad9db 100644
--- a/arch/arm64/include/asm/sections.h
+++ b/arch/arm64/include/asm/sections.h
@@ -13,6 +13,7 @@
 extern char __hyp_text_start[], __hyp_text_end[];
 extern char __hyp_rodata_start[], __hyp_rodata_end[];
 extern char __hyp_reloc_begin[], __hyp_reloc_end[];
+extern char __hyp_bss_start[], __hyp_bss_end[];
 extern char __idmap_text_start[], __idmap_text_end[];
 extern char __initdata_begin[], __initdata_end[];
 extern char __inittext_begin[], __inittext_end[];
diff --git a/arch/arm64/include/asm/sysreg.h b/arch/arm64/include/asm/sysreg.h
index d4a5fca9..8480169 100644
--- a/arch/arm64/include/asm/sysreg.h
+++ b/arch/arm64/include/asm/sysreg.h
@@ -333,6 +333,55 @@
 
 /*** End of Statistical Profiling Extension ***/
 
+/*
+ * TRBE Registers
+ */
+#define SYS_TRBLIMITR_EL1		sys_reg(3, 0, 9, 11, 0)
+#define SYS_TRBPTR_EL1			sys_reg(3, 0, 9, 11, 1)
+#define SYS_TRBBASER_EL1		sys_reg(3, 0, 9, 11, 2)
+#define SYS_TRBSR_EL1			sys_reg(3, 0, 9, 11, 3)
+#define SYS_TRBMAR_EL1			sys_reg(3, 0, 9, 11, 4)
+#define SYS_TRBTRG_EL1			sys_reg(3, 0, 9, 11, 6)
+#define SYS_TRBIDR_EL1			sys_reg(3, 0, 9, 11, 7)
+
+#define TRBLIMITR_LIMIT_MASK		GENMASK_ULL(51, 0)
+#define TRBLIMITR_LIMIT_SHIFT		12
+#define TRBLIMITR_NVM			BIT(5)
+#define TRBLIMITR_TRIG_MODE_MASK	GENMASK(1, 0)
+#define TRBLIMITR_TRIG_MODE_SHIFT	3
+#define TRBLIMITR_FILL_MODE_MASK	GENMASK(1, 0)
+#define TRBLIMITR_FILL_MODE_SHIFT	1
+#define TRBLIMITR_ENABLE		BIT(0)
+#define TRBPTR_PTR_MASK			GENMASK_ULL(63, 0)
+#define TRBPTR_PTR_SHIFT		0
+#define TRBBASER_BASE_MASK		GENMASK_ULL(51, 0)
+#define TRBBASER_BASE_SHIFT		12
+#define TRBSR_EC_MASK			GENMASK(5, 0)
+#define TRBSR_EC_SHIFT			26
+#define TRBSR_IRQ			BIT(22)
+#define TRBSR_TRG			BIT(21)
+#define TRBSR_WRAP			BIT(20)
+#define TRBSR_ABORT			BIT(18)
+#define TRBSR_STOP			BIT(17)
+#define TRBSR_MSS_MASK			GENMASK(15, 0)
+#define TRBSR_MSS_SHIFT			0
+#define TRBSR_BSC_MASK			GENMASK(5, 0)
+#define TRBSR_BSC_SHIFT			0
+#define TRBSR_FSC_MASK			GENMASK(5, 0)
+#define TRBSR_FSC_SHIFT			0
+#define TRBMAR_SHARE_MASK		GENMASK(1, 0)
+#define TRBMAR_SHARE_SHIFT		8
+#define TRBMAR_OUTER_MASK		GENMASK(3, 0)
+#define TRBMAR_OUTER_SHIFT		4
+#define TRBMAR_INNER_MASK		GENMASK(3, 0)
+#define TRBMAR_INNER_SHIFT		0
+#define TRBTRG_TRG_MASK			GENMASK(31, 0)
+#define TRBTRG_TRG_SHIFT		0
+#define TRBIDR_FLAG			BIT(5)
+#define TRBIDR_PROG			BIT(4)
+#define TRBIDR_ALIGN_MASK		GENMASK(3, 0)
+#define TRBIDR_ALIGN_SHIFT		0
+
 #define SYS_PMINTENSET_EL1		sys_reg(3, 0, 9, 14, 1)
 #define SYS_PMINTENCLR_EL1		sys_reg(3, 0, 9, 14, 2)
 
@@ -579,9 +628,6 @@
 #define SCTLR_ELx_A	(BIT(1))
 #define SCTLR_ELx_M	(BIT(0))
 
-#define SCTLR_ELx_FLAGS	(SCTLR_ELx_M  | SCTLR_ELx_A | SCTLR_ELx_C | \
-			 SCTLR_ELx_SA | SCTLR_ELx_I | SCTLR_ELx_IESB)
-
 /* SCTLR_EL2 specific flags. */
 #define SCTLR_EL2_RES1	((BIT(4))  | (BIT(5))  | (BIT(11)) | (BIT(16)) | \
 			 (BIT(18)) | (BIT(22)) | (BIT(23)) | (BIT(28)) | \
@@ -593,6 +639,10 @@
 #define ENDIAN_SET_EL2		0
 #endif
 
+#define INIT_SCTLR_EL2_MMU_ON						\
+	(SCTLR_ELx_M  | SCTLR_ELx_C | SCTLR_ELx_SA | SCTLR_ELx_I |	\
+	 SCTLR_ELx_IESB | SCTLR_ELx_WXN | ENDIAN_SET_EL2 | SCTLR_EL2_RES1)
+
 #define INIT_SCTLR_EL2_MMU_OFF \
 	(SCTLR_EL2_RES1 | ENDIAN_SET_EL2)
 
@@ -840,6 +890,7 @@
 #define ID_AA64MMFR2_CNP_SHIFT		0
 
 /* id_aa64dfr0 */
+#define ID_AA64DFR0_TRBE_SHIFT		44
 #define ID_AA64DFR0_TRACE_FILT_SHIFT	40
 #define ID_AA64DFR0_DOUBLELOCK_SHIFT	36
 #define ID_AA64DFR0_PMSVER_SHIFT	32
diff --git a/arch/arm64/kernel/asm-offsets.c b/arch/arm64/kernel/asm-offsets.c
index a36e2fc..8930b42 100644
--- a/arch/arm64/kernel/asm-offsets.c
+++ b/arch/arm64/kernel/asm-offsets.c
@@ -120,6 +120,9 @@
   DEFINE(NVHE_INIT_TPIDR_EL2,	offsetof(struct kvm_nvhe_init_params, tpidr_el2));
   DEFINE(NVHE_INIT_STACK_HYP_VA,	offsetof(struct kvm_nvhe_init_params, stack_hyp_va));
   DEFINE(NVHE_INIT_PGD_PA,	offsetof(struct kvm_nvhe_init_params, pgd_pa));
+  DEFINE(NVHE_INIT_HCR_EL2,	offsetof(struct kvm_nvhe_init_params, hcr_el2));
+  DEFINE(NVHE_INIT_VTTBR,	offsetof(struct kvm_nvhe_init_params, vttbr));
+  DEFINE(NVHE_INIT_VTCR,	offsetof(struct kvm_nvhe_init_params, vtcr));
 #endif
 #ifdef CONFIG_CPU_PM
   DEFINE(CPU_CTX_SP,		offsetof(struct cpu_suspend_ctx, sp));
diff --git a/arch/arm64/kernel/cpu-reset.S b/arch/arm64/kernel/cpu-reset.S
index 37721eb..d47ff63 100644
--- a/arch/arm64/kernel/cpu-reset.S
+++ b/arch/arm64/kernel/cpu-reset.S
@@ -30,10 +30,7 @@
  * flat identity mapping.
  */
 SYM_CODE_START(__cpu_soft_restart)
-	/* Clear sctlr_el1 flags. */
-	mrs	x12, sctlr_el1
-	mov_q	x13, SCTLR_ELx_FLAGS
-	bic	x12, x12, x13
+	mov_q	x12, INIT_SCTLR_EL1_MMU_OFF
 	pre_disable_mmu_workaround
 	/*
 	 * either disable EL1&0 translation regime or disable EL2&0 translation
diff --git a/arch/arm64/kernel/cpufeature.c b/arch/arm64/kernel/cpufeature.c
index e5281e1c..e3e0dcb 100644
--- a/arch/arm64/kernel/cpufeature.c
+++ b/arch/arm64/kernel/cpufeature.c
@@ -808,6 +808,12 @@
 					reg->name,
 					ftrp->shift + ftrp->width - 1,
 					ftrp->shift, str, tmp);
+		} else if ((ftr_mask & reg->override->val) == ftr_mask) {
+			reg->override->val &= ~ftr_mask;
+			pr_warn("%s[%d:%d]: impossible override, ignored\n",
+				reg->name,
+				ftrp->shift + ftrp->width - 1,
+				ftrp->shift);
 		}
 
 		val = arm64_ftr_set_value(ftrp, val, ftr_new);
@@ -1619,7 +1625,6 @@
 }
 #endif
 
-#ifdef CONFIG_ARM64_VHE
 static bool runs_at_el2(const struct arm64_cpu_capabilities *entry, int __unused)
 {
 	return is_kernel_in_hyp_mode();
@@ -1638,7 +1643,6 @@
 	if (!alternative_is_applied(ARM64_HAS_VIRT_HOST_EXTN))
 		write_sysreg(read_sysreg(tpidr_el1), tpidr_el2);
 }
-#endif
 
 static void cpu_has_fwb(const struct arm64_cpu_capabilities *__unused)
 {
@@ -1841,7 +1845,6 @@
 		.type = ARM64_CPUCAP_WEAK_LOCAL_CPU_FEATURE,
 		.matches = has_no_hw_prefetch,
 	},
-#ifdef CONFIG_ARM64_VHE
 	{
 		.desc = "Virtualization Host Extensions",
 		.capability = ARM64_HAS_VIRT_HOST_EXTN,
@@ -1849,7 +1852,6 @@
 		.matches = runs_at_el2,
 		.cpu_enable = cpu_copy_el2regs,
 	},
-#endif	/* CONFIG_ARM64_VHE */
 	{
 		.desc = "32-bit EL0 Support",
 		.capability = ARM64_HAS_32BIT_EL0,
diff --git a/arch/arm64/kernel/head.S b/arch/arm64/kernel/head.S
index 840bda1..96873df 100644
--- a/arch/arm64/kernel/head.S
+++ b/arch/arm64/kernel/head.S
@@ -477,14 +477,13 @@
  * booted in EL1 or EL2 respectively.
  */
 SYM_FUNC_START(init_kernel_el)
-	mov_q	x0, INIT_SCTLR_EL1_MMU_OFF
-	msr	sctlr_el1, x0
-
 	mrs	x0, CurrentEL
 	cmp	x0, #CurrentEL_EL2
 	b.eq	init_el2
 
 SYM_INNER_LABEL(init_el1, SYM_L_LOCAL)
+	mov_q	x0, INIT_SCTLR_EL1_MMU_OFF
+	msr	sctlr_el1, x0
 	isb
 	mov_q	x0, INIT_PSTATE_EL1
 	msr	spsr_el1, x0
@@ -504,9 +503,43 @@
 	msr	vbar_el2, x0
 	isb
 
+	/*
+	 * Fruity CPUs seem to have HCR_EL2.E2H set to RES1,
+	 * making it impossible to start in nVHE mode. Is that
+	 * compliant with the architecture? Absolutely not!
+	 */
+	mrs	x0, hcr_el2
+	and	x0, x0, #HCR_E2H
+	cbz	x0, 1f
+
+	/* Switching to VHE requires a sane SCTLR_EL1 as a start */
+	mov_q	x0, INIT_SCTLR_EL1_MMU_OFF
+	msr_s	SYS_SCTLR_EL12, x0
+
+	/*
+	 * Force an eret into a helper "function", and let it return
+	 * to our original caller... This makes sure that we have
+	 * initialised the basic PSTATE state.
+	 */
+	mov	x0, #INIT_PSTATE_EL2
+	msr	spsr_el1, x0
+	adr	x0, __cpu_stick_to_vhe
+	msr	elr_el1, x0
+	eret
+
+1:
+	mov_q	x0, INIT_SCTLR_EL1_MMU_OFF
+	msr	sctlr_el1, x0
+
 	msr	elr_el2, lr
 	mov	w0, #BOOT_CPU_MODE_EL2
 	eret
+
+__cpu_stick_to_vhe:
+	mov	x0, #HVC_VHE_RESTART
+	hvc	#0
+	mov	x0, #BOOT_CPU_MODE_EL2
+	ret
 SYM_FUNC_END(init_kernel_el)
 
 /*
diff --git a/arch/arm64/kernel/hyp-stub.S b/arch/arm64/kernel/hyp-stub.S
index 5eccbd6..43d2126 100644
--- a/arch/arm64/kernel/hyp-stub.S
+++ b/arch/arm64/kernel/hyp-stub.S
@@ -27,12 +27,12 @@
 	ventry	el2_fiq_invalid			// FIQ EL2t
 	ventry	el2_error_invalid		// Error EL2t
 
-	ventry	el2_sync_invalid		// Synchronous EL2h
+	ventry	elx_sync			// Synchronous EL2h
 	ventry	el2_irq_invalid			// IRQ EL2h
 	ventry	el2_fiq_invalid			// FIQ EL2h
 	ventry	el2_error_invalid		// Error EL2h
 
-	ventry	el1_sync			// Synchronous 64-bit EL1
+	ventry	elx_sync			// Synchronous 64-bit EL1
 	ventry	el1_irq_invalid			// IRQ 64-bit EL1
 	ventry	el1_fiq_invalid			// FIQ 64-bit EL1
 	ventry	el1_error_invalid		// Error 64-bit EL1
@@ -45,7 +45,7 @@
 
 	.align 11
 
-SYM_CODE_START_LOCAL(el1_sync)
+SYM_CODE_START_LOCAL(elx_sync)
 	cmp	x0, #HVC_SET_VECTORS
 	b.ne	1f
 	msr	vbar_el2, x1
@@ -71,7 +71,7 @@
 
 9:	mov	x0, xzr
 	eret
-SYM_CODE_END(el1_sync)
+SYM_CODE_END(elx_sync)
 
 // nVHE? No way! Give me the real thing!
 SYM_CODE_START_LOCAL(mutate_to_vhe)
@@ -115,9 +115,10 @@
 	mrs_s	x0, SYS_VBAR_EL12
 	msr	vbar_el1, x0
 
-	// Use EL2 translations for SPE and disable access from EL1
+	// Use EL2 translations for SPE & TRBE and disable access from EL1
 	mrs	x0, mdcr_el2
 	bic	x0, x0, #(MDCR_EL2_E2PB_MASK << MDCR_EL2_E2PB_SHIFT)
+	bic	x0, x0, #(MDCR_EL2_E2TB_MASK << MDCR_EL2_E2TB_SHIFT)
 	msr	mdcr_el2, x0
 
 	// Transfer the MM state from EL1 to EL2
@@ -224,7 +225,6 @@
  * Entry point to switch to VHE if deemed capable
  */
 SYM_FUNC_START(switch_to_vhe)
-#ifdef CONFIG_ARM64_VHE
 	// Need to have booted at EL2
 	adr_l	x1, __boot_cpu_mode
 	ldr	w0, [x1]
@@ -240,6 +240,5 @@
 	mov	x0, #HVC_VHE_RESTART
 	hvc	#0
 1:
-#endif
 	ret
 SYM_FUNC_END(switch_to_vhe)
diff --git a/arch/arm64/kernel/idreg-override.c b/arch/arm64/kernel/idreg-override.c
index 83f1c4b..e628c8c 100644
--- a/arch/arm64/kernel/idreg-override.c
+++ b/arch/arm64/kernel/idreg-override.c
@@ -25,14 +25,26 @@
 	struct {
 		char			name[FTR_DESC_FIELD_LEN];
 		u8			shift;
+		bool			(*filter)(u64 val);
 	} 				fields[];
 };
 
+static bool __init mmfr1_vh_filter(u64 val)
+{
+	/*
+	 * If we ever reach this point while running VHE, we're
+	 * guaranteed to be on one of these funky, VHE-stuck CPUs. If
+	 * the user was trying to force nVHE on us, proceed with
+	 * attitude adjustment.
+	 */
+	return !(is_kernel_in_hyp_mode() && val == 0);
+}
+
 static const struct ftr_set_desc mmfr1 __initconst = {
 	.name		= "id_aa64mmfr1",
 	.override	= &id_aa64mmfr1_override,
 	.fields		= {
-	        { "vh", ID_AA64MMFR1_VHE_SHIFT },
+		{ "vh", ID_AA64MMFR1_VHE_SHIFT, mmfr1_vh_filter },
 		{}
 	},
 };
@@ -124,6 +136,18 @@
 			if (find_field(cmdline, regs[i], f, &v))
 				continue;
 
+			/*
+			 * If an override gets filtered out, advertise
+			 * it by setting the value to 0xf, but
+			 * clearing the mask... Yes, this is fragile.
+			 */
+			if (regs[i]->fields[f].filter &&
+			    !regs[i]->fields[f].filter(v)) {
+				regs[i]->override->val  |= mask;
+				regs[i]->override->mask &= ~mask;
+				continue;
+			}
+
 			regs[i]->override->val  &= ~mask;
 			regs[i]->override->val  |= (v << shift) & mask;
 			regs[i]->override->mask |= mask;
diff --git a/arch/arm64/kernel/image-vars.h b/arch/arm64/kernel/image-vars.h
index 5aa9ed1..bcf3c27 100644
--- a/arch/arm64/kernel/image-vars.h
+++ b/arch/arm64/kernel/image-vars.h
@@ -65,13 +65,13 @@
 KVM_NVHE_ALIAS(kvm_patch_vector_branch);
 KVM_NVHE_ALIAS(kvm_update_va_mask);
 KVM_NVHE_ALIAS(kvm_get_kimage_voffset);
+KVM_NVHE_ALIAS(kvm_compute_final_ctr_el0);
 
 /* Global kernel state accessed by nVHE hyp code. */
 KVM_NVHE_ALIAS(kvm_vgic_global_state);
 
 /* Kernel symbols used to call panic() from nVHE hyp code (via ERET). */
-KVM_NVHE_ALIAS(__hyp_panic_string);
-KVM_NVHE_ALIAS(panic);
+KVM_NVHE_ALIAS(nvhe_hyp_panic_handler);
 
 /* Vectors installed by hyp-init on reset HVC. */
 KVM_NVHE_ALIAS(__hyp_stub_vectors);
@@ -104,6 +104,36 @@
 /* PMU available static key */
 KVM_NVHE_ALIAS(kvm_arm_pmu_available);
 
+/* Position-independent library routines */
+KVM_NVHE_ALIAS_HYP(clear_page, __pi_clear_page);
+KVM_NVHE_ALIAS_HYP(copy_page, __pi_copy_page);
+KVM_NVHE_ALIAS_HYP(memcpy, __pi_memcpy);
+KVM_NVHE_ALIAS_HYP(memset, __pi_memset);
+
+#ifdef CONFIG_KASAN
+KVM_NVHE_ALIAS_HYP(__memcpy, __pi_memcpy);
+KVM_NVHE_ALIAS_HYP(__memset, __pi_memset);
+#endif
+
+/* Kernel memory sections */
+KVM_NVHE_ALIAS(__start_rodata);
+KVM_NVHE_ALIAS(__end_rodata);
+KVM_NVHE_ALIAS(__bss_start);
+KVM_NVHE_ALIAS(__bss_stop);
+
+/* Hyp memory sections */
+KVM_NVHE_ALIAS(__hyp_idmap_text_start);
+KVM_NVHE_ALIAS(__hyp_idmap_text_end);
+KVM_NVHE_ALIAS(__hyp_text_start);
+KVM_NVHE_ALIAS(__hyp_text_end);
+KVM_NVHE_ALIAS(__hyp_bss_start);
+KVM_NVHE_ALIAS(__hyp_bss_end);
+KVM_NVHE_ALIAS(__hyp_rodata_start);
+KVM_NVHE_ALIAS(__hyp_rodata_end);
+
+/* pKVM static key */
+KVM_NVHE_ALIAS(kvm_protected_mode_initialized);
+
 #endif /* CONFIG_KVM */
 
 #endif /* __ARM64_KERNEL_IMAGE_VARS_H */
diff --git a/arch/arm64/kernel/vmlinux.lds.S b/arch/arm64/kernel/vmlinux.lds.S
index 7eea788..709d2c4 100644
--- a/arch/arm64/kernel/vmlinux.lds.S
+++ b/arch/arm64/kernel/vmlinux.lds.S
@@ -5,24 +5,7 @@
  * Written by Martin Mares <mj@atrey.karlin.mff.cuni.cz>
  */
 
-#define RO_EXCEPTION_TABLE_ALIGN	8
-#define RUNTIME_DISCARD_EXIT
-
-#include <asm-generic/vmlinux.lds.h>
-#include <asm/cache.h>
 #include <asm/hyp_image.h>
-#include <asm/kernel-pgtable.h>
-#include <asm/memory.h>
-#include <asm/page.h>
-
-#include "image.h"
-
-OUTPUT_ARCH(aarch64)
-ENTRY(_text)
-
-jiffies = jiffies_64;
-
-
 #ifdef CONFIG_KVM
 #define HYPERVISOR_EXTABLE					\
 	. = ALIGN(SZ_8);					\
@@ -32,9 +15,11 @@
 
 #define HYPERVISOR_DATA_SECTIONS				\
 	HYP_SECTION_NAME(.rodata) : {				\
+		. = ALIGN(PAGE_SIZE);				\
 		__hyp_rodata_start = .;				\
 		*(HYP_SECTION_NAME(.data..ro_after_init))	\
 		*(HYP_SECTION_NAME(.rodata))			\
+		. = ALIGN(PAGE_SIZE);				\
 		__hyp_rodata_end = .;				\
 	}
 
@@ -51,29 +36,52 @@
 		__hyp_reloc_end = .;				\
 	}
 
+#define BSS_FIRST_SECTIONS					\
+	__hyp_bss_start = .;					\
+	*(HYP_SECTION_NAME(.bss))				\
+	. = ALIGN(PAGE_SIZE);					\
+	__hyp_bss_end = .;
+
+/*
+ * We require that __hyp_bss_start and __bss_start are aligned, and enforce it
+ * with an assertion. But the BSS_SECTION macro places an empty .sbss section
+ * between them, which can in some cases cause the linker to misalign them. To
+ * work around the issue, force a page alignment for __bss_start.
+ */
+#define SBSS_ALIGN			PAGE_SIZE
 #else /* CONFIG_KVM */
 #define HYPERVISOR_EXTABLE
 #define HYPERVISOR_DATA_SECTIONS
 #define HYPERVISOR_PERCPU_SECTION
 #define HYPERVISOR_RELOC_SECTION
+#define SBSS_ALIGN			0
 #endif
 
+#define RO_EXCEPTION_TABLE_ALIGN	8
+#define RUNTIME_DISCARD_EXIT
+
+#include <asm-generic/vmlinux.lds.h>
+#include <asm/cache.h>
+#include <asm/kernel-pgtable.h>
+#include <asm/memory.h>
+#include <asm/page.h>
+
+#include "image.h"
+
+OUTPUT_ARCH(aarch64)
+ENTRY(_text)
+
+jiffies = jiffies_64;
+
 #define HYPERVISOR_TEXT					\
-	/*						\
-	 * Align to 4 KB so that			\
-	 * a) the HYP vector table is at its minimum	\
-	 *    alignment of 2048 bytes			\
-	 * b) the HYP init code will not cross a page	\
-	 *    boundary if its size does not exceed	\
-	 *    4 KB (see related ASSERT() below)		\
-	 */						\
-	. = ALIGN(SZ_4K);				\
+	. = ALIGN(PAGE_SIZE);				\
 	__hyp_idmap_text_start = .;			\
 	*(.hyp.idmap.text)				\
 	__hyp_idmap_text_end = .;			\
 	__hyp_text_start = .;				\
 	*(.hyp.text)					\
 	HYPERVISOR_EXTABLE				\
+	. = ALIGN(PAGE_SIZE);				\
 	__hyp_text_end = .;
 
 #define IDMAP_TEXT					\
@@ -276,7 +284,7 @@
 	__pecoff_data_rawsize = ABSOLUTE(. - __initdata_begin);
 	_edata = .;
 
-	BSS_SECTION(0, 0, 0)
+	BSS_SECTION(SBSS_ALIGN, 0, 0)
 
 	. = ALIGN(PAGE_SIZE);
 	init_pg_dir = .;
@@ -309,11 +317,12 @@
 #include "image-vars.h"
 
 /*
- * The HYP init code and ID map text can't be longer than a page each,
- * and should not cross a page boundary.
+ * The HYP init code and ID map text can't be longer than a page each. The
+ * former is page-aligned, but the latter may not be with 16K or 64K pages, so
+ * it should also not cross a page boundary.
  */
-ASSERT(__hyp_idmap_text_end - (__hyp_idmap_text_start & ~(SZ_4K - 1)) <= SZ_4K,
-	"HYP init code too big or misaligned")
+ASSERT(__hyp_idmap_text_end - __hyp_idmap_text_start <= PAGE_SIZE,
+	"HYP init code too big")
 ASSERT(__idmap_text_end - (__idmap_text_start & ~(SZ_4K - 1)) <= SZ_4K,
 	"ID map text too big or misaligned")
 #ifdef CONFIG_HIBERNATION
@@ -324,6 +333,9 @@
 ASSERT((__entry_tramp_text_end - __entry_tramp_text_start) == PAGE_SIZE,
 	"Entry trampoline text too big")
 #endif
+#ifdef CONFIG_KVM
+ASSERT(__hyp_bss_start == __bss_start, "HYP and Host BSS are misaligned")
+#endif
 /*
  * If padding is applied before .head.text, virt<->phys conversions will fail.
  */
diff --git a/arch/arm64/kvm/arm.c b/arch/arm64/kvm/arm.c
index 7f06ba7..4967ea3 100644
--- a/arch/arm64/kvm/arm.c
+++ b/arch/arm64/kvm/arm.c
@@ -206,6 +206,7 @@
 	case KVM_CAP_ARM_INJECT_EXT_DABT:
 	case KVM_CAP_SET_GUEST_DEBUG:
 	case KVM_CAP_VCPU_ATTRIBUTES:
+	case KVM_CAP_PTP_KVM:
 		r = 1;
 		break;
 	case KVM_CAP_ARM_SET_DEVICE_ADDR:
@@ -416,10 +417,12 @@
 
 	if (vcpu_has_ptrauth(vcpu))
 		vcpu_ptrauth_disable(vcpu);
+	kvm_arch_vcpu_load_debug_state_flags(vcpu);
 }
 
 void kvm_arch_vcpu_put(struct kvm_vcpu *vcpu)
 {
+	kvm_arch_vcpu_put_debug_state_flags(vcpu);
 	kvm_arch_vcpu_put_fp(vcpu);
 	if (has_vhe())
 		kvm_vcpu_put_sysregs_vhe(vcpu);
@@ -580,6 +583,8 @@
 
 	vcpu->arch.has_run_once = true;
 
+	kvm_arm_vcpu_init_debug(vcpu);
+
 	if (likely(irqchip_in_kernel(kvm))) {
 		/*
 		 * Map the VGIC hardware resources before running a vcpu the
@@ -1268,7 +1273,7 @@
 }
 
 void kvm_arch_flush_remote_tlbs_memslot(struct kvm *kvm,
-					struct kvm_memory_slot *memslot)
+					const struct kvm_memory_slot *memslot)
 {
 	kvm_flush_remote_tlbs(kvm);
 }
@@ -1350,16 +1355,9 @@
 /* A lookup table holding the hypervisor VA for each vector slot */
 static void *hyp_spectre_vector_selector[BP_HARDEN_EL2_SLOTS];
 
-static int __kvm_vector_slot2idx(enum arm64_hyp_spectre_vector slot)
-{
-	return slot - (slot != HYP_VECTOR_DIRECT);
-}
-
 static void kvm_init_vector_slot(void *base, enum arm64_hyp_spectre_vector slot)
 {
-	int idx = __kvm_vector_slot2idx(slot);
-
-	hyp_spectre_vector_selector[slot] = base + (idx * SZ_2K);
+	hyp_spectre_vector_selector[slot] = __kvm_vector_slot2addr(base, slot);
 }
 
 static int kvm_init_vector_slots(void)
@@ -1388,22 +1386,18 @@
 	return 0;
 }
 
-static void cpu_init_hyp_mode(void)
+static void cpu_prepare_hyp_mode(int cpu)
 {
-	struct kvm_nvhe_init_params *params = this_cpu_ptr_nvhe_sym(kvm_init_params);
-	struct arm_smccc_res res;
+	struct kvm_nvhe_init_params *params = per_cpu_ptr_nvhe_sym(kvm_init_params, cpu);
 	unsigned long tcr;
 
-	/* Switch from the HYP stub to our own HYP init vector */
-	__hyp_set_vectors(kvm_get_idmap_vector());
-
 	/*
 	 * Calculate the raw per-cpu offset without a translation from the
 	 * kernel's mapping to the linear mapping, and store it in tpidr_el2
 	 * so that we can use adr_l to access per-cpu variables in EL2.
 	 * Also drop the KASAN tag which gets in the way...
 	 */
-	params->tpidr_el2 = (unsigned long)kasan_reset_tag(this_cpu_ptr_nvhe_sym(__per_cpu_start)) -
+	params->tpidr_el2 = (unsigned long)kasan_reset_tag(per_cpu_ptr_nvhe_sym(__per_cpu_start, cpu)) -
 			    (unsigned long)kvm_ksym_ref(CHOOSE_NVHE_SYM(__per_cpu_start));
 
 	params->mair_el2 = read_sysreg(mair_el1);
@@ -1427,14 +1421,28 @@
 	tcr |= (idmap_t0sz & GENMASK(TCR_TxSZ_WIDTH - 1, 0)) << TCR_T0SZ_OFFSET;
 	params->tcr_el2 = tcr;
 
-	params->stack_hyp_va = kern_hyp_va(__this_cpu_read(kvm_arm_hyp_stack_page) + PAGE_SIZE);
+	params->stack_hyp_va = kern_hyp_va(per_cpu(kvm_arm_hyp_stack_page, cpu) + PAGE_SIZE);
 	params->pgd_pa = kvm_mmu_get_httbr();
+	if (is_protected_kvm_enabled())
+		params->hcr_el2 = HCR_HOST_NVHE_PROTECTED_FLAGS;
+	else
+		params->hcr_el2 = HCR_HOST_NVHE_FLAGS;
+	params->vttbr = params->vtcr = 0;
 
 	/*
 	 * Flush the init params from the data cache because the struct will
 	 * be read while the MMU is off.
 	 */
 	kvm_flush_dcache_to_poc(params, sizeof(*params));
+}
+
+static void hyp_install_host_vector(void)
+{
+	struct kvm_nvhe_init_params *params;
+	struct arm_smccc_res res;
+
+	/* Switch from the HYP stub to our own HYP init vector */
+	__hyp_set_vectors(kvm_get_idmap_vector());
 
 	/*
 	 * Call initialization code, and switch to the full blown HYP code.
@@ -1443,8 +1451,14 @@
 	 * cpus_have_const_cap() wrapper.
 	 */
 	BUG_ON(!system_capabilities_finalized());
+	params = this_cpu_ptr_nvhe_sym(kvm_init_params);
 	arm_smccc_1_1_hvc(KVM_HOST_SMCCC_FUNC(__kvm_hyp_init), virt_to_phys(params), &res);
 	WARN_ON(res.a0 != SMCCC_RET_SUCCESS);
+}
+
+static void cpu_init_hyp_mode(void)
+{
+	hyp_install_host_vector();
 
 	/*
 	 * Disabling SSBD on a non-VHE system requires us to enable SSBS
@@ -1487,7 +1501,10 @@
 	struct bp_hardening_data *data = this_cpu_ptr(&bp_hardening_data);
 	void *vector = hyp_spectre_vector_selector[data->slot];
 
-	*this_cpu_ptr_hyp_sym(kvm_hyp_vector) = (unsigned long)vector;
+	if (!is_protected_kvm_enabled())
+		*this_cpu_ptr_hyp_sym(kvm_hyp_vector) = (unsigned long)vector;
+	else
+		kvm_call_hyp_nvhe(__pkvm_cpu_set_vector, data->slot);
 }
 
 static void cpu_hyp_reinit(void)
@@ -1495,13 +1512,14 @@
 	kvm_init_host_cpu_context(&this_cpu_ptr_hyp_sym(kvm_host_data)->host_ctxt);
 
 	cpu_hyp_reset();
-	cpu_set_hyp_vector();
 
 	if (is_kernel_in_hyp_mode())
 		kvm_timer_init_vhe();
 	else
 		cpu_init_hyp_mode();
 
+	cpu_set_hyp_vector();
+
 	kvm_arm_init_debug();
 
 	if (vgic_present)
@@ -1697,18 +1715,62 @@
 	}
 }
 
+static int do_pkvm_init(u32 hyp_va_bits)
+{
+	void *per_cpu_base = kvm_ksym_ref(kvm_arm_hyp_percpu_base);
+	int ret;
+
+	preempt_disable();
+	hyp_install_host_vector();
+	ret = kvm_call_hyp_nvhe(__pkvm_init, hyp_mem_base, hyp_mem_size,
+				num_possible_cpus(), kern_hyp_va(per_cpu_base),
+				hyp_va_bits);
+	preempt_enable();
+
+	return ret;
+}
+
+static int kvm_hyp_init_protection(u32 hyp_va_bits)
+{
+	void *addr = phys_to_virt(hyp_mem_base);
+	int ret;
+
+	kvm_nvhe_sym(id_aa64mmfr0_el1_sys_val) = read_sanitised_ftr_reg(SYS_ID_AA64MMFR0_EL1);
+	kvm_nvhe_sym(id_aa64mmfr1_el1_sys_val) = read_sanitised_ftr_reg(SYS_ID_AA64MMFR1_EL1);
+
+	ret = create_hyp_mappings(addr, addr + hyp_mem_size, PAGE_HYP);
+	if (ret)
+		return ret;
+
+	ret = do_pkvm_init(hyp_va_bits);
+	if (ret)
+		return ret;
+
+	free_hyp_pgds();
+
+	return 0;
+}
+
 /**
  * Inits Hyp-mode on all online CPUs
  */
 static int init_hyp_mode(void)
 {
+	u32 hyp_va_bits;
 	int cpu;
-	int err = 0;
+	int err = -ENOMEM;
+
+	/*
+	 * The protected Hyp-mode cannot be initialized if the memory pool
+	 * allocation has failed.
+	 */
+	if (is_protected_kvm_enabled() && !hyp_mem_base)
+		goto out_err;
 
 	/*
 	 * Allocate Hyp PGD and setup Hyp identity mapping
 	 */
-	err = kvm_mmu_init();
+	err = kvm_mmu_init(&hyp_va_bits);
 	if (err)
 		goto out_err;
 
@@ -1769,7 +1831,19 @@
 		goto out_err;
 	}
 
-	err = create_hyp_mappings(kvm_ksym_ref(__bss_start),
+	/*
+	 * .hyp.bss is guaranteed to be placed at the beginning of the .bss
+	 * section thanks to an assertion in the linker script. Map it RW and
+	 * the rest of .bss RO.
+	 */
+	err = create_hyp_mappings(kvm_ksym_ref(__hyp_bss_start),
+				  kvm_ksym_ref(__hyp_bss_end), PAGE_HYP);
+	if (err) {
+		kvm_err("Cannot map hyp bss section: %d\n", err);
+		goto out_err;
+	}
+
+	err = create_hyp_mappings(kvm_ksym_ref(__hyp_bss_end),
 				  kvm_ksym_ref(__bss_stop), PAGE_HYP_RO);
 	if (err) {
 		kvm_err("Cannot map bss section\n");
@@ -1790,26 +1864,36 @@
 		}
 	}
 
-	/*
-	 * Map Hyp percpu pages
-	 */
 	for_each_possible_cpu(cpu) {
 		char *percpu_begin = (char *)kvm_arm_hyp_percpu_base[cpu];
 		char *percpu_end = percpu_begin + nvhe_percpu_size();
 
+		/* Map Hyp percpu pages */
 		err = create_hyp_mappings(percpu_begin, percpu_end, PAGE_HYP);
-
 		if (err) {
 			kvm_err("Cannot map hyp percpu region\n");
 			goto out_err;
 		}
+
+		/* Prepare the CPU initialization parameters */
+		cpu_prepare_hyp_mode(cpu);
 	}
 
 	if (is_protected_kvm_enabled()) {
 		init_cpu_logical_map();
 
-		if (!init_psci_relay())
+		if (!init_psci_relay()) {
+			err = -ENODEV;
 			goto out_err;
+		}
+	}
+
+	if (is_protected_kvm_enabled()) {
+		err = kvm_hyp_init_protection(hyp_va_bits);
+		if (err) {
+			kvm_err("Failed to init hyp memory protection\n");
+			goto out_err;
+		}
 	}
 
 	return 0;
@@ -1820,6 +1904,72 @@
 	return err;
 }
 
+static void _kvm_host_prot_finalize(void *discard)
+{
+	WARN_ON(kvm_call_hyp_nvhe(__pkvm_prot_finalize));
+}
+
+static inline int pkvm_mark_hyp(phys_addr_t start, phys_addr_t end)
+{
+	return kvm_call_hyp_nvhe(__pkvm_mark_hyp, start, end);
+}
+
+#define pkvm_mark_hyp_section(__section)		\
+	pkvm_mark_hyp(__pa_symbol(__section##_start),	\
+			__pa_symbol(__section##_end))
+
+static int finalize_hyp_mode(void)
+{
+	int cpu, ret;
+
+	if (!is_protected_kvm_enabled())
+		return 0;
+
+	ret = pkvm_mark_hyp_section(__hyp_idmap_text);
+	if (ret)
+		return ret;
+
+	ret = pkvm_mark_hyp_section(__hyp_text);
+	if (ret)
+		return ret;
+
+	ret = pkvm_mark_hyp_section(__hyp_rodata);
+	if (ret)
+		return ret;
+
+	ret = pkvm_mark_hyp_section(__hyp_bss);
+	if (ret)
+		return ret;
+
+	ret = pkvm_mark_hyp(hyp_mem_base, hyp_mem_base + hyp_mem_size);
+	if (ret)
+		return ret;
+
+	for_each_possible_cpu(cpu) {
+		phys_addr_t start = virt_to_phys((void *)kvm_arm_hyp_percpu_base[cpu]);
+		phys_addr_t end = start + (PAGE_SIZE << nvhe_percpu_order());
+
+		ret = pkvm_mark_hyp(start, end);
+		if (ret)
+			return ret;
+
+		start = virt_to_phys((void *)per_cpu(kvm_arm_hyp_stack_page, cpu));
+		end = start + PAGE_SIZE;
+		ret = pkvm_mark_hyp(start, end);
+		if (ret)
+			return ret;
+	}
+
+	/*
+	 * Flip the static key upfront as that may no longer be possible
+	 * once the host stage 2 is installed.
+	 */
+	static_branch_enable(&kvm_protected_mode_initialized);
+	on_each_cpu(_kvm_host_prot_finalize, NULL, 1);
+
+	return 0;
+}
+
 static void check_kvm_target_cpu(void *ret)
 {
 	*(int *)ret = kvm_target_cpu();
@@ -1894,11 +2044,6 @@
 
 	in_hyp_mode = is_kernel_in_hyp_mode();
 
-	if (!in_hyp_mode && kvm_arch_requires_vhe()) {
-		kvm_pr_unimpl("CPU unsupported in non-VHE mode, not initializing\n");
-		return -ENODEV;
-	}
-
 	if (cpus_have_final_cap(ARM64_WORKAROUND_DEVICE_LOAD_ACQUIRE) ||
 	    cpus_have_final_cap(ARM64_WORKAROUND_1508412))
 		kvm_info("Guests without required CPU erratum workarounds can deadlock system!\n" \
@@ -1936,8 +2081,15 @@
 	if (err)
 		goto out_hyp;
 
+	if (!in_hyp_mode) {
+		err = finalize_hyp_mode();
+		if (err) {
+			kvm_err("Failed to finalize Hyp protection\n");
+			goto out_hyp;
+		}
+	}
+
 	if (is_protected_kvm_enabled()) {
-		static_branch_enable(&kvm_protected_mode_initialized);
 		kvm_info("Protected nVHE mode initialized successfully\n");
 	} else if (in_hyp_mode) {
 		kvm_info("VHE mode initialized successfully\n");
diff --git a/arch/arm64/kvm/debug.c b/arch/arm64/kvm/debug.c
index dbc8905..d5e79d7e 100644
--- a/arch/arm64/kvm/debug.c
+++ b/arch/arm64/kvm/debug.c
@@ -69,6 +69,65 @@
 }
 
 /**
+ * kvm_arm_setup_mdcr_el2 - configure vcpu mdcr_el2 value
+ *
+ * @vcpu:	the vcpu pointer
+ *
+ * This ensures we will trap access to:
+ *  - Performance monitors (MDCR_EL2_TPM/MDCR_EL2_TPMCR)
+ *  - Debug ROM Address (MDCR_EL2_TDRA)
+ *  - OS related registers (MDCR_EL2_TDOSA)
+ *  - Statistical profiler (MDCR_EL2_TPMS/MDCR_EL2_E2PB)
+ *  - Self-hosted Trace Filter controls (MDCR_EL2_TTRF)
+ *  - Self-hosted Trace (MDCR_EL2_TTRF/MDCR_EL2_E2TB)
+ */
+static void kvm_arm_setup_mdcr_el2(struct kvm_vcpu *vcpu)
+{
+	/*
+	 * This also clears MDCR_EL2_E2PB_MASK and MDCR_EL2_E2TB_MASK
+	 * to disable guest access to the profiling and trace buffers
+	 */
+	vcpu->arch.mdcr_el2 = __this_cpu_read(mdcr_el2) & MDCR_EL2_HPMN_MASK;
+	vcpu->arch.mdcr_el2 |= (MDCR_EL2_TPM |
+				MDCR_EL2_TPMS |
+				MDCR_EL2_TTRF |
+				MDCR_EL2_TPMCR |
+				MDCR_EL2_TDRA |
+				MDCR_EL2_TDOSA);
+
+	/* Is the VM being debugged by userspace? */
+	if (vcpu->guest_debug)
+		/* Route all software debug exceptions to EL2 */
+		vcpu->arch.mdcr_el2 |= MDCR_EL2_TDE;
+
+	/*
+	 * Trap debug register access when one of the following is true:
+	 *  - Userspace is using the hardware to debug the guest
+	 *  (KVM_GUESTDBG_USE_HW is set).
+	 *  - The guest is not using debug (KVM_ARM64_DEBUG_DIRTY is clear).
+	 */
+	if ((vcpu->guest_debug & KVM_GUESTDBG_USE_HW) ||
+	    !(vcpu->arch.flags & KVM_ARM64_DEBUG_DIRTY))
+		vcpu->arch.mdcr_el2 |= MDCR_EL2_TDA;
+
+	trace_kvm_arm_set_dreg32("MDCR_EL2", vcpu->arch.mdcr_el2);
+}
+
+/**
+ * kvm_arm_vcpu_init_debug - setup vcpu debug traps
+ *
+ * @vcpu:	the vcpu pointer
+ *
+ * Set vcpu initial mdcr_el2 value.
+ */
+void kvm_arm_vcpu_init_debug(struct kvm_vcpu *vcpu)
+{
+	preempt_disable();
+	kvm_arm_setup_mdcr_el2(vcpu);
+	preempt_enable();
+}
+
+/**
  * kvm_arm_reset_debug_ptr - reset the debug ptr to point to the vcpu state
  */
 
@@ -83,13 +142,7 @@
  * @vcpu:	the vcpu pointer
  *
  * This is called before each entry into the hypervisor to setup any
- * debug related registers. Currently this just ensures we will trap
- * access to:
- *  - Performance monitors (MDCR_EL2_TPM/MDCR_EL2_TPMCR)
- *  - Debug ROM Address (MDCR_EL2_TDRA)
- *  - OS related registers (MDCR_EL2_TDOSA)
- *  - Statistical profiler (MDCR_EL2_TPMS/MDCR_EL2_E2PB)
- *  - Self-hosted Trace Filter controls (MDCR_EL2_TTRF)
+ * debug related registers.
  *
  * Additionally, KVM only traps guest accesses to the debug registers if
  * the guest is not actively using them (see the KVM_ARM64_DEBUG_DIRTY
@@ -101,28 +154,14 @@
 
 void kvm_arm_setup_debug(struct kvm_vcpu *vcpu)
 {
-	bool trap_debug = !(vcpu->arch.flags & KVM_ARM64_DEBUG_DIRTY);
 	unsigned long mdscr, orig_mdcr_el2 = vcpu->arch.mdcr_el2;
 
 	trace_kvm_arm_setup_debug(vcpu, vcpu->guest_debug);
 
-	/*
-	 * This also clears MDCR_EL2_E2PB_MASK to disable guest access
-	 * to the profiling buffer.
-	 */
-	vcpu->arch.mdcr_el2 = __this_cpu_read(mdcr_el2) & MDCR_EL2_HPMN_MASK;
-	vcpu->arch.mdcr_el2 |= (MDCR_EL2_TPM |
-				MDCR_EL2_TPMS |
-				MDCR_EL2_TTRF |
-				MDCR_EL2_TPMCR |
-				MDCR_EL2_TDRA |
-				MDCR_EL2_TDOSA);
+	kvm_arm_setup_mdcr_el2(vcpu);
 
 	/* Is Guest debugging in effect? */
 	if (vcpu->guest_debug) {
-		/* Route all software debug exceptions to EL2 */
-		vcpu->arch.mdcr_el2 |= MDCR_EL2_TDE;
-
 		/* Save guest debug state */
 		save_guest_debug_regs(vcpu);
 
@@ -176,7 +215,6 @@
 
 			vcpu->arch.debug_ptr = &vcpu->arch.external_debug_state;
 			vcpu->arch.flags |= KVM_ARM64_DEBUG_DIRTY;
-			trap_debug = true;
 
 			trace_kvm_arm_set_regset("BKPTS", get_num_brps(),
 						&vcpu->arch.debug_ptr->dbg_bcr[0],
@@ -191,10 +229,6 @@
 	BUG_ON(!vcpu->guest_debug &&
 		vcpu->arch.debug_ptr != &vcpu->arch.vcpu_debug_state);
 
-	/* Trap debug register access */
-	if (trap_debug)
-		vcpu->arch.mdcr_el2 |= MDCR_EL2_TDA;
-
 	/* If KDE or MDE are set, perform a full save/restore cycle. */
 	if (vcpu_read_sys_reg(vcpu, MDSCR_EL1) & (DBG_MDSCR_KDE | DBG_MDSCR_MDE))
 		vcpu->arch.flags |= KVM_ARM64_DEBUG_DIRTY;
@@ -203,7 +237,6 @@
 	if (has_vhe() && orig_mdcr_el2 != vcpu->arch.mdcr_el2)
 		write_sysreg(vcpu->arch.mdcr_el2, mdcr_el2);
 
-	trace_kvm_arm_set_dreg32("MDCR_EL2", vcpu->arch.mdcr_el2);
 	trace_kvm_arm_set_dreg32("MDSCR_EL1", vcpu_read_sys_reg(vcpu, MDSCR_EL1));
 }
 
@@ -231,3 +264,32 @@
 		}
 	}
 }
+
+void kvm_arch_vcpu_load_debug_state_flags(struct kvm_vcpu *vcpu)
+{
+	u64 dfr0;
+
+	/* For VHE, there is nothing to do */
+	if (has_vhe())
+		return;
+
+	dfr0 = read_sysreg(id_aa64dfr0_el1);
+	/*
+	 * If SPE is present on this CPU and is available at current EL,
+	 * we may need to check if the host state needs to be saved.
+	 */
+	if (cpuid_feature_extract_unsigned_field(dfr0, ID_AA64DFR0_PMSVER_SHIFT) &&
+	    !(read_sysreg_s(SYS_PMBIDR_EL1) & BIT(SYS_PMBIDR_EL1_P_SHIFT)))
+		vcpu->arch.flags |= KVM_ARM64_DEBUG_STATE_SAVE_SPE;
+
+	/* Check if we have TRBE implemented and available at the host */
+	if (cpuid_feature_extract_unsigned_field(dfr0, ID_AA64DFR0_TRBE_SHIFT) &&
+	    !(read_sysreg_s(SYS_TRBIDR_EL1) & TRBIDR_PROG))
+		vcpu->arch.flags |= KVM_ARM64_DEBUG_STATE_SAVE_TRBE;
+}
+
+void kvm_arch_vcpu_put_debug_state_flags(struct kvm_vcpu *vcpu)
+{
+	vcpu->arch.flags &= ~(KVM_ARM64_DEBUG_STATE_SAVE_SPE |
+			      KVM_ARM64_DEBUG_STATE_SAVE_TRBE);
+}
diff --git a/arch/arm64/kvm/fpsimd.c b/arch/arm64/kvm/fpsimd.c
index 3e081d5..5621020 100644
--- a/arch/arm64/kvm/fpsimd.c
+++ b/arch/arm64/kvm/fpsimd.c
@@ -11,6 +11,7 @@
 #include <linux/kvm_host.h>
 #include <asm/fpsimd.h>
 #include <asm/kvm_asm.h>
+#include <asm/kvm_hyp.h>
 #include <asm/kvm_mmu.h>
 #include <asm/sysreg.h>
 
@@ -42,6 +43,17 @@
 	if (ret)
 		goto error;
 
+	if (vcpu->arch.sve_state) {
+		void *sve_end;
+
+		sve_end = vcpu->arch.sve_state + vcpu_sve_state_size(vcpu);
+
+		ret = create_hyp_mappings(vcpu->arch.sve_state, sve_end,
+					  PAGE_HYP);
+		if (ret)
+			goto error;
+	}
+
 	vcpu->arch.host_thread_info = kern_hyp_va(ti);
 	vcpu->arch.host_fpsimd_state = kern_hyp_va(fpsimd);
 error:
@@ -109,11 +121,17 @@
 	local_irq_save(flags);
 
 	if (vcpu->arch.flags & KVM_ARM64_FP_ENABLED) {
-		fpsimd_save_and_flush_cpu_state();
+		if (guest_has_sve) {
+			__vcpu_sys_reg(vcpu, ZCR_EL1) = read_sysreg_el1(SYS_ZCR);
 
-		if (guest_has_sve)
-			__vcpu_sys_reg(vcpu, ZCR_EL1) = read_sysreg_s(SYS_ZCR_EL12);
-	} else if (host_has_sve) {
+			/* Restore the VL that was saved when bound to the CPU */
+			if (!has_vhe())
+				sve_cond_update_zcr_vq(vcpu_sve_max_vq(vcpu) - 1,
+						       SYS_ZCR_EL1);
+		}
+
+		fpsimd_save_and_flush_cpu_state();
+	} else if (has_vhe() && host_has_sve) {
 		/*
 		 * The FPSIMD/SVE state in the CPU has not been touched, and we
 		 * have SVE (and VHE): CPACR_EL1 (alias CPTR_EL2) has been
diff --git a/arch/arm64/kvm/guest.c b/arch/arm64/kvm/guest.c
index 9bbd30e..c763808 100644
--- a/arch/arm64/kvm/guest.c
+++ b/arch/arm64/kvm/guest.c
@@ -299,7 +299,7 @@
 
 	memset(vqs, 0, sizeof(vqs));
 
-	max_vq = sve_vq_from_vl(vcpu->arch.sve_max_vl);
+	max_vq = vcpu_sve_max_vq(vcpu);
 	for (vq = SVE_VQ_MIN; vq <= max_vq; ++vq)
 		if (sve_vq_available(vq))
 			vqs[vq_word(vq)] |= vq_mask(vq);
@@ -427,7 +427,7 @@
 		if (!vcpu_has_sve(vcpu) || (reg->id & SVE_REG_SLICE_MASK) > 0)
 			return -ENOENT;
 
-		vq = sve_vq_from_vl(vcpu->arch.sve_max_vl);
+		vq = vcpu_sve_max_vq(vcpu);
 
 		reqoffset = SVE_SIG_ZREG_OFFSET(vq, reg_num) -
 				SVE_SIG_REGS_OFFSET;
@@ -437,7 +437,7 @@
 		if (!vcpu_has_sve(vcpu) || (reg->id & SVE_REG_SLICE_MASK) > 0)
 			return -ENOENT;
 
-		vq = sve_vq_from_vl(vcpu->arch.sve_max_vl);
+		vq = vcpu_sve_max_vq(vcpu);
 
 		reqoffset = SVE_SIG_PREG_OFFSET(vq, reg_num) -
 				SVE_SIG_REGS_OFFSET;
diff --git a/arch/arm64/kvm/handle_exit.c b/arch/arm64/kvm/handle_exit.c
index cebe39f..6f48336 100644
--- a/arch/arm64/kvm/handle_exit.c
+++ b/arch/arm64/kvm/handle_exit.c
@@ -291,3 +291,48 @@
 	if (exception_index == ARM_EXCEPTION_EL1_SERROR)
 		kvm_handle_guest_serror(vcpu, kvm_vcpu_get_esr(vcpu));
 }
+
+void __noreturn __cold nvhe_hyp_panic_handler(u64 esr, u64 spsr, u64 elr,
+					      u64 par, uintptr_t vcpu,
+					      u64 far, u64 hpfar) {
+	u64 elr_in_kimg = __phys_to_kimg(__hyp_pa(elr));
+	u64 hyp_offset = elr_in_kimg - kaslr_offset() - elr;
+	u64 mode = spsr & PSR_MODE_MASK;
+
+	/*
+	 * The nVHE hyp symbols are not included by kallsyms to avoid issues
+	 * with aliasing. That means that the symbols cannot be printed with the
+	 * "%pS" format specifier, so fall back to the vmlinux address if
+	 * there's no better option.
+	 */
+	if (mode != PSR_MODE_EL2t && mode != PSR_MODE_EL2h) {
+		kvm_err("Invalid host exception to nVHE hyp!\n");
+	} else if (ESR_ELx_EC(esr) == ESR_ELx_EC_BRK64 &&
+		   (esr & ESR_ELx_BRK64_ISS_COMMENT_MASK) == BUG_BRK_IMM) {
+		struct bug_entry *bug = find_bug(elr_in_kimg);
+		const char *file = NULL;
+		unsigned int line = 0;
+
+		/* All hyp bugs, including warnings, are treated as fatal. */
+		if (bug)
+			bug_get_file_line(bug, &file, &line);
+
+		if (file)
+			kvm_err("nVHE hyp BUG at: %s:%u!\n", file, line);
+		else
+			kvm_err("nVHE hyp BUG at: %016llx!\n", elr + hyp_offset);
+	} else {
+		kvm_err("nVHE hyp panic at: %016llx!\n", elr + hyp_offset);
+	}
+
+	/*
+	 * Hyp has panicked and we're going to handle that by panicking the
+	 * kernel. The kernel offset will be revealed in the panic so we're
+	 * also safe to reveal the hyp offset as a debugging aid for translating
+	 * hyp VAs to vmlinux addresses.
+	 */
+	kvm_err("Hyp Offset: 0x%llx\n", hyp_offset);
+
+	panic("HYP panic:\nPS:%08llx PC:%016llx ESR:%08llx\nFAR:%016llx HPFAR:%016llx PAR:%016llx\nVCPU:%016lx\n",
+	      spsr, elr, esr, far, hpfar, par, vcpu);
+}
diff --git a/arch/arm64/kvm/hyp/Makefile b/arch/arm64/kvm/hyp/Makefile
index 687598e..b726332 100644
--- a/arch/arm64/kvm/hyp/Makefile
+++ b/arch/arm64/kvm/hyp/Makefile
@@ -10,4 +10,4 @@
 		    -DDISABLE_BRANCH_PROFILING		\
 		    $(DISABLE_STACKLEAK_PLUGIN)
 
-obj-$(CONFIG_KVM) += vhe/ nvhe/ pgtable.o
+obj-$(CONFIG_KVM) += vhe/ nvhe/ pgtable.o reserved_mem.o
diff --git a/arch/arm64/kvm/hyp/fpsimd.S b/arch/arm64/kvm/hyp/fpsimd.S
index 01f114a..3c63592 100644
--- a/arch/arm64/kvm/hyp/fpsimd.S
+++ b/arch/arm64/kvm/hyp/fpsimd.S
@@ -19,3 +19,13 @@
 	fpsimd_restore	x0, 1
 	ret
 SYM_FUNC_END(__fpsimd_restore_state)
+
+SYM_FUNC_START(__sve_restore_state)
+	__sve_load 0, x1, 2
+	ret
+SYM_FUNC_END(__sve_restore_state)
+
+SYM_FUNC_START(__sve_save_state)
+	sve_save 0, x1, 2
+	ret
+SYM_FUNC_END(__sve_save_state)
diff --git a/arch/arm64/kvm/hyp/include/hyp/switch.h b/arch/arm64/kvm/hyp/include/hyp/switch.h
index 6c1f51f..e4a2f29 100644
--- a/arch/arm64/kvm/hyp/include/hyp/switch.h
+++ b/arch/arm64/kvm/hyp/include/hyp/switch.h
@@ -30,8 +30,6 @@
 #include <asm/processor.h>
 #include <asm/thread_info.h>
 
-extern const char __hyp_panic_string[];
-
 extern struct exception_table_entry __start___kvm_ex_table;
 extern struct exception_table_entry __stop___kvm_ex_table;
 
@@ -160,18 +158,10 @@
 	return true;
 }
 
-static inline bool __populate_fault_info(struct kvm_vcpu *vcpu)
+static inline bool __get_fault_info(u64 esr, struct kvm_vcpu_fault_info *fault)
 {
-	u8 ec;
-	u64 esr;
 	u64 hpfar, far;
 
-	esr = vcpu->arch.fault.esr_el2;
-	ec = ESR_ELx_EC(esr);
-
-	if (ec != ESR_ELx_EC_DABT_LOW && ec != ESR_ELx_EC_IABT_LOW)
-		return true;
-
 	far = read_sysreg_el2(SYS_FAR);
 
 	/*
@@ -194,33 +184,59 @@
 		hpfar = read_sysreg(hpfar_el2);
 	}
 
-	vcpu->arch.fault.far_el2 = far;
-	vcpu->arch.fault.hpfar_el2 = hpfar;
+	fault->far_el2 = far;
+	fault->hpfar_el2 = hpfar;
 	return true;
 }
 
+static inline bool __populate_fault_info(struct kvm_vcpu *vcpu)
+{
+	u8 ec;
+	u64 esr;
+
+	esr = vcpu->arch.fault.esr_el2;
+	ec = ESR_ELx_EC(esr);
+
+	if (ec != ESR_ELx_EC_DABT_LOW && ec != ESR_ELx_EC_IABT_LOW)
+		return true;
+
+	return __get_fault_info(esr, &vcpu->arch.fault);
+}
+
+static inline void __hyp_sve_save_host(struct kvm_vcpu *vcpu)
+{
+	struct thread_struct *thread;
+
+	thread = container_of(vcpu->arch.host_fpsimd_state, struct thread_struct,
+			      uw.fpsimd_state);
+
+	__sve_save_state(sve_pffr(thread), &vcpu->arch.host_fpsimd_state->fpsr);
+}
+
+static inline void __hyp_sve_restore_guest(struct kvm_vcpu *vcpu)
+{
+	sve_cond_update_zcr_vq(vcpu_sve_max_vq(vcpu) - 1, SYS_ZCR_EL2);
+	__sve_restore_state(vcpu_sve_pffr(vcpu),
+			    &vcpu->arch.ctxt.fp_regs.fpsr);
+	write_sysreg_el1(__vcpu_sys_reg(vcpu, ZCR_EL1), SYS_ZCR);
+}
+
 /* Check for an FPSIMD/SVE trap and handle as appropriate */
 static inline bool __hyp_handle_fpsimd(struct kvm_vcpu *vcpu)
 {
-	bool vhe, sve_guest, sve_host;
+	bool sve_guest, sve_host;
 	u8 esr_ec;
+	u64 reg;
 
 	if (!system_supports_fpsimd())
 		return false;
 
-	/*
-	 * Currently system_supports_sve() currently implies has_vhe(),
-	 * so the check is redundant. However, has_vhe() can be determined
-	 * statically and helps the compiler remove dead code.
-	 */
-	if (has_vhe() && system_supports_sve()) {
+	if (system_supports_sve()) {
 		sve_guest = vcpu_has_sve(vcpu);
 		sve_host = vcpu->arch.flags & KVM_ARM64_HOST_SVE_IN_USE;
-		vhe = true;
 	} else {
 		sve_guest = false;
 		sve_host = false;
-		vhe = has_vhe();
 	}
 
 	esr_ec = kvm_vcpu_trap_get_class(vcpu);
@@ -229,53 +245,38 @@
 		return false;
 
 	/* Don't handle SVE traps for non-SVE vcpus here: */
-	if (!sve_guest)
-		if (esr_ec != ESR_ELx_EC_FP_ASIMD)
-			return false;
+	if (!sve_guest && esr_ec != ESR_ELx_EC_FP_ASIMD)
+		return false;
 
 	/* Valid trap.  Switch the context: */
-
-	if (vhe) {
-		u64 reg = read_sysreg(cpacr_el1) | CPACR_EL1_FPEN;
-
+	if (has_vhe()) {
+		reg = CPACR_EL1_FPEN;
 		if (sve_guest)
 			reg |= CPACR_EL1_ZEN;
 
-		write_sysreg(reg, cpacr_el1);
+		sysreg_clear_set(cpacr_el1, 0, reg);
 	} else {
-		write_sysreg(read_sysreg(cptr_el2) & ~(u64)CPTR_EL2_TFP,
-			     cptr_el2);
-	}
+		reg = CPTR_EL2_TFP;
+		if (sve_guest)
+			reg |= CPTR_EL2_TZ;
 
+		sysreg_clear_set(cptr_el2, reg, 0);
+	}
 	isb();
 
 	if (vcpu->arch.flags & KVM_ARM64_FP_HOST) {
-		/*
-		 * In the SVE case, VHE is assumed: it is enforced by
-		 * Kconfig and kvm_arch_init().
-		 */
-		if (sve_host) {
-			struct thread_struct *thread = container_of(
-				vcpu->arch.host_fpsimd_state,
-				struct thread_struct, uw.fpsimd_state);
-
-			sve_save_state(sve_pffr(thread),
-				       &vcpu->arch.host_fpsimd_state->fpsr);
-		} else {
+		if (sve_host)
+			__hyp_sve_save_host(vcpu);
+		else
 			__fpsimd_save_state(vcpu->arch.host_fpsimd_state);
-		}
 
 		vcpu->arch.flags &= ~KVM_ARM64_FP_HOST;
 	}
 
-	if (sve_guest) {
-		sve_load_state(vcpu_sve_pffr(vcpu),
-			       &vcpu->arch.ctxt.fp_regs.fpsr,
-			       sve_vq_from_vl(vcpu->arch.sve_max_vl) - 1);
-		write_sysreg_s(__vcpu_sys_reg(vcpu, ZCR_EL1), SYS_ZCR_EL12);
-	} else {
+	if (sve_guest)
+		__hyp_sve_restore_guest(vcpu);
+	else
 		__fpsimd_restore_state(&vcpu->arch.ctxt.fp_regs);
-	}
 
 	/* Skip restoring fpexc32 for AArch64 guests */
 	if (!(read_sysreg(hcr_el2) & HCR_RW))
diff --git a/arch/arm64/kvm/hyp/include/nvhe/early_alloc.h b/arch/arm64/kvm/hyp/include/nvhe/early_alloc.h
new file mode 100644
index 0000000..dc61aaa
--- /dev/null
+++ b/arch/arm64/kvm/hyp/include/nvhe/early_alloc.h
@@ -0,0 +1,14 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+#ifndef __KVM_HYP_EARLY_ALLOC_H
+#define __KVM_HYP_EARLY_ALLOC_H
+
+#include <asm/kvm_pgtable.h>
+
+void hyp_early_alloc_init(void *virt, unsigned long size);
+unsigned long hyp_early_alloc_nr_used_pages(void);
+void *hyp_early_alloc_page(void *arg);
+void *hyp_early_alloc_contig(unsigned int nr_pages);
+
+extern struct kvm_pgtable_mm_ops hyp_early_alloc_mm_ops;
+
+#endif /* __KVM_HYP_EARLY_ALLOC_H */
diff --git a/arch/arm64/kvm/hyp/include/nvhe/gfp.h b/arch/arm64/kvm/hyp/include/nvhe/gfp.h
new file mode 100644
index 0000000..18a4494
--- /dev/null
+++ b/arch/arm64/kvm/hyp/include/nvhe/gfp.h
@@ -0,0 +1,68 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+#ifndef __KVM_HYP_GFP_H
+#define __KVM_HYP_GFP_H
+
+#include <linux/list.h>
+
+#include <nvhe/memory.h>
+#include <nvhe/spinlock.h>
+
+#define HYP_NO_ORDER	UINT_MAX
+
+struct hyp_pool {
+	/*
+	 * Spinlock protecting concurrent changes to the memory pool as well as
+	 * the struct hyp_page of the pool's pages until we have a proper atomic
+	 * API at EL2.
+	 */
+	hyp_spinlock_t lock;
+	struct list_head free_area[MAX_ORDER];
+	phys_addr_t range_start;
+	phys_addr_t range_end;
+	unsigned int max_order;
+};
+
+static inline void hyp_page_ref_inc(struct hyp_page *p)
+{
+	struct hyp_pool *pool = hyp_page_to_pool(p);
+
+	hyp_spin_lock(&pool->lock);
+	p->refcount++;
+	hyp_spin_unlock(&pool->lock);
+}
+
+static inline int hyp_page_ref_dec_and_test(struct hyp_page *p)
+{
+	struct hyp_pool *pool = hyp_page_to_pool(p);
+	int ret;
+
+	hyp_spin_lock(&pool->lock);
+	p->refcount--;
+	ret = (p->refcount == 0);
+	hyp_spin_unlock(&pool->lock);
+
+	return ret;
+}
+
+static inline void hyp_set_page_refcounted(struct hyp_page *p)
+{
+	struct hyp_pool *pool = hyp_page_to_pool(p);
+
+	hyp_spin_lock(&pool->lock);
+	if (p->refcount) {
+		hyp_spin_unlock(&pool->lock);
+		BUG();
+	}
+	p->refcount = 1;
+	hyp_spin_unlock(&pool->lock);
+}
+
+/* Allocation */
+void *hyp_alloc_pages(struct hyp_pool *pool, unsigned int order);
+void hyp_get_page(void *addr);
+void hyp_put_page(void *addr);
+
+/* Used pages cannot be freed */
+int hyp_pool_init(struct hyp_pool *pool, u64 pfn, unsigned int nr_pages,
+		  unsigned int reserved_pages);
+#endif /* __KVM_HYP_GFP_H */
diff --git a/arch/arm64/kvm/hyp/include/nvhe/mem_protect.h b/arch/arm64/kvm/hyp/include/nvhe/mem_protect.h
new file mode 100644
index 0000000..42d81ec
--- /dev/null
+++ b/arch/arm64/kvm/hyp/include/nvhe/mem_protect.h
@@ -0,0 +1,36 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+/*
+ * Copyright (C) 2020 Google LLC
+ * Author: Quentin Perret <qperret@google.com>
+ */
+
+#ifndef __KVM_NVHE_MEM_PROTECT__
+#define __KVM_NVHE_MEM_PROTECT__
+#include <linux/kvm_host.h>
+#include <asm/kvm_hyp.h>
+#include <asm/kvm_pgtable.h>
+#include <asm/virt.h>
+#include <nvhe/spinlock.h>
+
+struct host_kvm {
+	struct kvm_arch arch;
+	struct kvm_pgtable pgt;
+	struct kvm_pgtable_mm_ops mm_ops;
+	hyp_spinlock_t lock;
+};
+extern struct host_kvm host_kvm;
+
+int __pkvm_prot_finalize(void);
+int __pkvm_mark_hyp(phys_addr_t start, phys_addr_t end);
+
+int kvm_host_prepare_stage2(void *mem_pgt_pool, void *dev_pgt_pool);
+void handle_host_mem_abort(struct kvm_cpu_context *host_ctxt);
+
+static __always_inline void __load_host_stage2(void)
+{
+	if (static_branch_likely(&kvm_protected_mode_initialized))
+		__load_stage2(&host_kvm.arch.mmu, host_kvm.arch.vtcr);
+	else
+		write_sysreg(0, vttbr_el2);
+}
+#endif /* __KVM_NVHE_MEM_PROTECT__ */
diff --git a/arch/arm64/kvm/hyp/include/nvhe/memory.h b/arch/arm64/kvm/hyp/include/nvhe/memory.h
new file mode 100644
index 0000000..fd78bde
--- /dev/null
+++ b/arch/arm64/kvm/hyp/include/nvhe/memory.h
@@ -0,0 +1,51 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+#ifndef __KVM_HYP_MEMORY_H
+#define __KVM_HYP_MEMORY_H
+
+#include <asm/kvm_mmu.h>
+#include <asm/page.h>
+
+#include <linux/types.h>
+
+struct hyp_pool;
+struct hyp_page {
+	unsigned int refcount;
+	unsigned int order;
+	struct hyp_pool *pool;
+	struct list_head node;
+};
+
+extern u64 __hyp_vmemmap;
+#define hyp_vmemmap ((struct hyp_page *)__hyp_vmemmap)
+
+#define __hyp_va(phys)	((void *)((phys_addr_t)(phys) - hyp_physvirt_offset))
+
+static inline void *hyp_phys_to_virt(phys_addr_t phys)
+{
+	return __hyp_va(phys);
+}
+
+static inline phys_addr_t hyp_virt_to_phys(void *addr)
+{
+	return __hyp_pa(addr);
+}
+
+#define hyp_phys_to_pfn(phys)	((phys) >> PAGE_SHIFT)
+#define hyp_pfn_to_phys(pfn)	((phys_addr_t)((pfn) << PAGE_SHIFT))
+#define hyp_phys_to_page(phys)	(&hyp_vmemmap[hyp_phys_to_pfn(phys)])
+#define hyp_virt_to_page(virt)	hyp_phys_to_page(__hyp_pa(virt))
+#define hyp_virt_to_pfn(virt)	hyp_phys_to_pfn(__hyp_pa(virt))
+
+#define hyp_page_to_pfn(page)	((struct hyp_page *)(page) - hyp_vmemmap)
+#define hyp_page_to_phys(page)  hyp_pfn_to_phys((hyp_page_to_pfn(page)))
+#define hyp_page_to_virt(page)	__hyp_va(hyp_page_to_phys(page))
+#define hyp_page_to_pool(page)	(((struct hyp_page *)page)->pool)
+
+static inline int hyp_page_count(void *addr)
+{
+	struct hyp_page *p = hyp_virt_to_page(addr);
+
+	return p->refcount;
+}
+
+#endif /* __KVM_HYP_MEMORY_H */
diff --git a/arch/arm64/kvm/hyp/include/nvhe/mm.h b/arch/arm64/kvm/hyp/include/nvhe/mm.h
new file mode 100644
index 0000000..0095f62
--- /dev/null
+++ b/arch/arm64/kvm/hyp/include/nvhe/mm.h
@@ -0,0 +1,96 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+#ifndef __KVM_HYP_MM_H
+#define __KVM_HYP_MM_H
+
+#include <asm/kvm_pgtable.h>
+#include <asm/spectre.h>
+#include <linux/memblock.h>
+#include <linux/types.h>
+
+#include <nvhe/memory.h>
+#include <nvhe/spinlock.h>
+
+#define HYP_MEMBLOCK_REGIONS 128
+extern struct memblock_region kvm_nvhe_sym(hyp_memory)[];
+extern unsigned int kvm_nvhe_sym(hyp_memblock_nr);
+extern struct kvm_pgtable pkvm_pgtable;
+extern hyp_spinlock_t pkvm_pgd_lock;
+extern struct hyp_pool hpool;
+extern u64 __io_map_base;
+
+int hyp_create_idmap(u32 hyp_va_bits);
+int hyp_map_vectors(void);
+int hyp_back_vmemmap(phys_addr_t phys, unsigned long size, phys_addr_t back);
+int pkvm_cpu_set_vector(enum arm64_hyp_spectre_vector slot);
+int pkvm_create_mappings(void *from, void *to, enum kvm_pgtable_prot prot);
+int __pkvm_create_mappings(unsigned long start, unsigned long size,
+			   unsigned long phys, enum kvm_pgtable_prot prot);
+unsigned long __pkvm_create_private_mapping(phys_addr_t phys, size_t size,
+					    enum kvm_pgtable_prot prot);
+
+static inline void hyp_vmemmap_range(phys_addr_t phys, unsigned long size,
+				     unsigned long *start, unsigned long *end)
+{
+	unsigned long nr_pages = size >> PAGE_SHIFT;
+	struct hyp_page *p = hyp_phys_to_page(phys);
+
+	*start = (unsigned long)p;
+	*end = *start + nr_pages * sizeof(struct hyp_page);
+	*start = ALIGN_DOWN(*start, PAGE_SIZE);
+	*end = ALIGN(*end, PAGE_SIZE);
+}
+
+static inline unsigned long __hyp_pgtable_max_pages(unsigned long nr_pages)
+{
+	unsigned long total = 0, i;
+
+	/* Provision the worst case scenario */
+	for (i = 0; i < KVM_PGTABLE_MAX_LEVELS; i++) {
+		nr_pages = DIV_ROUND_UP(nr_pages, PTRS_PER_PTE);
+		total += nr_pages;
+	}
+
+	return total;
+}
+
+static inline unsigned long __hyp_pgtable_total_pages(void)
+{
+	unsigned long res = 0, i;
+
+	/* Cover all of memory with page-granularity */
+	for (i = 0; i < kvm_nvhe_sym(hyp_memblock_nr); i++) {
+		struct memblock_region *reg = &kvm_nvhe_sym(hyp_memory)[i];
+		res += __hyp_pgtable_max_pages(reg->size >> PAGE_SHIFT);
+	}
+
+	return res;
+}
+
+static inline unsigned long hyp_s1_pgtable_pages(void)
+{
+	unsigned long res;
+
+	res = __hyp_pgtable_total_pages();
+
+	/* Allow 1 GiB for private mappings */
+	res += __hyp_pgtable_max_pages(SZ_1G >> PAGE_SHIFT);
+
+	return res;
+}
+
+static inline unsigned long host_s2_mem_pgtable_pages(void)
+{
+	/*
+	 * Include an extra 16 pages to safely upper-bound the worst case of
+	 * concatenated pgds.
+	 */
+	return __hyp_pgtable_total_pages() + 16;
+}
+
+static inline unsigned long host_s2_dev_pgtable_pages(void)
+{
+	/* Allow 1 GiB for MMIO mappings */
+	return __hyp_pgtable_max_pages(SZ_1G >> PAGE_SHIFT);
+}
+
+#endif /* __KVM_HYP_MM_H */
diff --git a/arch/arm64/kvm/hyp/include/nvhe/spinlock.h b/arch/arm64/kvm/hyp/include/nvhe/spinlock.h
new file mode 100644
index 0000000..76b537f
--- /dev/null
+++ b/arch/arm64/kvm/hyp/include/nvhe/spinlock.h
@@ -0,0 +1,92 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+/*
+ * A stand-alone ticket spinlock implementation for use by the non-VHE
+ * KVM hypervisor code running at EL2.
+ *
+ * Copyright (C) 2020 Google LLC
+ * Author: Will Deacon <will@kernel.org>
+ *
+ * Heavily based on the implementation removed by c11090474d70 which was:
+ * Copyright (C) 2012 ARM Ltd.
+ */
+
+#ifndef __ARM64_KVM_NVHE_SPINLOCK_H__
+#define __ARM64_KVM_NVHE_SPINLOCK_H__
+
+#include <asm/alternative.h>
+#include <asm/lse.h>
+
+typedef union hyp_spinlock {
+	u32	__val;
+	struct {
+#ifdef __AARCH64EB__
+		u16 next, owner;
+#else
+		u16 owner, next;
+#endif
+	};
+} hyp_spinlock_t;
+
+#define hyp_spin_lock_init(l)						\
+do {									\
+	*(l) = (hyp_spinlock_t){ .__val = 0 };				\
+} while (0)
+
+static inline void hyp_spin_lock(hyp_spinlock_t *lock)
+{
+	u32 tmp;
+	hyp_spinlock_t lockval, newval;
+
+	asm volatile(
+	/* Atomically increment the next ticket. */
+	ARM64_LSE_ATOMIC_INSN(
+	/* LL/SC */
+"	prfm	pstl1strm, %3\n"
+"1:	ldaxr	%w0, %3\n"
+"	add	%w1, %w0, #(1 << 16)\n"
+"	stxr	%w2, %w1, %3\n"
+"	cbnz	%w2, 1b\n",
+	/* LSE atomics */
+"	mov	%w2, #(1 << 16)\n"
+"	ldadda	%w2, %w0, %3\n"
+	__nops(3))
+
+	/* Did we get the lock? */
+"	eor	%w1, %w0, %w0, ror #16\n"
+"	cbz	%w1, 3f\n"
+	/*
+	 * No: spin on the owner. Send a local event to avoid missing an
+	 * unlock before the exclusive load.
+	 */
+"	sevl\n"
+"2:	wfe\n"
+"	ldaxrh	%w2, %4\n"
+"	eor	%w1, %w2, %w0, lsr #16\n"
+"	cbnz	%w1, 2b\n"
+	/* We got the lock. Critical section starts here. */
+"3:"
+	: "=&r" (lockval), "=&r" (newval), "=&r" (tmp), "+Q" (*lock)
+	: "Q" (lock->owner)
+	: "memory");
+}
+
+static inline void hyp_spin_unlock(hyp_spinlock_t *lock)
+{
+	u64 tmp;
+
+	asm volatile(
+	ARM64_LSE_ATOMIC_INSN(
+	/* LL/SC */
+	"	ldrh	%w1, %0\n"
+	"	add	%w1, %w1, #1\n"
+	"	stlrh	%w1, %0",
+	/* LSE atomics */
+	"	mov	%w1, #1\n"
+	"	staddlh	%w1, %0\n"
+	__nops(1))
+	: "=Q" (lock->owner), "=&r" (tmp)
+	:
+	: "memory");
+}
+
+#endif /* __ARM64_KVM_NVHE_SPINLOCK_H__ */
diff --git a/arch/arm64/kvm/hyp/nvhe/Makefile b/arch/arm64/kvm/hyp/nvhe/Makefile
index a6707df..f55201a 100644
--- a/arch/arm64/kvm/hyp/nvhe/Makefile
+++ b/arch/arm64/kvm/hyp/nvhe/Makefile
@@ -9,10 +9,15 @@
 hostprogs := gen-hyprel
 HOST_EXTRACFLAGS += -I$(objtree)/include
 
+lib-objs := clear_page.o copy_page.o memcpy.o memset.o
+lib-objs := $(addprefix ../../../lib/, $(lib-objs))
+
 obj-y := timer-sr.o sysreg-sr.o debug-sr.o switch.o tlb.o hyp-init.o host.o \
-	 hyp-main.o hyp-smp.o psci-relay.o
+	 hyp-main.o hyp-smp.o psci-relay.o early_alloc.o stub.o page_alloc.o \
+	 cache.o setup.o mm.o mem_protect.o
 obj-y += ../vgic-v3-sr.o ../aarch32.o ../vgic-v2-cpuif-proxy.o ../entry.o \
-	 ../fpsimd.o ../hyp-entry.o ../exception.o
+	 ../fpsimd.o ../hyp-entry.o ../exception.o ../pgtable.o
+obj-y += $(lib-objs)
 
 ##
 ## Build rules for compiling nVHE hyp code
diff --git a/arch/arm64/kvm/hyp/nvhe/cache.S b/arch/arm64/kvm/hyp/nvhe/cache.S
new file mode 100644
index 0000000..36cef69
--- /dev/null
+++ b/arch/arm64/kvm/hyp/nvhe/cache.S
@@ -0,0 +1,13 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+/*
+ * Code copied from arch/arm64/mm/cache.S.
+ */
+
+#include <linux/linkage.h>
+#include <asm/assembler.h>
+#include <asm/alternative.h>
+
+SYM_FUNC_START_PI(__flush_dcache_area)
+	dcache_by_line_op civac, sy, x0, x1, x2, x3
+	ret
+SYM_FUNC_END_PI(__flush_dcache_area)
diff --git a/arch/arm64/kvm/hyp/nvhe/debug-sr.c b/arch/arm64/kvm/hyp/nvhe/debug-sr.c
index f401724..7d3f258 100644
--- a/arch/arm64/kvm/hyp/nvhe/debug-sr.c
+++ b/arch/arm64/kvm/hyp/nvhe/debug-sr.c
@@ -21,17 +21,11 @@
 	/* Clear pmscr in case of early return */
 	*pmscr_el1 = 0;
 
-	/* SPE present on this CPU? */
-	if (!cpuid_feature_extract_unsigned_field(read_sysreg(id_aa64dfr0_el1),
-						  ID_AA64DFR0_PMSVER_SHIFT))
-		return;
-
-	/* Yes; is it owned by EL3? */
-	reg = read_sysreg_s(SYS_PMBIDR_EL1);
-	if (reg & BIT(SYS_PMBIDR_EL1_P_SHIFT))
-		return;
-
-	/* No; is the host actually using the thing? */
+	/*
+	 * At this point, we know that this CPU implements
+	 * SPE and is available to the host.
+	 * Check if the host is actually using it ?
+	 */
 	reg = read_sysreg_s(SYS_PMBLIMITR_EL1);
 	if (!(reg & BIT(SYS_PMBLIMITR_EL1_E_SHIFT)))
 		return;
@@ -58,10 +52,43 @@
 	write_sysreg_s(pmscr_el1, SYS_PMSCR_EL1);
 }
 
+static void __debug_save_trace(u64 *trfcr_el1)
+{
+	*trfcr_el1 = 0;
+
+	/* Check if the TRBE is enabled */
+	if (!(read_sysreg_s(SYS_TRBLIMITR_EL1) & TRBLIMITR_ENABLE))
+		return;
+	/*
+	 * Prohibit trace generation while we are in guest.
+	 * Since access to TRFCR_EL1 is trapped, the guest can't
+	 * modify the filtering set by the host.
+	 */
+	*trfcr_el1 = read_sysreg_s(SYS_TRFCR_EL1);
+	write_sysreg_s(0, SYS_TRFCR_EL1);
+	isb();
+	/* Drain the trace buffer to memory */
+	tsb_csync();
+	dsb(nsh);
+}
+
+static void __debug_restore_trace(u64 trfcr_el1)
+{
+	if (!trfcr_el1)
+		return;
+
+	/* Restore trace filter controls */
+	write_sysreg_s(trfcr_el1, SYS_TRFCR_EL1);
+}
+
 void __debug_save_host_buffers_nvhe(struct kvm_vcpu *vcpu)
 {
 	/* Disable and flush SPE data generation */
-	__debug_save_spe(&vcpu->arch.host_debug_state.pmscr_el1);
+	if (vcpu->arch.flags & KVM_ARM64_DEBUG_STATE_SAVE_SPE)
+		__debug_save_spe(&vcpu->arch.host_debug_state.pmscr_el1);
+	/* Disable and flush Self-Hosted Trace generation */
+	if (vcpu->arch.flags & KVM_ARM64_DEBUG_STATE_SAVE_TRBE)
+		__debug_save_trace(&vcpu->arch.host_debug_state.trfcr_el1);
 }
 
 void __debug_switch_to_guest(struct kvm_vcpu *vcpu)
@@ -71,7 +98,10 @@
 
 void __debug_restore_host_buffers_nvhe(struct kvm_vcpu *vcpu)
 {
-	__debug_restore_spe(vcpu->arch.host_debug_state.pmscr_el1);
+	if (vcpu->arch.flags & KVM_ARM64_DEBUG_STATE_SAVE_SPE)
+		__debug_restore_spe(vcpu->arch.host_debug_state.pmscr_el1);
+	if (vcpu->arch.flags & KVM_ARM64_DEBUG_STATE_SAVE_TRBE)
+		__debug_restore_trace(vcpu->arch.host_debug_state.trfcr_el1);
 }
 
 void __debug_switch_to_host(struct kvm_vcpu *vcpu)
diff --git a/arch/arm64/kvm/hyp/nvhe/early_alloc.c b/arch/arm64/kvm/hyp/nvhe/early_alloc.c
new file mode 100644
index 0000000..1306c43
--- /dev/null
+++ b/arch/arm64/kvm/hyp/nvhe/early_alloc.c
@@ -0,0 +1,54 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Copyright (C) 2020 Google LLC
+ * Author: Quentin Perret <qperret@google.com>
+ */
+
+#include <asm/kvm_pgtable.h>
+
+#include <nvhe/early_alloc.h>
+#include <nvhe/memory.h>
+
+struct kvm_pgtable_mm_ops hyp_early_alloc_mm_ops;
+s64 __ro_after_init hyp_physvirt_offset;
+
+static unsigned long base;
+static unsigned long end;
+static unsigned long cur;
+
+unsigned long hyp_early_alloc_nr_used_pages(void)
+{
+	return (cur - base) >> PAGE_SHIFT;
+}
+
+void *hyp_early_alloc_contig(unsigned int nr_pages)
+{
+	unsigned long size = (nr_pages << PAGE_SHIFT);
+	void *ret = (void *)cur;
+
+	if (!nr_pages)
+		return NULL;
+
+	if (end - cur < size)
+		return NULL;
+
+	cur += size;
+	memset(ret, 0, size);
+
+	return ret;
+}
+
+void *hyp_early_alloc_page(void *arg)
+{
+	return hyp_early_alloc_contig(1);
+}
+
+void hyp_early_alloc_init(void *virt, unsigned long size)
+{
+	base = cur = (unsigned long)virt;
+	end = base + size;
+
+	hyp_early_alloc_mm_ops.zalloc_page = hyp_early_alloc_page;
+	hyp_early_alloc_mm_ops.phys_to_virt = hyp_phys_to_virt;
+	hyp_early_alloc_mm_ops.virt_to_phys = hyp_virt_to_phys;
+}
diff --git a/arch/arm64/kvm/hyp/nvhe/gen-hyprel.c b/arch/arm64/kvm/hyp/nvhe/gen-hyprel.c
index ead02c6..6bc88a7 100644
--- a/arch/arm64/kvm/hyp/nvhe/gen-hyprel.c
+++ b/arch/arm64/kvm/hyp/nvhe/gen-hyprel.c
@@ -50,6 +50,18 @@
 #ifndef R_AARCH64_ABS64
 #define R_AARCH64_ABS64			257
 #endif
+#ifndef R_AARCH64_PREL64
+#define R_AARCH64_PREL64		260
+#endif
+#ifndef R_AARCH64_PREL32
+#define R_AARCH64_PREL32		261
+#endif
+#ifndef R_AARCH64_PREL16
+#define R_AARCH64_PREL16		262
+#endif
+#ifndef R_AARCH64_PLT32
+#define R_AARCH64_PLT32			314
+#endif
 #ifndef R_AARCH64_LD_PREL_LO19
 #define R_AARCH64_LD_PREL_LO19		273
 #endif
@@ -371,6 +383,12 @@
 		case R_AARCH64_ABS64:
 			emit_rela_abs64(rela, sh_orig_name);
 			break;
+		/* Allow position-relative data relocations. */
+		case R_AARCH64_PREL64:
+		case R_AARCH64_PREL32:
+		case R_AARCH64_PREL16:
+		case R_AARCH64_PLT32:
+			break;
 		/* Allow relocations to generate PC-relative addressing. */
 		case R_AARCH64_LD_PREL_LO19:
 		case R_AARCH64_ADR_PREL_LO21:
diff --git a/arch/arm64/kvm/hyp/nvhe/host.S b/arch/arm64/kvm/hyp/nvhe/host.S
index 5d94584..2b23400 100644
--- a/arch/arm64/kvm/hyp/nvhe/host.S
+++ b/arch/arm64/kvm/hyp/nvhe/host.S
@@ -79,22 +79,18 @@
 	mov	lr, #(PSR_F_BIT | PSR_I_BIT | PSR_A_BIT | PSR_D_BIT |\
 		      PSR_MODE_EL1h)
 	msr	spsr_el2, lr
-	ldr	lr, =panic
+	ldr	lr, =nvhe_hyp_panic_handler
 	hyp_kimg_va lr, x6
 	msr	elr_el2, lr
 
 	mov	x29, x0
 
-	/* Load the format string into x0 and arguments into x1-7 */
-	ldr	x0, =__hyp_panic_string
-	hyp_kimg_va x0, x6
-
-	/* Load the format arguments into x1-7. */
-	mov	x6, x3
-	get_vcpu_ptr x7, x3
-	mrs	x3, esr_el2
-	mrs	x4, far_el2
-	mrs	x5, hpfar_el2
+	/* Load the panic arguments into x0-7 */
+	mrs	x0, esr_el2
+	get_vcpu_ptr x4, x5
+	mrs	x5, far_el2
+	mrs	x6, hpfar_el2
+	mov	x7, xzr			// Unused argument
 
 	/* Enter the host, conditionally restoring the host context. */
 	cbz	x29, __host_enter_without_restoring
diff --git a/arch/arm64/kvm/hyp/nvhe/hyp-init.S b/arch/arm64/kvm/hyp/nvhe/hyp-init.S
index c631e29..c953fb4b 100644
--- a/arch/arm64/kvm/hyp/nvhe/hyp-init.S
+++ b/arch/arm64/kvm/hyp/nvhe/hyp-init.S
@@ -83,11 +83,6 @@
  * x0: struct kvm_nvhe_init_params PA
  */
 SYM_CODE_START_LOCAL(___kvm_hyp_init)
-alternative_if ARM64_KVM_PROTECTED_MODE
-	mov_q	x1, HCR_HOST_NVHE_PROTECTED_FLAGS
-	msr	hcr_el2, x1
-alternative_else_nop_endif
-
 	ldr	x1, [x0, #NVHE_INIT_TPIDR_EL2]
 	msr	tpidr_el2, x1
 
@@ -97,6 +92,15 @@
 	ldr	x1, [x0, #NVHE_INIT_MAIR_EL2]
 	msr	mair_el2, x1
 
+	ldr	x1, [x0, #NVHE_INIT_HCR_EL2]
+	msr	hcr_el2, x1
+
+	ldr	x1, [x0, #NVHE_INIT_VTTBR]
+	msr	vttbr_el2, x1
+
+	ldr	x1, [x0, #NVHE_INIT_VTCR]
+	msr	vtcr_el2, x1
+
 	ldr	x1, [x0, #NVHE_INIT_PGD_PA]
 	phys_to_ttbr x2, x1
 alternative_if ARM64_HAS_CNP
@@ -115,15 +119,10 @@
 
 	/* Invalidate the stale TLBs from Bootloader */
 	tlbi	alle2
+	tlbi	vmalls12e1
 	dsb	sy
 
-	/*
-	 * Preserve all the RES1 bits while setting the default flags,
-	 * as well as the EE bit on BE. Drop the A flag since the compiler
-	 * is allowed to generate unaligned accesses.
-	 */
-	mov_q	x0, (SCTLR_EL2_RES1 | (SCTLR_ELx_FLAGS & ~SCTLR_ELx_A))
-CPU_BE(	orr	x0, x0, #SCTLR_ELx_EE)
+	mov_q	x0, INIT_SCTLR_EL2_MMU_ON
 alternative_if ARM64_HAS_ADDRESS_AUTH
 	mov_q	x1, (SCTLR_ELx_ENIA | SCTLR_ELx_ENIB | \
 		     SCTLR_ELx_ENDA | SCTLR_ELx_ENDB)
@@ -221,9 +220,7 @@
 	mov	x0, xzr
 reset:
 	/* Reset kvm back to the hyp stub. */
-	mrs	x5, sctlr_el2
-	mov_q	x6, SCTLR_ELx_FLAGS
-	bic	x5, x5, x6		// Clear SCTL_M and etc
+	mov_q	x5, INIT_SCTLR_EL2_MMU_OFF
 	pre_disable_mmu_workaround
 	msr	sctlr_el2, x5
 	isb
@@ -244,4 +241,31 @@
 
 SYM_CODE_END(__kvm_handle_stub_hvc)
 
+SYM_FUNC_START(__pkvm_init_switch_pgd)
+	/* Turn the MMU off */
+	pre_disable_mmu_workaround
+	mrs	x2, sctlr_el2
+	bic	x3, x2, #SCTLR_ELx_M
+	msr	sctlr_el2, x3
+	isb
+
+	tlbi	alle2
+
+	/* Install the new pgtables */
+	ldr	x3, [x0, #NVHE_INIT_PGD_PA]
+	phys_to_ttbr x4, x3
+alternative_if ARM64_HAS_CNP
+	orr	x4, x4, #TTBR_CNP_BIT
+alternative_else_nop_endif
+	msr	ttbr0_el2, x4
+
+	/* Set the new stack pointer */
+	ldr	x0, [x0, #NVHE_INIT_STACK_HYP_VA]
+	mov	sp, x0
+
+	/* And turn the MMU back on! */
+	set_sctlr_el2	x2
+	ret	x1
+SYM_FUNC_END(__pkvm_init_switch_pgd)
+
 	.popsection
diff --git a/arch/arm64/kvm/hyp/nvhe/hyp-main.c b/arch/arm64/kvm/hyp/nvhe/hyp-main.c
index 93632820..f36420a 100644
--- a/arch/arm64/kvm/hyp/nvhe/hyp-main.c
+++ b/arch/arm64/kvm/hyp/nvhe/hyp-main.c
@@ -6,12 +6,15 @@
 
 #include <hyp/switch.h>
 
+#include <asm/pgtable-types.h>
 #include <asm/kvm_asm.h>
 #include <asm/kvm_emulate.h>
 #include <asm/kvm_host.h>
 #include <asm/kvm_hyp.h>
 #include <asm/kvm_mmu.h>
 
+#include <nvhe/mem_protect.h>
+#include <nvhe/mm.h>
 #include <nvhe/trap_handler.h>
 
 DEFINE_PER_CPU(struct kvm_nvhe_init_params, kvm_init_params);
@@ -106,6 +109,61 @@
 	__vgic_v3_restore_aprs(kern_hyp_va(cpu_if));
 }
 
+static void handle___pkvm_init(struct kvm_cpu_context *host_ctxt)
+{
+	DECLARE_REG(phys_addr_t, phys, host_ctxt, 1);
+	DECLARE_REG(unsigned long, size, host_ctxt, 2);
+	DECLARE_REG(unsigned long, nr_cpus, host_ctxt, 3);
+	DECLARE_REG(unsigned long *, per_cpu_base, host_ctxt, 4);
+	DECLARE_REG(u32, hyp_va_bits, host_ctxt, 5);
+
+	/*
+	 * __pkvm_init() will return only if an error occurred, otherwise it
+	 * will tail-call in __pkvm_init_finalise() which will have to deal
+	 * with the host context directly.
+	 */
+	cpu_reg(host_ctxt, 1) = __pkvm_init(phys, size, nr_cpus, per_cpu_base,
+					    hyp_va_bits);
+}
+
+static void handle___pkvm_cpu_set_vector(struct kvm_cpu_context *host_ctxt)
+{
+	DECLARE_REG(enum arm64_hyp_spectre_vector, slot, host_ctxt, 1);
+
+	cpu_reg(host_ctxt, 1) = pkvm_cpu_set_vector(slot);
+}
+
+static void handle___pkvm_create_mappings(struct kvm_cpu_context *host_ctxt)
+{
+	DECLARE_REG(unsigned long, start, host_ctxt, 1);
+	DECLARE_REG(unsigned long, size, host_ctxt, 2);
+	DECLARE_REG(unsigned long, phys, host_ctxt, 3);
+	DECLARE_REG(enum kvm_pgtable_prot, prot, host_ctxt, 4);
+
+	cpu_reg(host_ctxt, 1) = __pkvm_create_mappings(start, size, phys, prot);
+}
+
+static void handle___pkvm_create_private_mapping(struct kvm_cpu_context *host_ctxt)
+{
+	DECLARE_REG(phys_addr_t, phys, host_ctxt, 1);
+	DECLARE_REG(size_t, size, host_ctxt, 2);
+	DECLARE_REG(enum kvm_pgtable_prot, prot, host_ctxt, 3);
+
+	cpu_reg(host_ctxt, 1) = __pkvm_create_private_mapping(phys, size, prot);
+}
+
+static void handle___pkvm_prot_finalize(struct kvm_cpu_context *host_ctxt)
+{
+	cpu_reg(host_ctxt, 1) = __pkvm_prot_finalize();
+}
+
+static void handle___pkvm_mark_hyp(struct kvm_cpu_context *host_ctxt)
+{
+	DECLARE_REG(phys_addr_t, start, host_ctxt, 1);
+	DECLARE_REG(phys_addr_t, end, host_ctxt, 2);
+
+	cpu_reg(host_ctxt, 1) = __pkvm_mark_hyp(start, end);
+}
 typedef void (*hcall_t)(struct kvm_cpu_context *);
 
 #define HANDLE_FUNC(x)	[__KVM_HOST_SMCCC_FUNC_##x] = (hcall_t)handle_##x
@@ -125,6 +183,12 @@
 	HANDLE_FUNC(__kvm_get_mdcr_el2),
 	HANDLE_FUNC(__vgic_v3_save_aprs),
 	HANDLE_FUNC(__vgic_v3_restore_aprs),
+	HANDLE_FUNC(__pkvm_init),
+	HANDLE_FUNC(__pkvm_cpu_set_vector),
+	HANDLE_FUNC(__pkvm_create_mappings),
+	HANDLE_FUNC(__pkvm_create_private_mapping),
+	HANDLE_FUNC(__pkvm_prot_finalize),
+	HANDLE_FUNC(__pkvm_mark_hyp),
 };
 
 static void handle_host_hcall(struct kvm_cpu_context *host_ctxt)
@@ -177,7 +241,16 @@
 	case ESR_ELx_EC_SMC64:
 		handle_host_smc(host_ctxt);
 		break;
+	case ESR_ELx_EC_SVE:
+		sysreg_clear_set(cptr_el2, CPTR_EL2_TZ, 0);
+		isb();
+		sve_cond_update_zcr_vq(ZCR_ELx_LEN_MASK, SYS_ZCR_EL2);
+		break;
+	case ESR_ELx_EC_IABT_LOW:
+	case ESR_ELx_EC_DABT_LOW:
+		handle_host_mem_abort(host_ctxt);
+		break;
 	default:
-		hyp_panic();
+		BUG();
 	}
 }
diff --git a/arch/arm64/kvm/hyp/nvhe/hyp-smp.c b/arch/arm64/kvm/hyp/nvhe/hyp-smp.c
index 8795590..9f54833 100644
--- a/arch/arm64/kvm/hyp/nvhe/hyp-smp.c
+++ b/arch/arm64/kvm/hyp/nvhe/hyp-smp.c
@@ -18,8 +18,7 @@
 
 u64 cpu_logical_map(unsigned int cpu)
 {
-	if (cpu >= ARRAY_SIZE(hyp_cpu_logical_map))
-		hyp_panic();
+	BUG_ON(cpu >= ARRAY_SIZE(hyp_cpu_logical_map));
 
 	return hyp_cpu_logical_map[cpu];
 }
@@ -30,8 +29,7 @@
 	unsigned long this_cpu_base;
 	unsigned long elf_base;
 
-	if (cpu >= ARRAY_SIZE(kvm_arm_hyp_percpu_base))
-		hyp_panic();
+	BUG_ON(cpu >= ARRAY_SIZE(kvm_arm_hyp_percpu_base));
 
 	cpu_base_array = (unsigned long *)&kvm_arm_hyp_percpu_base;
 	this_cpu_base = kern_hyp_va(cpu_base_array[cpu]);
diff --git a/arch/arm64/kvm/hyp/nvhe/hyp.lds.S b/arch/arm64/kvm/hyp/nvhe/hyp.lds.S
index cd119d8..f4562f4 100644
--- a/arch/arm64/kvm/hyp/nvhe/hyp.lds.S
+++ b/arch/arm64/kvm/hyp/nvhe/hyp.lds.S
@@ -25,4 +25,5 @@
 	BEGIN_HYP_SECTION(.data..percpu)
 		PERCPU_INPUT(L1_CACHE_BYTES)
 	END_HYP_SECTION
+	HYP_SECTION(.bss)
 }
diff --git a/arch/arm64/kvm/hyp/nvhe/mem_protect.c b/arch/arm64/kvm/hyp/nvhe/mem_protect.c
new file mode 100644
index 0000000..e342f7f
--- /dev/null
+++ b/arch/arm64/kvm/hyp/nvhe/mem_protect.c
@@ -0,0 +1,279 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Copyright (C) 2020 Google LLC
+ * Author: Quentin Perret <qperret@google.com>
+ */
+
+#include <linux/kvm_host.h>
+#include <asm/kvm_emulate.h>
+#include <asm/kvm_hyp.h>
+#include <asm/kvm_mmu.h>
+#include <asm/kvm_pgtable.h>
+#include <asm/stage2_pgtable.h>
+
+#include <hyp/switch.h>
+
+#include <nvhe/gfp.h>
+#include <nvhe/memory.h>
+#include <nvhe/mem_protect.h>
+#include <nvhe/mm.h>
+
+#define KVM_HOST_S2_FLAGS (KVM_PGTABLE_S2_NOFWB | KVM_PGTABLE_S2_IDMAP)
+
+extern unsigned long hyp_nr_cpus;
+struct host_kvm host_kvm;
+
+struct hyp_pool host_s2_mem;
+struct hyp_pool host_s2_dev;
+
+/*
+ * Copies of the host's CPU features registers holding sanitized values.
+ */
+u64 id_aa64mmfr0_el1_sys_val;
+u64 id_aa64mmfr1_el1_sys_val;
+
+static const u8 pkvm_hyp_id = 1;
+
+static void *host_s2_zalloc_pages_exact(size_t size)
+{
+	return hyp_alloc_pages(&host_s2_mem, get_order(size));
+}
+
+static void *host_s2_zalloc_page(void *pool)
+{
+	return hyp_alloc_pages(pool, 0);
+}
+
+static int prepare_s2_pools(void *mem_pgt_pool, void *dev_pgt_pool)
+{
+	unsigned long nr_pages, pfn;
+	int ret;
+
+	pfn = hyp_virt_to_pfn(mem_pgt_pool);
+	nr_pages = host_s2_mem_pgtable_pages();
+	ret = hyp_pool_init(&host_s2_mem, pfn, nr_pages, 0);
+	if (ret)
+		return ret;
+
+	pfn = hyp_virt_to_pfn(dev_pgt_pool);
+	nr_pages = host_s2_dev_pgtable_pages();
+	ret = hyp_pool_init(&host_s2_dev, pfn, nr_pages, 0);
+	if (ret)
+		return ret;
+
+	host_kvm.mm_ops = (struct kvm_pgtable_mm_ops) {
+		.zalloc_pages_exact = host_s2_zalloc_pages_exact,
+		.zalloc_page = host_s2_zalloc_page,
+		.phys_to_virt = hyp_phys_to_virt,
+		.virt_to_phys = hyp_virt_to_phys,
+		.page_count = hyp_page_count,
+		.get_page = hyp_get_page,
+		.put_page = hyp_put_page,
+	};
+
+	return 0;
+}
+
+static void prepare_host_vtcr(void)
+{
+	u32 parange, phys_shift;
+
+	/* The host stage 2 is id-mapped, so use parange for T0SZ */
+	parange = kvm_get_parange(id_aa64mmfr0_el1_sys_val);
+	phys_shift = id_aa64mmfr0_parange_to_phys_shift(parange);
+
+	host_kvm.arch.vtcr = kvm_get_vtcr(id_aa64mmfr0_el1_sys_val,
+					  id_aa64mmfr1_el1_sys_val, phys_shift);
+}
+
+int kvm_host_prepare_stage2(void *mem_pgt_pool, void *dev_pgt_pool)
+{
+	struct kvm_s2_mmu *mmu = &host_kvm.arch.mmu;
+	int ret;
+
+	prepare_host_vtcr();
+	hyp_spin_lock_init(&host_kvm.lock);
+
+	ret = prepare_s2_pools(mem_pgt_pool, dev_pgt_pool);
+	if (ret)
+		return ret;
+
+	ret = kvm_pgtable_stage2_init_flags(&host_kvm.pgt, &host_kvm.arch,
+					    &host_kvm.mm_ops, KVM_HOST_S2_FLAGS);
+	if (ret)
+		return ret;
+
+	mmu->pgd_phys = __hyp_pa(host_kvm.pgt.pgd);
+	mmu->arch = &host_kvm.arch;
+	mmu->pgt = &host_kvm.pgt;
+	mmu->vmid.vmid_gen = 0;
+	mmu->vmid.vmid = 0;
+
+	return 0;
+}
+
+int __pkvm_prot_finalize(void)
+{
+	struct kvm_s2_mmu *mmu = &host_kvm.arch.mmu;
+	struct kvm_nvhe_init_params *params = this_cpu_ptr(&kvm_init_params);
+
+	params->vttbr = kvm_get_vttbr(mmu);
+	params->vtcr = host_kvm.arch.vtcr;
+	params->hcr_el2 |= HCR_VM;
+	kvm_flush_dcache_to_poc(params, sizeof(*params));
+
+	write_sysreg(params->hcr_el2, hcr_el2);
+	__load_stage2(&host_kvm.arch.mmu, host_kvm.arch.vtcr);
+
+	/*
+	 * Make sure to have an ISB before the TLB maintenance below but only
+	 * when __load_stage2() doesn't include one already.
+	 */
+	asm(ALTERNATIVE("isb", "nop", ARM64_WORKAROUND_SPECULATIVE_AT));
+
+	/* Invalidate stale HCR bits that may be cached in TLBs */
+	__tlbi(vmalls12e1);
+	dsb(nsh);
+	isb();
+
+	return 0;
+}
+
+static int host_stage2_unmap_dev_all(void)
+{
+	struct kvm_pgtable *pgt = &host_kvm.pgt;
+	struct memblock_region *reg;
+	u64 addr = 0;
+	int i, ret;
+
+	/* Unmap all non-memory regions to recycle the pages */
+	for (i = 0; i < hyp_memblock_nr; i++, addr = reg->base + reg->size) {
+		reg = &hyp_memory[i];
+		ret = kvm_pgtable_stage2_unmap(pgt, addr, reg->base - addr);
+		if (ret)
+			return ret;
+	}
+	return kvm_pgtable_stage2_unmap(pgt, addr, BIT(pgt->ia_bits) - addr);
+}
+
+static bool find_mem_range(phys_addr_t addr, struct kvm_mem_range *range)
+{
+	int cur, left = 0, right = hyp_memblock_nr;
+	struct memblock_region *reg;
+	phys_addr_t end;
+
+	range->start = 0;
+	range->end = ULONG_MAX;
+
+	/* The list of memblock regions is sorted, binary search it */
+	while (left < right) {
+		cur = (left + right) >> 1;
+		reg = &hyp_memory[cur];
+		end = reg->base + reg->size;
+		if (addr < reg->base) {
+			right = cur;
+			range->end = reg->base;
+		} else if (addr >= end) {
+			left = cur + 1;
+			range->start = end;
+		} else {
+			range->start = reg->base;
+			range->end = end;
+			return true;
+		}
+	}
+
+	return false;
+}
+
+static bool range_is_memory(u64 start, u64 end)
+{
+	struct kvm_mem_range r1, r2;
+
+	if (!find_mem_range(start, &r1) || !find_mem_range(end, &r2))
+		return false;
+	if (r1.start != r2.start)
+		return false;
+
+	return true;
+}
+
+static inline int __host_stage2_idmap(u64 start, u64 end,
+				      enum kvm_pgtable_prot prot,
+				      struct hyp_pool *pool)
+{
+	return kvm_pgtable_stage2_map(&host_kvm.pgt, start, end - start, start,
+				      prot, pool);
+}
+
+static int host_stage2_idmap(u64 addr)
+{
+	enum kvm_pgtable_prot prot = KVM_PGTABLE_PROT_R | KVM_PGTABLE_PROT_W;
+	struct kvm_mem_range range;
+	bool is_memory = find_mem_range(addr, &range);
+	struct hyp_pool *pool = is_memory ? &host_s2_mem : &host_s2_dev;
+	int ret;
+
+	if (is_memory)
+		prot |= KVM_PGTABLE_PROT_X;
+
+	hyp_spin_lock(&host_kvm.lock);
+	ret = kvm_pgtable_stage2_find_range(&host_kvm.pgt, addr, prot, &range);
+	if (ret)
+		goto unlock;
+
+	ret = __host_stage2_idmap(range.start, range.end, prot, pool);
+	if (is_memory || ret != -ENOMEM)
+		goto unlock;
+
+	/*
+	 * host_s2_mem has been provided with enough pages to cover all of
+	 * memory with page granularity, so we should never hit the ENOMEM case.
+	 * However, it is difficult to know how much of the MMIO range we will
+	 * need to cover upfront, so we may need to 'recycle' the pages if we
+	 * run out.
+	 */
+	ret = host_stage2_unmap_dev_all();
+	if (ret)
+		goto unlock;
+
+	ret = __host_stage2_idmap(range.start, range.end, prot, pool);
+
+unlock:
+	hyp_spin_unlock(&host_kvm.lock);
+
+	return ret;
+}
+
+int __pkvm_mark_hyp(phys_addr_t start, phys_addr_t end)
+{
+	int ret;
+
+	/*
+	 * host_stage2_unmap_dev_all() currently relies on MMIO mappings being
+	 * non-persistent, so don't allow changing page ownership in MMIO range.
+	 */
+	if (!range_is_memory(start, end))
+		return -EINVAL;
+
+	hyp_spin_lock(&host_kvm.lock);
+	ret = kvm_pgtable_stage2_set_owner(&host_kvm.pgt, start, end - start,
+					   &host_s2_mem, pkvm_hyp_id);
+	hyp_spin_unlock(&host_kvm.lock);
+
+	return ret != -EAGAIN ? ret : 0;
+}
+
+void handle_host_mem_abort(struct kvm_cpu_context *host_ctxt)
+{
+	struct kvm_vcpu_fault_info fault;
+	u64 esr, addr;
+	int ret = 0;
+
+	esr = read_sysreg_el2(SYS_ESR);
+	BUG_ON(!__get_fault_info(esr, &fault));
+
+	addr = (fault.hpfar_el2 & HPFAR_MASK) << 8;
+	ret = host_stage2_idmap(addr);
+	BUG_ON(ret && ret != -EAGAIN);
+}
diff --git a/arch/arm64/kvm/hyp/nvhe/mm.c b/arch/arm64/kvm/hyp/nvhe/mm.c
new file mode 100644
index 0000000..a8efdf0
--- /dev/null
+++ b/arch/arm64/kvm/hyp/nvhe/mm.c
@@ -0,0 +1,173 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Copyright (C) 2020 Google LLC
+ * Author: Quentin Perret <qperret@google.com>
+ */
+
+#include <linux/kvm_host.h>
+#include <asm/kvm_hyp.h>
+#include <asm/kvm_mmu.h>
+#include <asm/kvm_pgtable.h>
+#include <asm/spectre.h>
+
+#include <nvhe/early_alloc.h>
+#include <nvhe/gfp.h>
+#include <nvhe/memory.h>
+#include <nvhe/mm.h>
+#include <nvhe/spinlock.h>
+
+struct kvm_pgtable pkvm_pgtable;
+hyp_spinlock_t pkvm_pgd_lock;
+u64 __io_map_base;
+
+struct memblock_region hyp_memory[HYP_MEMBLOCK_REGIONS];
+unsigned int hyp_memblock_nr;
+
+int __pkvm_create_mappings(unsigned long start, unsigned long size,
+			  unsigned long phys, enum kvm_pgtable_prot prot)
+{
+	int err;
+
+	hyp_spin_lock(&pkvm_pgd_lock);
+	err = kvm_pgtable_hyp_map(&pkvm_pgtable, start, size, phys, prot);
+	hyp_spin_unlock(&pkvm_pgd_lock);
+
+	return err;
+}
+
+unsigned long __pkvm_create_private_mapping(phys_addr_t phys, size_t size,
+					    enum kvm_pgtable_prot prot)
+{
+	unsigned long addr;
+	int err;
+
+	hyp_spin_lock(&pkvm_pgd_lock);
+
+	size = PAGE_ALIGN(size + offset_in_page(phys));
+	addr = __io_map_base;
+	__io_map_base += size;
+
+	/* Are we overflowing on the vmemmap ? */
+	if (__io_map_base > __hyp_vmemmap) {
+		__io_map_base -= size;
+		addr = (unsigned long)ERR_PTR(-ENOMEM);
+		goto out;
+	}
+
+	err = kvm_pgtable_hyp_map(&pkvm_pgtable, addr, size, phys, prot);
+	if (err) {
+		addr = (unsigned long)ERR_PTR(err);
+		goto out;
+	}
+
+	addr = addr + offset_in_page(phys);
+out:
+	hyp_spin_unlock(&pkvm_pgd_lock);
+
+	return addr;
+}
+
+int pkvm_create_mappings(void *from, void *to, enum kvm_pgtable_prot prot)
+{
+	unsigned long start = (unsigned long)from;
+	unsigned long end = (unsigned long)to;
+	unsigned long virt_addr;
+	phys_addr_t phys;
+
+	start = start & PAGE_MASK;
+	end = PAGE_ALIGN(end);
+
+	for (virt_addr = start; virt_addr < end; virt_addr += PAGE_SIZE) {
+		int err;
+
+		phys = hyp_virt_to_phys((void *)virt_addr);
+		err = __pkvm_create_mappings(virt_addr, PAGE_SIZE, phys, prot);
+		if (err)
+			return err;
+	}
+
+	return 0;
+}
+
+int hyp_back_vmemmap(phys_addr_t phys, unsigned long size, phys_addr_t back)
+{
+	unsigned long start, end;
+
+	hyp_vmemmap_range(phys, size, &start, &end);
+
+	return __pkvm_create_mappings(start, end - start, back, PAGE_HYP);
+}
+
+static void *__hyp_bp_vect_base;
+int pkvm_cpu_set_vector(enum arm64_hyp_spectre_vector slot)
+{
+	void *vector;
+
+	switch (slot) {
+	case HYP_VECTOR_DIRECT: {
+		vector = __kvm_hyp_vector;
+		break;
+	}
+	case HYP_VECTOR_SPECTRE_DIRECT: {
+		vector = __bp_harden_hyp_vecs;
+		break;
+	}
+	case HYP_VECTOR_INDIRECT:
+	case HYP_VECTOR_SPECTRE_INDIRECT: {
+		vector = (void *)__hyp_bp_vect_base;
+		break;
+	}
+	default:
+		return -EINVAL;
+	}
+
+	vector = __kvm_vector_slot2addr(vector, slot);
+	*this_cpu_ptr(&kvm_hyp_vector) = (unsigned long)vector;
+
+	return 0;
+}
+
+int hyp_map_vectors(void)
+{
+	phys_addr_t phys;
+	void *bp_base;
+
+	if (!cpus_have_const_cap(ARM64_SPECTRE_V3A))
+		return 0;
+
+	phys = __hyp_pa(__bp_harden_hyp_vecs);
+	bp_base = (void *)__pkvm_create_private_mapping(phys,
+							__BP_HARDEN_HYP_VECS_SZ,
+							PAGE_HYP_EXEC);
+	if (IS_ERR_OR_NULL(bp_base))
+		return PTR_ERR(bp_base);
+
+	__hyp_bp_vect_base = bp_base;
+
+	return 0;
+}
+
+int hyp_create_idmap(u32 hyp_va_bits)
+{
+	unsigned long start, end;
+
+	start = hyp_virt_to_phys((void *)__hyp_idmap_text_start);
+	start = ALIGN_DOWN(start, PAGE_SIZE);
+
+	end = hyp_virt_to_phys((void *)__hyp_idmap_text_end);
+	end = ALIGN(end, PAGE_SIZE);
+
+	/*
+	 * One half of the VA space is reserved to linearly map portions of
+	 * memory -- see va_layout.c for more details. The other half of the VA
+	 * space contains the trampoline page, and needs some care. Split that
+	 * second half in two and find the quarter of VA space not conflicting
+	 * with the idmap to place the IOs and the vmemmap. IOs use the lower
+	 * half of the quarter and the vmemmap the upper half.
+	 */
+	__io_map_base = start & BIT(hyp_va_bits - 2);
+	__io_map_base ^= BIT(hyp_va_bits - 2);
+	__hyp_vmemmap = __io_map_base | BIT(hyp_va_bits - 3);
+
+	return __pkvm_create_mappings(start, end - start, start, PAGE_HYP_EXEC);
+}
diff --git a/arch/arm64/kvm/hyp/nvhe/page_alloc.c b/arch/arm64/kvm/hyp/nvhe/page_alloc.c
new file mode 100644
index 0000000..237e03b
--- /dev/null
+++ b/arch/arm64/kvm/hyp/nvhe/page_alloc.c
@@ -0,0 +1,195 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Copyright (C) 2020 Google LLC
+ * Author: Quentin Perret <qperret@google.com>
+ */
+
+#include <asm/kvm_hyp.h>
+#include <nvhe/gfp.h>
+
+u64 __hyp_vmemmap;
+
+/*
+ * Index the hyp_vmemmap to find a potential buddy page, but make no assumption
+ * about its current state.
+ *
+ * Example buddy-tree for a 4-pages physically contiguous pool:
+ *
+ *                 o : Page 3
+ *                /
+ *               o-o : Page 2
+ *              /
+ *             /   o : Page 1
+ *            /   /
+ *           o---o-o : Page 0
+ *    Order  2   1 0
+ *
+ * Example of requests on this pool:
+ *   __find_buddy_nocheck(pool, page 0, order 0) => page 1
+ *   __find_buddy_nocheck(pool, page 0, order 1) => page 2
+ *   __find_buddy_nocheck(pool, page 1, order 0) => page 0
+ *   __find_buddy_nocheck(pool, page 2, order 0) => page 3
+ */
+static struct hyp_page *__find_buddy_nocheck(struct hyp_pool *pool,
+					     struct hyp_page *p,
+					     unsigned int order)
+{
+	phys_addr_t addr = hyp_page_to_phys(p);
+
+	addr ^= (PAGE_SIZE << order);
+
+	/*
+	 * Don't return a page outside the pool range -- it belongs to
+	 * something else and may not be mapped in hyp_vmemmap.
+	 */
+	if (addr < pool->range_start || addr >= pool->range_end)
+		return NULL;
+
+	return hyp_phys_to_page(addr);
+}
+
+/* Find a buddy page currently available for allocation */
+static struct hyp_page *__find_buddy_avail(struct hyp_pool *pool,
+					   struct hyp_page *p,
+					   unsigned int order)
+{
+	struct hyp_page *buddy = __find_buddy_nocheck(pool, p, order);
+
+	if (!buddy || buddy->order != order || list_empty(&buddy->node))
+		return NULL;
+
+	return buddy;
+
+}
+
+static void __hyp_attach_page(struct hyp_pool *pool,
+			      struct hyp_page *p)
+{
+	unsigned int order = p->order;
+	struct hyp_page *buddy;
+
+	memset(hyp_page_to_virt(p), 0, PAGE_SIZE << p->order);
+
+	/*
+	 * Only the first struct hyp_page of a high-order page (otherwise known
+	 * as the 'head') should have p->order set. The non-head pages should
+	 * have p->order = HYP_NO_ORDER. Here @p may no longer be the head
+	 * after coallescing, so make sure to mark it HYP_NO_ORDER proactively.
+	 */
+	p->order = HYP_NO_ORDER;
+	for (; (order + 1) < pool->max_order; order++) {
+		buddy = __find_buddy_avail(pool, p, order);
+		if (!buddy)
+			break;
+
+		/* Take the buddy out of its list, and coallesce with @p */
+		list_del_init(&buddy->node);
+		buddy->order = HYP_NO_ORDER;
+		p = min(p, buddy);
+	}
+
+	/* Mark the new head, and insert it */
+	p->order = order;
+	list_add_tail(&p->node, &pool->free_area[order]);
+}
+
+static void hyp_attach_page(struct hyp_page *p)
+{
+	struct hyp_pool *pool = hyp_page_to_pool(p);
+
+	hyp_spin_lock(&pool->lock);
+	__hyp_attach_page(pool, p);
+	hyp_spin_unlock(&pool->lock);
+}
+
+static struct hyp_page *__hyp_extract_page(struct hyp_pool *pool,
+					   struct hyp_page *p,
+					   unsigned int order)
+{
+	struct hyp_page *buddy;
+
+	list_del_init(&p->node);
+	while (p->order > order) {
+		/*
+		 * The buddy of order n - 1 currently has HYP_NO_ORDER as it
+		 * is covered by a higher-level page (whose head is @p). Use
+		 * __find_buddy_nocheck() to find it and inject it in the
+		 * free_list[n - 1], effectively splitting @p in half.
+		 */
+		p->order--;
+		buddy = __find_buddy_nocheck(pool, p, p->order);
+		buddy->order = p->order;
+		list_add_tail(&buddy->node, &pool->free_area[buddy->order]);
+	}
+
+	return p;
+}
+
+void hyp_put_page(void *addr)
+{
+	struct hyp_page *p = hyp_virt_to_page(addr);
+
+	if (hyp_page_ref_dec_and_test(p))
+		hyp_attach_page(p);
+}
+
+void hyp_get_page(void *addr)
+{
+	struct hyp_page *p = hyp_virt_to_page(addr);
+
+	hyp_page_ref_inc(p);
+}
+
+void *hyp_alloc_pages(struct hyp_pool *pool, unsigned int order)
+{
+	unsigned int i = order;
+	struct hyp_page *p;
+
+	hyp_spin_lock(&pool->lock);
+
+	/* Look for a high-enough-order page */
+	while (i < pool->max_order && list_empty(&pool->free_area[i]))
+		i++;
+	if (i >= pool->max_order) {
+		hyp_spin_unlock(&pool->lock);
+		return NULL;
+	}
+
+	/* Extract it from the tree at the right order */
+	p = list_first_entry(&pool->free_area[i], struct hyp_page, node);
+	p = __hyp_extract_page(pool, p, order);
+
+	hyp_spin_unlock(&pool->lock);
+	hyp_set_page_refcounted(p);
+
+	return hyp_page_to_virt(p);
+}
+
+int hyp_pool_init(struct hyp_pool *pool, u64 pfn, unsigned int nr_pages,
+		  unsigned int reserved_pages)
+{
+	phys_addr_t phys = hyp_pfn_to_phys(pfn);
+	struct hyp_page *p;
+	int i;
+
+	hyp_spin_lock_init(&pool->lock);
+	pool->max_order = min(MAX_ORDER, get_order(nr_pages << PAGE_SHIFT));
+	for (i = 0; i < pool->max_order; i++)
+		INIT_LIST_HEAD(&pool->free_area[i]);
+	pool->range_start = phys;
+	pool->range_end = phys + (nr_pages << PAGE_SHIFT);
+
+	/* Init the vmemmap portion */
+	p = hyp_phys_to_page(phys);
+	memset(p, 0, sizeof(*p) * nr_pages);
+	for (i = 0; i < nr_pages; i++) {
+		p[i].pool = pool;
+		INIT_LIST_HEAD(&p[i].node);
+	}
+
+	/* Attach the unused pages to the buddy tree */
+	for (i = reserved_pages; i < nr_pages; i++)
+		__hyp_attach_page(pool, &p[i]);
+
+	return 0;
+}
diff --git a/arch/arm64/kvm/hyp/nvhe/psci-relay.c b/arch/arm64/kvm/hyp/nvhe/psci-relay.c
index 63de71c..0850878 100644
--- a/arch/arm64/kvm/hyp/nvhe/psci-relay.c
+++ b/arch/arm64/kvm/hyp/nvhe/psci-relay.c
@@ -11,6 +11,7 @@
 #include <linux/kvm_host.h>
 #include <uapi/linux/psci.h>
 
+#include <nvhe/memory.h>
 #include <nvhe/trap_handler.h>
 
 void kvm_hyp_cpu_entry(unsigned long r0);
@@ -20,9 +21,6 @@
 
 /* Config options set by the host. */
 struct kvm_host_psci_config __ro_after_init kvm_host_psci_config;
-s64 __ro_after_init hyp_physvirt_offset;
-
-#define __hyp_pa(x) ((phys_addr_t)((x)) + hyp_physvirt_offset)
 
 #define INVALID_CPU_ID	UINT_MAX
 
diff --git a/arch/arm64/kvm/hyp/nvhe/setup.c b/arch/arm64/kvm/hyp/nvhe/setup.c
new file mode 100644
index 0000000..7488f53
--- /dev/null
+++ b/arch/arm64/kvm/hyp/nvhe/setup.c
@@ -0,0 +1,214 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Copyright (C) 2020 Google LLC
+ * Author: Quentin Perret <qperret@google.com>
+ */
+
+#include <linux/kvm_host.h>
+#include <asm/kvm_hyp.h>
+#include <asm/kvm_mmu.h>
+#include <asm/kvm_pgtable.h>
+
+#include <nvhe/early_alloc.h>
+#include <nvhe/gfp.h>
+#include <nvhe/memory.h>
+#include <nvhe/mem_protect.h>
+#include <nvhe/mm.h>
+#include <nvhe/trap_handler.h>
+
+struct hyp_pool hpool;
+struct kvm_pgtable_mm_ops pkvm_pgtable_mm_ops;
+unsigned long hyp_nr_cpus;
+
+#define hyp_percpu_size ((unsigned long)__per_cpu_end - \
+			 (unsigned long)__per_cpu_start)
+
+static void *vmemmap_base;
+static void *hyp_pgt_base;
+static void *host_s2_mem_pgt_base;
+static void *host_s2_dev_pgt_base;
+
+static int divide_memory_pool(void *virt, unsigned long size)
+{
+	unsigned long vstart, vend, nr_pages;
+
+	hyp_early_alloc_init(virt, size);
+
+	hyp_vmemmap_range(__hyp_pa(virt), size, &vstart, &vend);
+	nr_pages = (vend - vstart) >> PAGE_SHIFT;
+	vmemmap_base = hyp_early_alloc_contig(nr_pages);
+	if (!vmemmap_base)
+		return -ENOMEM;
+
+	nr_pages = hyp_s1_pgtable_pages();
+	hyp_pgt_base = hyp_early_alloc_contig(nr_pages);
+	if (!hyp_pgt_base)
+		return -ENOMEM;
+
+	nr_pages = host_s2_mem_pgtable_pages();
+	host_s2_mem_pgt_base = hyp_early_alloc_contig(nr_pages);
+	if (!host_s2_mem_pgt_base)
+		return -ENOMEM;
+
+	nr_pages = host_s2_dev_pgtable_pages();
+	host_s2_dev_pgt_base = hyp_early_alloc_contig(nr_pages);
+	if (!host_s2_dev_pgt_base)
+		return -ENOMEM;
+
+	return 0;
+}
+
+static int recreate_hyp_mappings(phys_addr_t phys, unsigned long size,
+				 unsigned long *per_cpu_base,
+				 u32 hyp_va_bits)
+{
+	void *start, *end, *virt = hyp_phys_to_virt(phys);
+	unsigned long pgt_size = hyp_s1_pgtable_pages() << PAGE_SHIFT;
+	int ret, i;
+
+	/* Recreate the hyp page-table using the early page allocator */
+	hyp_early_alloc_init(hyp_pgt_base, pgt_size);
+	ret = kvm_pgtable_hyp_init(&pkvm_pgtable, hyp_va_bits,
+				   &hyp_early_alloc_mm_ops);
+	if (ret)
+		return ret;
+
+	ret = hyp_create_idmap(hyp_va_bits);
+	if (ret)
+		return ret;
+
+	ret = hyp_map_vectors();
+	if (ret)
+		return ret;
+
+	ret = hyp_back_vmemmap(phys, size, hyp_virt_to_phys(vmemmap_base));
+	if (ret)
+		return ret;
+
+	ret = pkvm_create_mappings(__hyp_text_start, __hyp_text_end, PAGE_HYP_EXEC);
+	if (ret)
+		return ret;
+
+	ret = pkvm_create_mappings(__start_rodata, __end_rodata, PAGE_HYP_RO);
+	if (ret)
+		return ret;
+
+	ret = pkvm_create_mappings(__hyp_rodata_start, __hyp_rodata_end, PAGE_HYP_RO);
+	if (ret)
+		return ret;
+
+	ret = pkvm_create_mappings(__hyp_bss_start, __hyp_bss_end, PAGE_HYP);
+	if (ret)
+		return ret;
+
+	ret = pkvm_create_mappings(__hyp_bss_end, __bss_stop, PAGE_HYP_RO);
+	if (ret)
+		return ret;
+
+	ret = pkvm_create_mappings(virt, virt + size, PAGE_HYP);
+	if (ret)
+		return ret;
+
+	for (i = 0; i < hyp_nr_cpus; i++) {
+		start = (void *)kern_hyp_va(per_cpu_base[i]);
+		end = start + PAGE_ALIGN(hyp_percpu_size);
+		ret = pkvm_create_mappings(start, end, PAGE_HYP);
+		if (ret)
+			return ret;
+
+		end = (void *)per_cpu_ptr(&kvm_init_params, i)->stack_hyp_va;
+		start = end - PAGE_SIZE;
+		ret = pkvm_create_mappings(start, end, PAGE_HYP);
+		if (ret)
+			return ret;
+	}
+
+	return 0;
+}
+
+static void update_nvhe_init_params(void)
+{
+	struct kvm_nvhe_init_params *params;
+	unsigned long i;
+
+	for (i = 0; i < hyp_nr_cpus; i++) {
+		params = per_cpu_ptr(&kvm_init_params, i);
+		params->pgd_pa = __hyp_pa(pkvm_pgtable.pgd);
+		__flush_dcache_area(params, sizeof(*params));
+	}
+}
+
+static void *hyp_zalloc_hyp_page(void *arg)
+{
+	return hyp_alloc_pages(&hpool, 0);
+}
+
+void __noreturn __pkvm_init_finalise(void)
+{
+	struct kvm_host_data *host_data = this_cpu_ptr(&kvm_host_data);
+	struct kvm_cpu_context *host_ctxt = &host_data->host_ctxt;
+	unsigned long nr_pages, reserved_pages, pfn;
+	int ret;
+
+	/* Now that the vmemmap is backed, install the full-fledged allocator */
+	pfn = hyp_virt_to_pfn(hyp_pgt_base);
+	nr_pages = hyp_s1_pgtable_pages();
+	reserved_pages = hyp_early_alloc_nr_used_pages();
+	ret = hyp_pool_init(&hpool, pfn, nr_pages, reserved_pages);
+	if (ret)
+		goto out;
+
+	ret = kvm_host_prepare_stage2(host_s2_mem_pgt_base, host_s2_dev_pgt_base);
+	if (ret)
+		goto out;
+
+	pkvm_pgtable_mm_ops = (struct kvm_pgtable_mm_ops) {
+		.zalloc_page = hyp_zalloc_hyp_page,
+		.phys_to_virt = hyp_phys_to_virt,
+		.virt_to_phys = hyp_virt_to_phys,
+		.get_page = hyp_get_page,
+		.put_page = hyp_put_page,
+	};
+	pkvm_pgtable.mm_ops = &pkvm_pgtable_mm_ops;
+
+out:
+	/*
+	 * We tail-called to here from handle___pkvm_init() and will not return,
+	 * so make sure to propagate the return value to the host.
+	 */
+	cpu_reg(host_ctxt, 1) = ret;
+
+	__host_enter(host_ctxt);
+}
+
+int __pkvm_init(phys_addr_t phys, unsigned long size, unsigned long nr_cpus,
+		unsigned long *per_cpu_base, u32 hyp_va_bits)
+{
+	struct kvm_nvhe_init_params *params;
+	void *virt = hyp_phys_to_virt(phys);
+	void (*fn)(phys_addr_t params_pa, void *finalize_fn_va);
+	int ret;
+
+	if (!PAGE_ALIGNED(phys) || !PAGE_ALIGNED(size))
+		return -EINVAL;
+
+	hyp_spin_lock_init(&pkvm_pgd_lock);
+	hyp_nr_cpus = nr_cpus;
+
+	ret = divide_memory_pool(virt, size);
+	if (ret)
+		return ret;
+
+	ret = recreate_hyp_mappings(phys, size, per_cpu_base, hyp_va_bits);
+	if (ret)
+		return ret;
+
+	update_nvhe_init_params();
+
+	/* Jump in the idmap page to switch to the new page-tables */
+	params = this_cpu_ptr(&kvm_init_params);
+	fn = (typeof(fn))__hyp_pa(__pkvm_init_switch_pgd);
+	fn(__hyp_pa(params), __pkvm_init_finalise);
+
+	unreachable();
+}
diff --git a/arch/arm64/kvm/hyp/nvhe/stub.c b/arch/arm64/kvm/hyp/nvhe/stub.c
new file mode 100644
index 0000000..c0aa6bb
--- /dev/null
+++ b/arch/arm64/kvm/hyp/nvhe/stub.c
@@ -0,0 +1,22 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Stubs for out-of-line function calls caused by re-using kernel
+ * infrastructure at EL2.
+ *
+ * Copyright (C) 2020 - Google LLC
+ */
+
+#include <linux/list.h>
+
+#ifdef CONFIG_DEBUG_LIST
+bool __list_add_valid(struct list_head *new, struct list_head *prev,
+		      struct list_head *next)
+{
+		return true;
+}
+
+bool __list_del_entry_valid(struct list_head *entry)
+{
+		return true;
+}
+#endif
diff --git a/arch/arm64/kvm/hyp/nvhe/switch.c b/arch/arm64/kvm/hyp/nvhe/switch.c
index 68ab6b4..e9f6ea7 100644
--- a/arch/arm64/kvm/hyp/nvhe/switch.c
+++ b/arch/arm64/kvm/hyp/nvhe/switch.c
@@ -28,6 +28,8 @@
 #include <asm/processor.h>
 #include <asm/thread_info.h>
 
+#include <nvhe/mem_protect.h>
+
 /* Non-VHE specific context */
 DEFINE_PER_CPU(struct kvm_host_data, kvm_host_data);
 DEFINE_PER_CPU(struct kvm_cpu_context, kvm_hyp_ctxt);
@@ -41,9 +43,9 @@
 	__activate_traps_common(vcpu);
 
 	val = CPTR_EL2_DEFAULT;
-	val |= CPTR_EL2_TTA | CPTR_EL2_TZ | CPTR_EL2_TAM;
+	val |= CPTR_EL2_TTA | CPTR_EL2_TAM;
 	if (!update_fp_enabled(vcpu)) {
-		val |= CPTR_EL2_TFP;
+		val |= CPTR_EL2_TFP | CPTR_EL2_TZ;
 		__activate_traps_fpsimd32(vcpu);
 	}
 
@@ -68,7 +70,7 @@
 static void __deactivate_traps(struct kvm_vcpu *vcpu)
 {
 	extern char __kvm_hyp_host_vector[];
-	u64 mdcr_el2;
+	u64 mdcr_el2, cptr;
 
 	___deactivate_traps(vcpu);
 
@@ -95,19 +97,17 @@
 
 	mdcr_el2 &= MDCR_EL2_HPMN_MASK;
 	mdcr_el2 |= MDCR_EL2_E2PB_MASK << MDCR_EL2_E2PB_SHIFT;
+	mdcr_el2 |= MDCR_EL2_E2TB_MASK << MDCR_EL2_E2TB_SHIFT;
 
 	write_sysreg(mdcr_el2, mdcr_el2);
-	if (is_protected_kvm_enabled())
-		write_sysreg(HCR_HOST_NVHE_PROTECTED_FLAGS, hcr_el2);
-	else
-		write_sysreg(HCR_HOST_NVHE_FLAGS, hcr_el2);
-	write_sysreg(CPTR_EL2_DEFAULT, cptr_el2);
-	write_sysreg(__kvm_hyp_host_vector, vbar_el2);
-}
+	write_sysreg(this_cpu_ptr(&kvm_init_params)->hcr_el2, hcr_el2);
 
-static void __load_host_stage2(void)
-{
-	write_sysreg(0, vttbr_el2);
+	cptr = CPTR_EL2_DEFAULT;
+	if (vcpu_has_sve(vcpu) && (vcpu->arch.flags & KVM_ARM64_FP_ENABLED))
+		cptr |= CPTR_EL2_TZ;
+
+	write_sysreg(cptr, cptr_el2);
+	write_sysreg(__kvm_hyp_host_vector, vbar_el2);
 }
 
 /* Save VGICv3 state on non-VHE systems */
diff --git a/arch/arm64/kvm/hyp/nvhe/tlb.c b/arch/arm64/kvm/hyp/nvhe/tlb.c
index 229b067..83dc3b2 100644
--- a/arch/arm64/kvm/hyp/nvhe/tlb.c
+++ b/arch/arm64/kvm/hyp/nvhe/tlb.c
@@ -8,6 +8,8 @@
 #include <asm/kvm_mmu.h>
 #include <asm/tlbflush.h>
 
+#include <nvhe/mem_protect.h>
+
 struct tlb_inv_context {
 	u64		tcr;
 };
@@ -43,7 +45,7 @@
 
 static void __tlb_switch_to_host(struct tlb_inv_context *cxt)
 {
-	write_sysreg(0, vttbr_el2);
+	__load_host_stage2();
 
 	if (cpus_have_final_cap(ARM64_WORKAROUND_SPECULATIVE_AT)) {
 		/* Ensure write of the host VMID */
diff --git a/arch/arm64/kvm/hyp/pgtable.c b/arch/arm64/kvm/hyp/pgtable.c
index 926fc07..c37c1dc 100644
--- a/arch/arm64/kvm/hyp/pgtable.c
+++ b/arch/arm64/kvm/hyp/pgtable.c
@@ -9,8 +9,7 @@
 
 #include <linux/bitfield.h>
 #include <asm/kvm_pgtable.h>
-
-#define KVM_PGTABLE_MAX_LEVELS		4U
+#include <asm/stage2_pgtable.h>
 
 #define KVM_PTE_VALID			BIT(0)
 
@@ -49,6 +48,11 @@
 					 KVM_PTE_LEAF_ATTR_LO_S2_S2AP_W | \
 					 KVM_PTE_LEAF_ATTR_HI_S2_XN)
 
+#define KVM_PTE_LEAF_ATTR_S2_IGNORED	GENMASK(58, 55)
+
+#define KVM_INVALID_PTE_OWNER_MASK	GENMASK(63, 56)
+#define KVM_MAX_OWNER_ID		1
+
 struct kvm_pgtable_walk_data {
 	struct kvm_pgtable		*pgt;
 	struct kvm_pgtable_walker	*walker;
@@ -68,21 +72,36 @@
 	return BIT(kvm_granule_shift(level));
 }
 
-static bool kvm_block_mapping_supported(u64 addr, u64 end, u64 phys, u32 level)
-{
-	u64 granule = kvm_granule_size(level);
+#define KVM_PHYS_INVALID (-1ULL)
 
+static bool kvm_phys_is_valid(u64 phys)
+{
+	return phys < BIT(id_aa64mmfr0_parange_to_phys_shift(ID_AA64MMFR0_PARANGE_MAX));
+}
+
+static bool kvm_level_supports_block_mapping(u32 level)
+{
 	/*
 	 * Reject invalid block mappings and don't bother with 4TB mappings for
 	 * 52-bit PAs.
 	 */
-	if (level == 0 || (PAGE_SIZE != SZ_4K && level == 1))
+	return !(level == 0 || (PAGE_SIZE != SZ_4K && level == 1));
+}
+
+static bool kvm_block_mapping_supported(u64 addr, u64 end, u64 phys, u32 level)
+{
+	u64 granule = kvm_granule_size(level);
+
+	if (!kvm_level_supports_block_mapping(level))
 		return false;
 
 	if (granule > (end - addr))
 		return false;
 
-	return IS_ALIGNED(addr, granule) && IS_ALIGNED(phys, granule);
+	if (kvm_phys_is_valid(phys) && !IS_ALIGNED(phys, granule))
+		return false;
+
+	return IS_ALIGNED(addr, granule);
 }
 
 static u32 kvm_pgtable_idx(struct kvm_pgtable_walk_data *data, u32 level)
@@ -152,20 +171,20 @@
 	return pte;
 }
 
-static kvm_pte_t *kvm_pte_follow(kvm_pte_t pte)
+static kvm_pte_t *kvm_pte_follow(kvm_pte_t pte, struct kvm_pgtable_mm_ops *mm_ops)
 {
-	return __va(kvm_pte_to_phys(pte));
+	return mm_ops->phys_to_virt(kvm_pte_to_phys(pte));
 }
 
-static void kvm_set_invalid_pte(kvm_pte_t *ptep)
+static void kvm_clear_pte(kvm_pte_t *ptep)
 {
-	kvm_pte_t pte = *ptep;
-	WRITE_ONCE(*ptep, pte & ~KVM_PTE_VALID);
+	WRITE_ONCE(*ptep, 0);
 }
 
-static void kvm_set_table_pte(kvm_pte_t *ptep, kvm_pte_t *childp)
+static void kvm_set_table_pte(kvm_pte_t *ptep, kvm_pte_t *childp,
+			      struct kvm_pgtable_mm_ops *mm_ops)
 {
-	kvm_pte_t old = *ptep, pte = kvm_phys_to_pte(__pa(childp));
+	kvm_pte_t old = *ptep, pte = kvm_phys_to_pte(mm_ops->virt_to_phys(childp));
 
 	pte |= FIELD_PREP(KVM_PTE_TYPE, KVM_PTE_TYPE_TABLE);
 	pte |= KVM_PTE_VALID;
@@ -187,6 +206,11 @@
 	return pte;
 }
 
+static kvm_pte_t kvm_init_invalid_leaf_owner(u8 owner_id)
+{
+	return FIELD_PREP(KVM_INVALID_PTE_OWNER_MASK, owner_id);
+}
+
 static int kvm_pgtable_visitor_cb(struct kvm_pgtable_walk_data *data, u64 addr,
 				  u32 level, kvm_pte_t *ptep,
 				  enum kvm_pgtable_walk_flags flag)
@@ -228,7 +252,7 @@
 		goto out;
 	}
 
-	childp = kvm_pte_follow(pte);
+	childp = kvm_pte_follow(pte, data->pgt->mm_ops);
 	ret = __kvm_pgtable_walk(data, childp, level + 1);
 	if (ret)
 		goto out;
@@ -303,12 +327,12 @@
 }
 
 struct hyp_map_data {
-	u64		phys;
-	kvm_pte_t	attr;
+	u64				phys;
+	kvm_pte_t			attr;
+	struct kvm_pgtable_mm_ops	*mm_ops;
 };
 
-static int hyp_map_set_prot_attr(enum kvm_pgtable_prot prot,
-				 struct hyp_map_data *data)
+static int hyp_set_prot_attr(enum kvm_pgtable_prot prot, kvm_pte_t *ptep)
 {
 	bool device = prot & KVM_PGTABLE_PROT_DEVICE;
 	u32 mtype = device ? MT_DEVICE_nGnRE : MT_NORMAL;
@@ -333,7 +357,8 @@
 	attr |= FIELD_PREP(KVM_PTE_LEAF_ATTR_LO_S1_AP, ap);
 	attr |= FIELD_PREP(KVM_PTE_LEAF_ATTR_LO_S1_SH, sh);
 	attr |= KVM_PTE_LEAF_ATTR_LO_S1_AF;
-	data->attr = attr;
+	*ptep = attr;
+
 	return 0;
 }
 
@@ -359,6 +384,8 @@
 			  enum kvm_pgtable_walk_flags flag, void * const arg)
 {
 	kvm_pte_t *childp;
+	struct hyp_map_data *data = arg;
+	struct kvm_pgtable_mm_ops *mm_ops = data->mm_ops;
 
 	if (hyp_map_walker_try_leaf(addr, end, level, ptep, arg))
 		return 0;
@@ -366,11 +393,11 @@
 	if (WARN_ON(level == KVM_PGTABLE_MAX_LEVELS - 1))
 		return -EINVAL;
 
-	childp = (kvm_pte_t *)get_zeroed_page(GFP_KERNEL);
+	childp = (kvm_pte_t *)mm_ops->zalloc_page(NULL);
 	if (!childp)
 		return -ENOMEM;
 
-	kvm_set_table_pte(ptep, childp);
+	kvm_set_table_pte(ptep, childp, mm_ops);
 	return 0;
 }
 
@@ -380,6 +407,7 @@
 	int ret;
 	struct hyp_map_data map_data = {
 		.phys	= ALIGN_DOWN(phys, PAGE_SIZE),
+		.mm_ops	= pgt->mm_ops,
 	};
 	struct kvm_pgtable_walker walker = {
 		.cb	= hyp_map_walker,
@@ -387,7 +415,7 @@
 		.arg	= &map_data,
 	};
 
-	ret = hyp_map_set_prot_attr(prot, &map_data);
+	ret = hyp_set_prot_attr(prot, &map_data.attr);
 	if (ret)
 		return ret;
 
@@ -397,16 +425,18 @@
 	return ret;
 }
 
-int kvm_pgtable_hyp_init(struct kvm_pgtable *pgt, u32 va_bits)
+int kvm_pgtable_hyp_init(struct kvm_pgtable *pgt, u32 va_bits,
+			 struct kvm_pgtable_mm_ops *mm_ops)
 {
 	u64 levels = ARM64_HW_PGTABLE_LEVELS(va_bits);
 
-	pgt->pgd = (kvm_pte_t *)get_zeroed_page(GFP_KERNEL);
+	pgt->pgd = (kvm_pte_t *)mm_ops->zalloc_page(NULL);
 	if (!pgt->pgd)
 		return -ENOMEM;
 
 	pgt->ia_bits		= va_bits;
 	pgt->start_level	= KVM_PGTABLE_MAX_LEVELS - levels;
+	pgt->mm_ops		= mm_ops;
 	pgt->mmu		= NULL;
 	return 0;
 }
@@ -414,7 +444,9 @@
 static int hyp_free_walker(u64 addr, u64 end, u32 level, kvm_pte_t *ptep,
 			   enum kvm_pgtable_walk_flags flag, void * const arg)
 {
-	free_page((unsigned long)kvm_pte_follow(*ptep));
+	struct kvm_pgtable_mm_ops *mm_ops = arg;
+
+	mm_ops->put_page((void *)kvm_pte_follow(*ptep, mm_ops));
 	return 0;
 }
 
@@ -423,29 +455,75 @@
 	struct kvm_pgtable_walker walker = {
 		.cb	= hyp_free_walker,
 		.flags	= KVM_PGTABLE_WALK_TABLE_POST,
+		.arg	= pgt->mm_ops,
 	};
 
 	WARN_ON(kvm_pgtable_walk(pgt, 0, BIT(pgt->ia_bits), &walker));
-	free_page((unsigned long)pgt->pgd);
+	pgt->mm_ops->put_page(pgt->pgd);
 	pgt->pgd = NULL;
 }
 
 struct stage2_map_data {
 	u64				phys;
 	kvm_pte_t			attr;
+	u8				owner_id;
 
 	kvm_pte_t			*anchor;
+	kvm_pte_t			*childp;
 
 	struct kvm_s2_mmu		*mmu;
-	struct kvm_mmu_memory_cache	*memcache;
+	void				*memcache;
+
+	struct kvm_pgtable_mm_ops	*mm_ops;
 };
 
-static int stage2_map_set_prot_attr(enum kvm_pgtable_prot prot,
-				    struct stage2_map_data *data)
+u64 kvm_get_vtcr(u64 mmfr0, u64 mmfr1, u32 phys_shift)
+{
+	u64 vtcr = VTCR_EL2_FLAGS;
+	u8 lvls;
+
+	vtcr |= kvm_get_parange(mmfr0) << VTCR_EL2_PS_SHIFT;
+	vtcr |= VTCR_EL2_T0SZ(phys_shift);
+	/*
+	 * Use a minimum 2 level page table to prevent splitting
+	 * host PMD huge pages at stage2.
+	 */
+	lvls = stage2_pgtable_levels(phys_shift);
+	if (lvls < 2)
+		lvls = 2;
+	vtcr |= VTCR_EL2_LVLS_TO_SL0(lvls);
+
+	/*
+	 * Enable the Hardware Access Flag management, unconditionally
+	 * on all CPUs. The features is RES0 on CPUs without the support
+	 * and must be ignored by the CPUs.
+	 */
+	vtcr |= VTCR_EL2_HA;
+
+	/* Set the vmid bits */
+	vtcr |= (get_vmid_bits(mmfr1) == 16) ?
+		VTCR_EL2_VS_16BIT :
+		VTCR_EL2_VS_8BIT;
+
+	return vtcr;
+}
+
+static bool stage2_has_fwb(struct kvm_pgtable *pgt)
+{
+	if (!cpus_have_const_cap(ARM64_HAS_STAGE2_FWB))
+		return false;
+
+	return !(pgt->flags & KVM_PGTABLE_S2_NOFWB);
+}
+
+#define KVM_S2_MEMATTR(pgt, attr) PAGE_S2_MEMATTR(attr, stage2_has_fwb(pgt))
+
+static int stage2_set_prot_attr(struct kvm_pgtable *pgt, enum kvm_pgtable_prot prot,
+				kvm_pte_t *ptep)
 {
 	bool device = prot & KVM_PGTABLE_PROT_DEVICE;
-	kvm_pte_t attr = device ? PAGE_S2_MEMATTR(DEVICE_nGnRE) :
-			    PAGE_S2_MEMATTR(NORMAL);
+	kvm_pte_t attr = device ? KVM_S2_MEMATTR(pgt, DEVICE_nGnRE) :
+			    KVM_S2_MEMATTR(pgt, NORMAL);
 	u32 sh = KVM_PTE_LEAF_ATTR_LO_S2_SH_IS;
 
 	if (!(prot & KVM_PGTABLE_PROT_X))
@@ -461,44 +539,78 @@
 
 	attr |= FIELD_PREP(KVM_PTE_LEAF_ATTR_LO_S2_SH, sh);
 	attr |= KVM_PTE_LEAF_ATTR_LO_S2_AF;
-	data->attr = attr;
+	*ptep = attr;
+
 	return 0;
 }
 
+static bool stage2_pte_needs_update(kvm_pte_t old, kvm_pte_t new)
+{
+	if (!kvm_pte_valid(old) || !kvm_pte_valid(new))
+		return true;
+
+	return ((old ^ new) & (~KVM_PTE_LEAF_ATTR_S2_PERMS));
+}
+
+static bool stage2_pte_is_counted(kvm_pte_t pte)
+{
+	/*
+	 * The refcount tracks valid entries as well as invalid entries if they
+	 * encode ownership of a page to another entity than the page-table
+	 * owner, whose id is 0.
+	 */
+	return !!pte;
+}
+
+static void stage2_put_pte(kvm_pte_t *ptep, struct kvm_s2_mmu *mmu, u64 addr,
+			   u32 level, struct kvm_pgtable_mm_ops *mm_ops)
+{
+	/*
+	 * Clear the existing PTE, and perform break-before-make with
+	 * TLB maintenance if it was valid.
+	 */
+	if (kvm_pte_valid(*ptep)) {
+		kvm_clear_pte(ptep);
+		kvm_call_hyp(__kvm_tlb_flush_vmid_ipa, mmu, addr, level);
+	}
+
+	mm_ops->put_page(ptep);
+}
+
 static int stage2_map_walker_try_leaf(u64 addr, u64 end, u32 level,
 				      kvm_pte_t *ptep,
 				      struct stage2_map_data *data)
 {
 	kvm_pte_t new, old = *ptep;
 	u64 granule = kvm_granule_size(level), phys = data->phys;
-	struct page *page = virt_to_page(ptep);
+	struct kvm_pgtable_mm_ops *mm_ops = data->mm_ops;
 
 	if (!kvm_block_mapping_supported(addr, end, phys, level))
 		return -E2BIG;
 
-	new = kvm_init_valid_leaf_pte(phys, data->attr, level);
-	if (kvm_pte_valid(old)) {
+	if (kvm_phys_is_valid(phys))
+		new = kvm_init_valid_leaf_pte(phys, data->attr, level);
+	else
+		new = kvm_init_invalid_leaf_owner(data->owner_id);
+
+	if (stage2_pte_is_counted(old)) {
 		/*
 		 * Skip updating the PTE if we are trying to recreate the exact
 		 * same mapping or only change the access permissions. Instead,
 		 * the vCPU will exit one more time from guest if still needed
 		 * and then go through the path of relaxing permissions.
 		 */
-		if (!((old ^ new) & (~KVM_PTE_LEAF_ATTR_S2_PERMS)))
+		if (!stage2_pte_needs_update(old, new))
 			return -EAGAIN;
 
-		/*
-		 * There's an existing different valid leaf entry, so perform
-		 * break-before-make.
-		 */
-		kvm_set_invalid_pte(ptep);
-		kvm_call_hyp(__kvm_tlb_flush_vmid_ipa, data->mmu, addr, level);
-		put_page(page);
+		stage2_put_pte(ptep, data->mmu, addr, level, mm_ops);
 	}
 
 	smp_store_release(ptep, new);
-	get_page(page);
-	data->phys += granule;
+	if (stage2_pte_is_counted(new))
+		mm_ops->get_page(ptep);
+	if (kvm_phys_is_valid(phys))
+		data->phys += granule;
 	return 0;
 }
 
@@ -512,7 +624,8 @@
 	if (!kvm_block_mapping_supported(addr, end, data->phys, level))
 		return 0;
 
-	kvm_set_invalid_pte(ptep);
+	data->childp = kvm_pte_follow(*ptep, data->mm_ops);
+	kvm_clear_pte(ptep);
 
 	/*
 	 * Invalidate the whole stage-2, as we may have numerous leaf
@@ -527,13 +640,13 @@
 static int stage2_map_walk_leaf(u64 addr, u64 end, u32 level, kvm_pte_t *ptep,
 				struct stage2_map_data *data)
 {
-	int ret;
+	struct kvm_pgtable_mm_ops *mm_ops = data->mm_ops;
 	kvm_pte_t *childp, pte = *ptep;
-	struct page *page = virt_to_page(ptep);
+	int ret;
 
 	if (data->anchor) {
-		if (kvm_pte_valid(pte))
-			put_page(page);
+		if (stage2_pte_is_counted(pte))
+			mm_ops->put_page(ptep);
 
 		return 0;
 	}
@@ -548,7 +661,7 @@
 	if (!data->memcache)
 		return -ENOMEM;
 
-	childp = kvm_mmu_memory_cache_alloc(data->memcache);
+	childp = mm_ops->zalloc_page(data->memcache);
 	if (!childp)
 		return -ENOMEM;
 
@@ -557,14 +670,11 @@
 	 * a table. Accesses beyond 'end' that fall within the new table
 	 * will be mapped lazily.
 	 */
-	if (kvm_pte_valid(pte)) {
-		kvm_set_invalid_pte(ptep);
-		kvm_call_hyp(__kvm_tlb_flush_vmid_ipa, data->mmu, addr, level);
-		put_page(page);
-	}
+	if (stage2_pte_is_counted(pte))
+		stage2_put_pte(ptep, data->mmu, addr, level, mm_ops);
 
-	kvm_set_table_pte(ptep, childp);
-	get_page(page);
+	kvm_set_table_pte(ptep, childp, mm_ops);
+	mm_ops->get_page(ptep);
 
 	return 0;
 }
@@ -573,19 +683,25 @@
 				      kvm_pte_t *ptep,
 				      struct stage2_map_data *data)
 {
+	struct kvm_pgtable_mm_ops *mm_ops = data->mm_ops;
+	kvm_pte_t *childp;
 	int ret = 0;
 
 	if (!data->anchor)
 		return 0;
 
-	free_page((unsigned long)kvm_pte_follow(*ptep));
-	put_page(virt_to_page(ptep));
-
 	if (data->anchor == ptep) {
+		childp = data->childp;
 		data->anchor = NULL;
+		data->childp = NULL;
 		ret = stage2_map_walk_leaf(addr, end, level, ptep, data);
+	} else {
+		childp = kvm_pte_follow(*ptep, mm_ops);
 	}
 
+	mm_ops->put_page(childp);
+	mm_ops->put_page(ptep);
+
 	return ret;
 }
 
@@ -627,13 +743,14 @@
 
 int kvm_pgtable_stage2_map(struct kvm_pgtable *pgt, u64 addr, u64 size,
 			   u64 phys, enum kvm_pgtable_prot prot,
-			   struct kvm_mmu_memory_cache *mc)
+			   void *mc)
 {
 	int ret;
 	struct stage2_map_data map_data = {
 		.phys		= ALIGN_DOWN(phys, PAGE_SIZE),
 		.mmu		= pgt->mmu,
 		.memcache	= mc,
+		.mm_ops		= pgt->mm_ops,
 	};
 	struct kvm_pgtable_walker walker = {
 		.cb		= stage2_map_walker,
@@ -643,7 +760,10 @@
 		.arg		= &map_data,
 	};
 
-	ret = stage2_map_set_prot_attr(prot, &map_data);
+	if (WARN_ON((pgt->flags & KVM_PGTABLE_S2_IDMAP) && (addr != phys)))
+		return -EINVAL;
+
+	ret = stage2_set_prot_attr(pgt, prot, &map_data.attr);
 	if (ret)
 		return ret;
 
@@ -652,38 +772,63 @@
 	return ret;
 }
 
-static void stage2_flush_dcache(void *addr, u64 size)
+int kvm_pgtable_stage2_set_owner(struct kvm_pgtable *pgt, u64 addr, u64 size,
+				 void *mc, u8 owner_id)
 {
-	if (cpus_have_const_cap(ARM64_HAS_STAGE2_FWB))
-		return;
+	int ret;
+	struct stage2_map_data map_data = {
+		.phys		= KVM_PHYS_INVALID,
+		.mmu		= pgt->mmu,
+		.memcache	= mc,
+		.mm_ops		= pgt->mm_ops,
+		.owner_id	= owner_id,
+	};
+	struct kvm_pgtable_walker walker = {
+		.cb		= stage2_map_walker,
+		.flags		= KVM_PGTABLE_WALK_TABLE_PRE |
+				  KVM_PGTABLE_WALK_LEAF |
+				  KVM_PGTABLE_WALK_TABLE_POST,
+		.arg		= &map_data,
+	};
 
-	__flush_dcache_area(addr, size);
+	if (owner_id > KVM_MAX_OWNER_ID)
+		return -EINVAL;
+
+	ret = kvm_pgtable_walk(pgt, addr, size, &walker);
+	return ret;
 }
 
-static bool stage2_pte_cacheable(kvm_pte_t pte)
+static bool stage2_pte_cacheable(struct kvm_pgtable *pgt, kvm_pte_t pte)
 {
 	u64 memattr = pte & KVM_PTE_LEAF_ATTR_LO_S2_MEMATTR;
-	return memattr == PAGE_S2_MEMATTR(NORMAL);
+	return memattr == KVM_S2_MEMATTR(pgt, NORMAL);
 }
 
 static int stage2_unmap_walker(u64 addr, u64 end, u32 level, kvm_pte_t *ptep,
 			       enum kvm_pgtable_walk_flags flag,
 			       void * const arg)
 {
-	struct kvm_s2_mmu *mmu = arg;
+	struct kvm_pgtable *pgt = arg;
+	struct kvm_s2_mmu *mmu = pgt->mmu;
+	struct kvm_pgtable_mm_ops *mm_ops = pgt->mm_ops;
 	kvm_pte_t pte = *ptep, *childp = NULL;
 	bool need_flush = false;
 
-	if (!kvm_pte_valid(pte))
+	if (!kvm_pte_valid(pte)) {
+		if (stage2_pte_is_counted(pte)) {
+			kvm_clear_pte(ptep);
+			mm_ops->put_page(ptep);
+		}
 		return 0;
+	}
 
 	if (kvm_pte_table(pte, level)) {
-		childp = kvm_pte_follow(pte);
+		childp = kvm_pte_follow(pte, mm_ops);
 
-		if (page_count(virt_to_page(childp)) != 1)
+		if (mm_ops->page_count(childp) != 1)
 			return 0;
-	} else if (stage2_pte_cacheable(pte)) {
-		need_flush = true;
+	} else if (stage2_pte_cacheable(pgt, pte)) {
+		need_flush = !stage2_has_fwb(pgt);
 	}
 
 	/*
@@ -691,17 +836,15 @@
 	 * block entry and rely on the remaining portions being faulted
 	 * back lazily.
 	 */
-	kvm_set_invalid_pte(ptep);
-	kvm_call_hyp(__kvm_tlb_flush_vmid_ipa, mmu, addr, level);
-	put_page(virt_to_page(ptep));
+	stage2_put_pte(ptep, mmu, addr, level, mm_ops);
 
 	if (need_flush) {
-		stage2_flush_dcache(kvm_pte_follow(pte),
+		__flush_dcache_area(kvm_pte_follow(pte, mm_ops),
 				    kvm_granule_size(level));
 	}
 
 	if (childp)
-		free_page((unsigned long)childp);
+		mm_ops->put_page(childp);
 
 	return 0;
 }
@@ -710,7 +853,7 @@
 {
 	struct kvm_pgtable_walker walker = {
 		.cb	= stage2_unmap_walker,
-		.arg	= pgt->mmu,
+		.arg	= pgt,
 		.flags	= KVM_PGTABLE_WALK_LEAF | KVM_PGTABLE_WALK_TABLE_POST,
 	};
 
@@ -842,12 +985,14 @@
 			       enum kvm_pgtable_walk_flags flag,
 			       void * const arg)
 {
+	struct kvm_pgtable *pgt = arg;
+	struct kvm_pgtable_mm_ops *mm_ops = pgt->mm_ops;
 	kvm_pte_t pte = *ptep;
 
-	if (!kvm_pte_valid(pte) || !stage2_pte_cacheable(pte))
+	if (!kvm_pte_valid(pte) || !stage2_pte_cacheable(pgt, pte))
 		return 0;
 
-	stage2_flush_dcache(kvm_pte_follow(pte), kvm_granule_size(level));
+	__flush_dcache_area(kvm_pte_follow(pte, mm_ops), kvm_granule_size(level));
 	return 0;
 }
 
@@ -856,30 +1001,35 @@
 	struct kvm_pgtable_walker walker = {
 		.cb	= stage2_flush_walker,
 		.flags	= KVM_PGTABLE_WALK_LEAF,
+		.arg	= pgt,
 	};
 
-	if (cpus_have_const_cap(ARM64_HAS_STAGE2_FWB))
+	if (stage2_has_fwb(pgt))
 		return 0;
 
 	return kvm_pgtable_walk(pgt, addr, size, &walker);
 }
 
-int kvm_pgtable_stage2_init(struct kvm_pgtable *pgt, struct kvm *kvm)
+int kvm_pgtable_stage2_init_flags(struct kvm_pgtable *pgt, struct kvm_arch *arch,
+				  struct kvm_pgtable_mm_ops *mm_ops,
+				  enum kvm_pgtable_stage2_flags flags)
 {
 	size_t pgd_sz;
-	u64 vtcr = kvm->arch.vtcr;
+	u64 vtcr = arch->vtcr;
 	u32 ia_bits = VTCR_EL2_IPA(vtcr);
 	u32 sl0 = FIELD_GET(VTCR_EL2_SL0_MASK, vtcr);
 	u32 start_level = VTCR_EL2_TGRAN_SL0_BASE - sl0;
 
 	pgd_sz = kvm_pgd_pages(ia_bits, start_level) * PAGE_SIZE;
-	pgt->pgd = alloc_pages_exact(pgd_sz, GFP_KERNEL_ACCOUNT | __GFP_ZERO);
+	pgt->pgd = mm_ops->zalloc_pages_exact(pgd_sz);
 	if (!pgt->pgd)
 		return -ENOMEM;
 
 	pgt->ia_bits		= ia_bits;
 	pgt->start_level	= start_level;
-	pgt->mmu		= &kvm->arch.mmu;
+	pgt->mm_ops		= mm_ops;
+	pgt->mmu		= &arch->mmu;
+	pgt->flags		= flags;
 
 	/* Ensure zeroed PGD pages are visible to the hardware walker */
 	dsb(ishst);
@@ -890,15 +1040,16 @@
 			      enum kvm_pgtable_walk_flags flag,
 			      void * const arg)
 {
+	struct kvm_pgtable_mm_ops *mm_ops = arg;
 	kvm_pte_t pte = *ptep;
 
-	if (!kvm_pte_valid(pte))
+	if (!stage2_pte_is_counted(pte))
 		return 0;
 
-	put_page(virt_to_page(ptep));
+	mm_ops->put_page(ptep);
 
 	if (kvm_pte_table(pte, level))
-		free_page((unsigned long)kvm_pte_follow(pte));
+		mm_ops->put_page(kvm_pte_follow(pte, mm_ops));
 
 	return 0;
 }
@@ -910,10 +1061,85 @@
 		.cb	= stage2_free_walker,
 		.flags	= KVM_PGTABLE_WALK_LEAF |
 			  KVM_PGTABLE_WALK_TABLE_POST,
+		.arg	= pgt->mm_ops,
 	};
 
 	WARN_ON(kvm_pgtable_walk(pgt, 0, BIT(pgt->ia_bits), &walker));
 	pgd_sz = kvm_pgd_pages(pgt->ia_bits, pgt->start_level) * PAGE_SIZE;
-	free_pages_exact(pgt->pgd, pgd_sz);
+	pgt->mm_ops->free_pages_exact(pgt->pgd, pgd_sz);
 	pgt->pgd = NULL;
 }
+
+#define KVM_PTE_LEAF_S2_COMPAT_MASK	(KVM_PTE_LEAF_ATTR_S2_PERMS | \
+					 KVM_PTE_LEAF_ATTR_LO_S2_MEMATTR | \
+					 KVM_PTE_LEAF_ATTR_S2_IGNORED)
+
+static int stage2_check_permission_walker(u64 addr, u64 end, u32 level,
+					  kvm_pte_t *ptep,
+					  enum kvm_pgtable_walk_flags flag,
+					  void * const arg)
+{
+	kvm_pte_t old_attr, pte = *ptep, *new_attr = arg;
+
+	/*
+	 * Compatible mappings are either invalid and owned by the page-table
+	 * owner (whose id is 0), or valid with matching permission attributes.
+	 */
+	if (kvm_pte_valid(pte)) {
+		old_attr = pte & KVM_PTE_LEAF_S2_COMPAT_MASK;
+		if (old_attr != *new_attr)
+			return -EEXIST;
+	} else if (pte) {
+		return -EEXIST;
+	}
+
+	return 0;
+}
+
+int kvm_pgtable_stage2_find_range(struct kvm_pgtable *pgt, u64 addr,
+				  enum kvm_pgtable_prot prot,
+				  struct kvm_mem_range *range)
+{
+	kvm_pte_t attr;
+	struct kvm_pgtable_walker check_perm_walker = {
+		.cb		= stage2_check_permission_walker,
+		.flags		= KVM_PGTABLE_WALK_LEAF,
+		.arg		= &attr,
+	};
+	u64 granule, start, end;
+	u32 level;
+	int ret;
+
+	ret = stage2_set_prot_attr(pgt, prot, &attr);
+	if (ret)
+		return ret;
+	attr &= KVM_PTE_LEAF_S2_COMPAT_MASK;
+
+	for (level = pgt->start_level; level < KVM_PGTABLE_MAX_LEVELS; level++) {
+		granule = kvm_granule_size(level);
+		start = ALIGN_DOWN(addr, granule);
+		end = start + granule;
+
+		if (!kvm_level_supports_block_mapping(level))
+			continue;
+
+		if (start < range->start || range->end < end)
+			continue;
+
+		/*
+		 * Check the presence of existing mappings with incompatible
+		 * permissions within the current block range, and try one level
+		 * deeper if one is found.
+		 */
+		ret = kvm_pgtable_walk(pgt, start, granule, &check_perm_walker);
+		if (ret != -EEXIST)
+			break;
+	}
+
+	if (!ret) {
+		range->start = start;
+		range->end = end;
+	}
+
+	return ret;
+}
diff --git a/arch/arm64/kvm/hyp/reserved_mem.c b/arch/arm64/kvm/hyp/reserved_mem.c
new file mode 100644
index 0000000..83ca23a
--- /dev/null
+++ b/arch/arm64/kvm/hyp/reserved_mem.c
@@ -0,0 +1,113 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * Copyright (C) 2020 - Google LLC
+ * Author: Quentin Perret <qperret@google.com>
+ */
+
+#include <linux/kvm_host.h>
+#include <linux/memblock.h>
+#include <linux/sort.h>
+
+#include <asm/kvm_host.h>
+
+#include <nvhe/memory.h>
+#include <nvhe/mm.h>
+
+static struct memblock_region *hyp_memory = kvm_nvhe_sym(hyp_memory);
+static unsigned int *hyp_memblock_nr_ptr = &kvm_nvhe_sym(hyp_memblock_nr);
+
+phys_addr_t hyp_mem_base;
+phys_addr_t hyp_mem_size;
+
+static int cmp_hyp_memblock(const void *p1, const void *p2)
+{
+	const struct memblock_region *r1 = p1;
+	const struct memblock_region *r2 = p2;
+
+	return r1->base < r2->base ? -1 : (r1->base > r2->base);
+}
+
+static void __init sort_memblock_regions(void)
+{
+	sort(hyp_memory,
+	     *hyp_memblock_nr_ptr,
+	     sizeof(struct memblock_region),
+	     cmp_hyp_memblock,
+	     NULL);
+}
+
+static int __init register_memblock_regions(void)
+{
+	struct memblock_region *reg;
+
+	for_each_mem_region(reg) {
+		if (*hyp_memblock_nr_ptr >= HYP_MEMBLOCK_REGIONS)
+			return -ENOMEM;
+
+		hyp_memory[*hyp_memblock_nr_ptr] = *reg;
+		(*hyp_memblock_nr_ptr)++;
+	}
+	sort_memblock_regions();
+
+	return 0;
+}
+
+void __init kvm_hyp_reserve(void)
+{
+	u64 nr_pages, prev, hyp_mem_pages = 0;
+	int ret;
+
+	if (!is_hyp_mode_available() || is_kernel_in_hyp_mode())
+		return;
+
+	if (kvm_get_mode() != KVM_MODE_PROTECTED)
+		return;
+
+	ret = register_memblock_regions();
+	if (ret) {
+		*hyp_memblock_nr_ptr = 0;
+		kvm_err("Failed to register hyp memblocks: %d\n", ret);
+		return;
+	}
+
+	hyp_mem_pages += hyp_s1_pgtable_pages();
+	hyp_mem_pages += host_s2_mem_pgtable_pages();
+	hyp_mem_pages += host_s2_dev_pgtable_pages();
+
+	/*
+	 * The hyp_vmemmap needs to be backed by pages, but these pages
+	 * themselves need to be present in the vmemmap, so compute the number
+	 * of pages needed by looking for a fixed point.
+	 */
+	nr_pages = 0;
+	do {
+		prev = nr_pages;
+		nr_pages = hyp_mem_pages + prev;
+		nr_pages = DIV_ROUND_UP(nr_pages * sizeof(struct hyp_page), PAGE_SIZE);
+		nr_pages += __hyp_pgtable_max_pages(nr_pages);
+	} while (nr_pages != prev);
+	hyp_mem_pages += nr_pages;
+
+	/*
+	 * Try to allocate a PMD-aligned region to reduce TLB pressure once
+	 * this is unmapped from the host stage-2, and fallback to PAGE_SIZE.
+	 */
+	hyp_mem_size = hyp_mem_pages << PAGE_SHIFT;
+	hyp_mem_base = memblock_find_in_range(0, memblock_end_of_DRAM(),
+					      ALIGN(hyp_mem_size, PMD_SIZE),
+					      PMD_SIZE);
+	if (!hyp_mem_base)
+		hyp_mem_base = memblock_find_in_range(0, memblock_end_of_DRAM(),
+						      hyp_mem_size, PAGE_SIZE);
+	else
+		hyp_mem_size = ALIGN(hyp_mem_size, PMD_SIZE);
+
+	if (!hyp_mem_base) {
+		kvm_err("Failed to reserve hyp memory\n");
+		return;
+	}
+	memblock_reserve(hyp_mem_base, hyp_mem_size);
+
+	kvm_info("Reserved %lld MiB at 0x%llx\n", hyp_mem_size >> 20,
+		 hyp_mem_base);
+}
diff --git a/arch/arm64/kvm/hyp/vhe/switch.c b/arch/arm64/kvm/hyp/vhe/switch.c
index af8e940..7b8f7db 100644
--- a/arch/arm64/kvm/hyp/vhe/switch.c
+++ b/arch/arm64/kvm/hyp/vhe/switch.c
@@ -27,8 +27,6 @@
 #include <asm/processor.h>
 #include <asm/thread_info.h>
 
-const char __hyp_panic_string[] = "HYP panic:\nPS:%08llx PC:%016llx ESR:%08llx\nFAR:%016llx HPFAR:%016llx PAR:%016llx\nVCPU:%p\n";
-
 /* VHE specific context */
 DEFINE_PER_CPU(struct kvm_host_data, kvm_host_data);
 DEFINE_PER_CPU(struct kvm_cpu_context, kvm_hyp_ctxt);
@@ -207,7 +205,7 @@
 	__deactivate_traps(vcpu);
 	sysreg_restore_host_state_vhe(host_ctxt);
 
-	panic(__hyp_panic_string,
+	panic("HYP panic:\nPS:%08llx PC:%016llx ESR:%08llx\nFAR:%016llx HPFAR:%016llx PAR:%016llx\nVCPU:%p\n",
 	      spsr, elr,
 	      read_sysreg_el2(SYS_ESR), read_sysreg_el2(SYS_FAR),
 	      read_sysreg(hpfar_el2), par, vcpu);
diff --git a/arch/arm64/kvm/hypercalls.c b/arch/arm64/kvm/hypercalls.c
index ead21b9..30da78f 100644
--- a/arch/arm64/kvm/hypercalls.c
+++ b/arch/arm64/kvm/hypercalls.c
@@ -9,16 +9,65 @@
 #include <kvm/arm_hypercalls.h>
 #include <kvm/arm_psci.h>
 
+static void kvm_ptp_get_time(struct kvm_vcpu *vcpu, u64 *val)
+{
+	struct system_time_snapshot systime_snapshot;
+	u64 cycles = ~0UL;
+	u32 feature;
+
+	/*
+	 * system time and counter value must captured at the same
+	 * time to keep consistency and precision.
+	 */
+	ktime_get_snapshot(&systime_snapshot);
+
+	/*
+	 * This is only valid if the current clocksource is the
+	 * architected counter, as this is the only one the guest
+	 * can see.
+	 */
+	if (systime_snapshot.cs_id != CSID_ARM_ARCH_COUNTER)
+		return;
+
+	/*
+	 * The guest selects one of the two reference counters
+	 * (virtual or physical) with the first argument of the SMCCC
+	 * call. In case the identifier is not supported, error out.
+	 */
+	feature = smccc_get_arg1(vcpu);
+	switch (feature) {
+	case KVM_PTP_VIRT_COUNTER:
+		cycles = systime_snapshot.cycles - vcpu_read_sys_reg(vcpu, CNTVOFF_EL2);
+		break;
+	case KVM_PTP_PHYS_COUNTER:
+		cycles = systime_snapshot.cycles;
+		break;
+	default:
+		return;
+	}
+
+	/*
+	 * This relies on the top bit of val[0] never being set for
+	 * valid values of system time, because that is *really* far
+	 * in the future (about 292 years from 1970, and at that stage
+	 * nobody will give a damn about it).
+	 */
+	val[0] = upper_32_bits(systime_snapshot.real);
+	val[1] = lower_32_bits(systime_snapshot.real);
+	val[2] = upper_32_bits(cycles);
+	val[3] = lower_32_bits(cycles);
+}
+
 int kvm_hvc_call_handler(struct kvm_vcpu *vcpu)
 {
 	u32 func_id = smccc_get_function(vcpu);
-	long val = SMCCC_RET_NOT_SUPPORTED;
+	u64 val[4] = {SMCCC_RET_NOT_SUPPORTED};
 	u32 feature;
 	gpa_t gpa;
 
 	switch (func_id) {
 	case ARM_SMCCC_VERSION_FUNC_ID:
-		val = ARM_SMCCC_VERSION_1_1;
+		val[0] = ARM_SMCCC_VERSION_1_1;
 		break;
 	case ARM_SMCCC_ARCH_FEATURES_FUNC_ID:
 		feature = smccc_get_arg1(vcpu);
@@ -28,10 +77,10 @@
 			case SPECTRE_VULNERABLE:
 				break;
 			case SPECTRE_MITIGATED:
-				val = SMCCC_RET_SUCCESS;
+				val[0] = SMCCC_RET_SUCCESS;
 				break;
 			case SPECTRE_UNAFFECTED:
-				val = SMCCC_ARCH_WORKAROUND_RET_UNAFFECTED;
+				val[0] = SMCCC_ARCH_WORKAROUND_RET_UNAFFECTED;
 				break;
 			}
 			break;
@@ -54,22 +103,35 @@
 					break;
 				fallthrough;
 			case SPECTRE_UNAFFECTED:
-				val = SMCCC_RET_NOT_REQUIRED;
+				val[0] = SMCCC_RET_NOT_REQUIRED;
 				break;
 			}
 			break;
 		case ARM_SMCCC_HV_PV_TIME_FEATURES:
-			val = SMCCC_RET_SUCCESS;
+			val[0] = SMCCC_RET_SUCCESS;
 			break;
 		}
 		break;
 	case ARM_SMCCC_HV_PV_TIME_FEATURES:
-		val = kvm_hypercall_pv_features(vcpu);
+		val[0] = kvm_hypercall_pv_features(vcpu);
 		break;
 	case ARM_SMCCC_HV_PV_TIME_ST:
 		gpa = kvm_init_stolen_time(vcpu);
 		if (gpa != GPA_INVALID)
-			val = gpa;
+			val[0] = gpa;
+		break;
+	case ARM_SMCCC_VENDOR_HYP_CALL_UID_FUNC_ID:
+		val[0] = ARM_SMCCC_VENDOR_HYP_UID_KVM_REG_0;
+		val[1] = ARM_SMCCC_VENDOR_HYP_UID_KVM_REG_1;
+		val[2] = ARM_SMCCC_VENDOR_HYP_UID_KVM_REG_2;
+		val[3] = ARM_SMCCC_VENDOR_HYP_UID_KVM_REG_3;
+		break;
+	case ARM_SMCCC_VENDOR_HYP_KVM_FEATURES_FUNC_ID:
+		val[0] = BIT(ARM_SMCCC_KVM_FUNC_FEATURES);
+		val[0] |= BIT(ARM_SMCCC_KVM_FUNC_PTP);
+		break;
+	case ARM_SMCCC_VENDOR_HYP_KVM_PTP_FUNC_ID:
+		kvm_ptp_get_time(vcpu, val);
 		break;
 	case ARM_SMCCC_TRNG_VERSION:
 	case ARM_SMCCC_TRNG_FEATURES:
@@ -81,6 +143,6 @@
 		return kvm_psci_call(vcpu);
 	}
 
-	smccc_set_retval(vcpu, val, 0, 0, 0);
+	smccc_set_retval(vcpu, val[0], val[1], val[2], val[3]);
 	return 1;
 }
diff --git a/arch/arm64/kvm/mmu.c b/arch/arm64/kvm/mmu.c
index 8711894..c5d1f3c 100644
--- a/arch/arm64/kvm/mmu.c
+++ b/arch/arm64/kvm/mmu.c
@@ -88,6 +88,44 @@
 	return !pfn_valid(pfn);
 }
 
+static void *stage2_memcache_zalloc_page(void *arg)
+{
+	struct kvm_mmu_memory_cache *mc = arg;
+
+	/* Allocated with __GFP_ZERO, so no need to zero */
+	return kvm_mmu_memory_cache_alloc(mc);
+}
+
+static void *kvm_host_zalloc_pages_exact(size_t size)
+{
+	return alloc_pages_exact(size, GFP_KERNEL_ACCOUNT | __GFP_ZERO);
+}
+
+static void kvm_host_get_page(void *addr)
+{
+	get_page(virt_to_page(addr));
+}
+
+static void kvm_host_put_page(void *addr)
+{
+	put_page(virt_to_page(addr));
+}
+
+static int kvm_host_page_count(void *addr)
+{
+	return page_count(virt_to_page(addr));
+}
+
+static phys_addr_t kvm_host_pa(void *addr)
+{
+	return __pa(addr);
+}
+
+static void *kvm_host_va(phys_addr_t phys)
+{
+	return __va(phys);
+}
+
 /*
  * Unmapping vs dcache management:
  *
@@ -127,7 +165,7 @@
 static void __unmap_stage2_range(struct kvm_s2_mmu *mmu, phys_addr_t start, u64 size,
 				 bool may_block)
 {
-	struct kvm *kvm = mmu->kvm;
+	struct kvm *kvm = kvm_s2_mmu_to_kvm(mmu);
 	phys_addr_t end = start + size;
 
 	assert_spin_locked(&kvm->mmu_lock);
@@ -183,15 +221,39 @@
 	if (hyp_pgtable) {
 		kvm_pgtable_hyp_destroy(hyp_pgtable);
 		kfree(hyp_pgtable);
+		hyp_pgtable = NULL;
 	}
 	mutex_unlock(&kvm_hyp_pgd_mutex);
 }
 
+static bool kvm_host_owns_hyp_mappings(void)
+{
+	if (static_branch_likely(&kvm_protected_mode_initialized))
+		return false;
+
+	/*
+	 * This can happen at boot time when __create_hyp_mappings() is called
+	 * after the hyp protection has been enabled, but the static key has
+	 * not been flipped yet.
+	 */
+	if (!hyp_pgtable && is_protected_kvm_enabled())
+		return false;
+
+	WARN_ON(!hyp_pgtable);
+
+	return true;
+}
+
 static int __create_hyp_mappings(unsigned long start, unsigned long size,
 				 unsigned long phys, enum kvm_pgtable_prot prot)
 {
 	int err;
 
+	if (!kvm_host_owns_hyp_mappings()) {
+		return kvm_call_hyp_nvhe(__pkvm_create_mappings,
+					 start, size, phys, prot);
+	}
+
 	mutex_lock(&kvm_hyp_pgd_mutex);
 	err = kvm_pgtable_hyp_map(hyp_pgtable, start, size, phys, prot);
 	mutex_unlock(&kvm_hyp_pgd_mutex);
@@ -253,6 +315,16 @@
 	unsigned long base;
 	int ret = 0;
 
+	if (!kvm_host_owns_hyp_mappings()) {
+		base = kvm_call_hyp_nvhe(__pkvm_create_private_mapping,
+					 phys_addr, size, prot);
+		if (IS_ERR_OR_NULL((void *)base))
+			return PTR_ERR((void *)base);
+		*haddr = base;
+
+		return 0;
+	}
+
 	mutex_lock(&kvm_hyp_pgd_mutex);
 
 	/*
@@ -351,6 +423,17 @@
 	return 0;
 }
 
+static struct kvm_pgtable_mm_ops kvm_s2_mm_ops = {
+	.zalloc_page		= stage2_memcache_zalloc_page,
+	.zalloc_pages_exact	= kvm_host_zalloc_pages_exact,
+	.free_pages_exact	= free_pages_exact,
+	.get_page		= kvm_host_get_page,
+	.put_page		= kvm_host_put_page,
+	.page_count		= kvm_host_page_count,
+	.phys_to_virt		= kvm_host_va,
+	.virt_to_phys		= kvm_host_pa,
+};
+
 /**
  * kvm_init_stage2_mmu - Initialise a S2 MMU strucrure
  * @kvm:	The pointer to the KVM structure
@@ -374,7 +457,7 @@
 	if (!pgt)
 		return -ENOMEM;
 
-	err = kvm_pgtable_stage2_init(pgt, kvm);
+	err = kvm_pgtable_stage2_init(pgt, &kvm->arch, &kvm_s2_mm_ops);
 	if (err)
 		goto out_free_pgtable;
 
@@ -387,7 +470,7 @@
 	for_each_possible_cpu(cpu)
 		*per_cpu_ptr(mmu->last_vcpu_ran, cpu) = -1;
 
-	mmu->kvm = kvm;
+	mmu->arch = &kvm->arch;
 	mmu->pgt = pgt;
 	mmu->pgd_phys = __pa(pgt->pgd);
 	mmu->vmid.vmid_gen = 0;
@@ -421,10 +504,11 @@
 	 *     +--------------------------------------------+
 	 */
 	do {
-		struct vm_area_struct *vma = find_vma(current->mm, hva);
+		struct vm_area_struct *vma;
 		hva_t vm_start, vm_end;
 
-		if (!vma || vma->vm_start >= reg_end)
+		vma = find_vma_intersection(current->mm, hva, reg_end);
+		if (!vma)
 			break;
 
 		/*
@@ -469,7 +553,7 @@
 
 void kvm_free_stage2_pgd(struct kvm_s2_mmu *mmu)
 {
-	struct kvm *kvm = mmu->kvm;
+	struct kvm *kvm = kvm_s2_mmu_to_kvm(mmu);
 	struct kvm_pgtable *pgt = NULL;
 
 	spin_lock(&kvm->mmu_lock);
@@ -538,7 +622,7 @@
  */
 static void stage2_wp_range(struct kvm_s2_mmu *mmu, phys_addr_t addr, phys_addr_t end)
 {
-	struct kvm *kvm = mmu->kvm;
+	struct kvm *kvm = kvm_s2_mmu_to_kvm(mmu);
 	stage2_apply_range_resched(kvm, addr, end, kvm_pgtable_stage2_wrprotect);
 }
 
@@ -555,7 +639,7 @@
  * Acquires kvm_mmu_lock. Called with kvm->slots_lock mutex acquired,
  * serializing operations for VM memory regions.
  */
-void kvm_mmu_wp_memory_region(struct kvm *kvm, int slot)
+static void kvm_mmu_wp_memory_region(struct kvm *kvm, int slot)
 {
 	struct kvm_memslots *slots = kvm_memslots(kvm);
 	struct kvm_memory_slot *memslot = id_to_memslot(slots, slot);
@@ -839,13 +923,18 @@
 	 * gfn_to_pfn_prot (which calls get_user_pages), so that we don't risk
 	 * the page we just got a reference to gets unmapped before we have a
 	 * chance to grab the mmu_lock, which ensure that if the page gets
-	 * unmapped afterwards, the call to kvm_unmap_hva will take it away
+	 * unmapped afterwards, the call to kvm_unmap_gfn will take it away
 	 * from us again properly. This smp_rmb() interacts with the smp_wmb()
 	 * in kvm_mmu_notifier_invalidate_<page|range_end>.
+	 *
+	 * Besides, __gfn_to_pfn_memslot() instead of gfn_to_pfn_prot() is
+	 * used to avoid unnecessary overhead introduced to locate the memory
+	 * slot because it's always fixed even @gfn is adjusted for huge pages.
 	 */
 	smp_rmb();
 
-	pfn = gfn_to_pfn_prot(kvm, gfn, write_fault, &writable);
+	pfn = __gfn_to_pfn_memslot(memslot, gfn, false, NULL,
+				   write_fault, &writable, NULL);
 	if (pfn == KVM_PFN_ERR_HWPOISON) {
 		kvm_send_hwpoison_signal(hva, vma_shift);
 		return 0;
@@ -911,7 +1000,7 @@
 	/* Mark the page dirty only if the fault is handled successfully */
 	if (writable && !ret) {
 		kvm_set_pfn_dirty(pfn);
-		mark_page_dirty(kvm, gfn);
+		mark_page_dirty_in_slot(kvm, memslot, gfn);
 	}
 
 out_unlock:
@@ -1064,126 +1153,70 @@
 	return ret;
 }
 
-static int handle_hva_to_gpa(struct kvm *kvm,
-			     unsigned long start,
-			     unsigned long end,
-			     int (*handler)(struct kvm *kvm,
-					    gpa_t gpa, u64 size,
-					    void *data),
-			     void *data)
-{
-	struct kvm_memslots *slots;
-	struct kvm_memory_slot *memslot;
-	int ret = 0;
-
-	slots = kvm_memslots(kvm);
-
-	/* we only care about the pages that the guest sees */
-	kvm_for_each_memslot(memslot, slots) {
-		unsigned long hva_start, hva_end;
-		gfn_t gpa;
-
-		hva_start = max(start, memslot->userspace_addr);
-		hva_end = min(end, memslot->userspace_addr +
-					(memslot->npages << PAGE_SHIFT));
-		if (hva_start >= hva_end)
-			continue;
-
-		gpa = hva_to_gfn_memslot(hva_start, memslot) << PAGE_SHIFT;
-		ret |= handler(kvm, gpa, (u64)(hva_end - hva_start), data);
-	}
-
-	return ret;
-}
-
-static int kvm_unmap_hva_handler(struct kvm *kvm, gpa_t gpa, u64 size, void *data)
-{
-	unsigned flags = *(unsigned *)data;
-	bool may_block = flags & MMU_NOTIFIER_RANGE_BLOCKABLE;
-
-	__unmap_stage2_range(&kvm->arch.mmu, gpa, size, may_block);
-	return 0;
-}
-
-int kvm_unmap_hva_range(struct kvm *kvm,
-			unsigned long start, unsigned long end, unsigned flags)
+bool kvm_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range)
 {
 	if (!kvm->arch.mmu.pgt)
 		return 0;
 
-	trace_kvm_unmap_hva_range(start, end);
-	handle_hva_to_gpa(kvm, start, end, &kvm_unmap_hva_handler, &flags);
+	__unmap_stage2_range(&kvm->arch.mmu, range->start << PAGE_SHIFT,
+			     (range->end - range->start) << PAGE_SHIFT,
+			     range->may_block);
+
 	return 0;
 }
 
-static int kvm_set_spte_handler(struct kvm *kvm, gpa_t gpa, u64 size, void *data)
+bool kvm_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	kvm_pfn_t *pfn = (kvm_pfn_t *)data;
-
-	WARN_ON(size != PAGE_SIZE);
-
-	/*
-	 * The MMU notifiers will have unmapped a huge PMD before calling
-	 * ->change_pte() (which in turn calls kvm_set_spte_hva()) and
-	 * therefore we never need to clear out a huge PMD through this
-	 * calling path and a memcache is not required.
-	 */
-	kvm_pgtable_stage2_map(kvm->arch.mmu.pgt, gpa, PAGE_SIZE,
-			       __pfn_to_phys(*pfn), KVM_PGTABLE_PROT_R, NULL);
-	return 0;
-}
-
-int kvm_set_spte_hva(struct kvm *kvm, unsigned long hva, pte_t pte)
-{
-	unsigned long end = hva + PAGE_SIZE;
-	kvm_pfn_t pfn = pte_pfn(pte);
+	kvm_pfn_t pfn = pte_pfn(range->pte);
 
 	if (!kvm->arch.mmu.pgt)
 		return 0;
 
-	trace_kvm_set_spte_hva(hva);
+	WARN_ON(range->end - range->start != 1);
 
 	/*
 	 * We've moved a page around, probably through CoW, so let's treat it
 	 * just like a translation fault and clean the cache to the PoC.
 	 */
 	clean_dcache_guest_page(pfn, PAGE_SIZE);
-	handle_hva_to_gpa(kvm, hva, end, &kvm_set_spte_handler, &pfn);
+
+	/*
+	 * The MMU notifiers will have unmapped a huge PMD before calling
+	 * ->change_pte() (which in turn calls kvm_set_spte_gfn()) and
+	 * therefore we never need to clear out a huge PMD through this
+	 * calling path and a memcache is not required.
+	 */
+	kvm_pgtable_stage2_map(kvm->arch.mmu.pgt, range->start << PAGE_SHIFT,
+			       PAGE_SIZE, __pfn_to_phys(pfn),
+			       KVM_PGTABLE_PROT_R, NULL);
+
 	return 0;
 }
 
-static int kvm_age_hva_handler(struct kvm *kvm, gpa_t gpa, u64 size, void *data)
+bool kvm_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	pte_t pte;
+	u64 size = (range->end - range->start) << PAGE_SHIFT;
 	kvm_pte_t kpte;
+	pte_t pte;
+
+	if (!kvm->arch.mmu.pgt)
+		return 0;
 
 	WARN_ON(size != PAGE_SIZE && size != PMD_SIZE && size != PUD_SIZE);
-	kpte = kvm_pgtable_stage2_mkold(kvm->arch.mmu.pgt, gpa);
+
+	kpte = kvm_pgtable_stage2_mkold(kvm->arch.mmu.pgt,
+					range->start << PAGE_SHIFT);
 	pte = __pte(kpte);
 	return pte_valid(pte) && pte_young(pte);
 }
 
-static int kvm_test_age_hva_handler(struct kvm *kvm, gpa_t gpa, u64 size, void *data)
-{
-	WARN_ON(size != PAGE_SIZE && size != PMD_SIZE && size != PUD_SIZE);
-	return kvm_pgtable_stage2_is_young(kvm->arch.mmu.pgt, gpa);
-}
-
-int kvm_age_hva(struct kvm *kvm, unsigned long start, unsigned long end)
+bool kvm_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
 	if (!kvm->arch.mmu.pgt)
 		return 0;
-	trace_kvm_age_hva(start, end);
-	return handle_hva_to_gpa(kvm, start, end, kvm_age_hva_handler, NULL);
-}
 
-int kvm_test_age_hva(struct kvm *kvm, unsigned long hva)
-{
-	if (!kvm->arch.mmu.pgt)
-		return 0;
-	trace_kvm_test_age_hva(hva);
-	return handle_hva_to_gpa(kvm, hva, hva + PAGE_SIZE,
-				 kvm_test_age_hva_handler, NULL);
+	return kvm_pgtable_stage2_is_young(kvm->arch.mmu.pgt,
+					   range->start << PAGE_SHIFT);
 }
 
 phys_addr_t kvm_mmu_get_httbr(void)
@@ -1208,10 +1241,22 @@
 	return err;
 }
 
-int kvm_mmu_init(void)
+static void *kvm_hyp_zalloc_page(void *arg)
+{
+	return (void *)get_zeroed_page(GFP_KERNEL);
+}
+
+static struct kvm_pgtable_mm_ops kvm_hyp_mm_ops = {
+	.zalloc_page		= kvm_hyp_zalloc_page,
+	.get_page		= kvm_host_get_page,
+	.put_page		= kvm_host_put_page,
+	.phys_to_virt		= kvm_host_va,
+	.virt_to_phys		= kvm_host_pa,
+};
+
+int kvm_mmu_init(u32 *hyp_va_bits)
 {
 	int err;
-	u32 hyp_va_bits;
 
 	hyp_idmap_start = __pa_symbol(__hyp_idmap_text_start);
 	hyp_idmap_start = ALIGN_DOWN(hyp_idmap_start, PAGE_SIZE);
@@ -1225,8 +1270,8 @@
 	 */
 	BUG_ON((hyp_idmap_start ^ (hyp_idmap_end - 1)) & PAGE_MASK);
 
-	hyp_va_bits = 64 - ((idmap_t0sz & TCR_T0SZ_MASK) >> TCR_T0SZ_OFFSET);
-	kvm_debug("Using %u-bit virtual addresses at EL2\n", hyp_va_bits);
+	*hyp_va_bits = 64 - ((idmap_t0sz & TCR_T0SZ_MASK) >> TCR_T0SZ_OFFSET);
+	kvm_debug("Using %u-bit virtual addresses at EL2\n", *hyp_va_bits);
 	kvm_debug("IDMAP page: %lx\n", hyp_idmap_start);
 	kvm_debug("HYP VA range: %lx:%lx\n",
 		  kern_hyp_va(PAGE_OFFSET),
@@ -1251,7 +1296,7 @@
 		goto out;
 	}
 
-	err = kvm_pgtable_hyp_init(hyp_pgtable, hyp_va_bits);
+	err = kvm_pgtable_hyp_init(hyp_pgtable, *hyp_va_bits, &kvm_hyp_mm_ops);
 	if (err)
 		goto out_free_pgtable;
 
@@ -1329,10 +1374,11 @@
 	 *     +--------------------------------------------+
 	 */
 	do {
-		struct vm_area_struct *vma = find_vma(current->mm, hva);
+		struct vm_area_struct *vma;
 		hva_t vm_start, vm_end;
 
-		if (!vma || vma->vm_start >= reg_end)
+		vma = find_vma_intersection(current->mm, hva, reg_end);
+		if (!vma)
 			break;
 
 		/*
diff --git a/arch/arm64/kvm/perf.c b/arch/arm64/kvm/perf.c
index 73916432..8f860ae 100644
--- a/arch/arm64/kvm/perf.c
+++ b/arch/arm64/kvm/perf.c
@@ -55,7 +55,8 @@
 	 * hardware performance counters. This could ensure the presence of
 	 * a physical PMU and CONFIG_PERF_EVENT is selected.
 	 */
-	if (IS_ENABLED(CONFIG_ARM_PMU) && perf_num_counters() > 0)
+	if (IS_ENABLED(CONFIG_ARM_PMU) && perf_num_counters() > 0
+				       && !is_protected_kvm_enabled())
 		static_branch_enable(&kvm_arm_pmu_available);
 
 	return perf_register_guest_info_callbacks(&kvm_guest_cbs);
diff --git a/arch/arm64/kvm/pmu.c b/arch/arm64/kvm/pmu.c
index faf32a4..03a6c1f 100644
--- a/arch/arm64/kvm/pmu.c
+++ b/arch/arm64/kvm/pmu.c
@@ -33,7 +33,7 @@
 {
 	struct kvm_host_data *ctx = this_cpu_ptr_hyp_sym(kvm_host_data);
 
-	if (!ctx || !kvm_pmu_switch_needed(attr))
+	if (!kvm_arm_support_pmu_v3() || !ctx || !kvm_pmu_switch_needed(attr))
 		return;
 
 	if (!attr->exclude_host)
@@ -49,7 +49,7 @@
 {
 	struct kvm_host_data *ctx = this_cpu_ptr_hyp_sym(kvm_host_data);
 
-	if (!ctx)
+	if (!kvm_arm_support_pmu_v3() || !ctx)
 		return;
 
 	ctx->pmu_events.events_host &= ~clr;
@@ -172,7 +172,7 @@
 	struct kvm_host_data *host;
 	u32 events_guest, events_host;
 
-	if (!has_vhe())
+	if (!kvm_arm_support_pmu_v3() || !has_vhe())
 		return;
 
 	preempt_disable();
@@ -193,7 +193,7 @@
 	struct kvm_host_data *host;
 	u32 events_guest, events_host;
 
-	if (!has_vhe())
+	if (!kvm_arm_support_pmu_v3() || !has_vhe())
 		return;
 
 	host = this_cpu_ptr_hyp_sym(kvm_host_data);
diff --git a/arch/arm64/kvm/reset.c b/arch/arm64/kvm/reset.c
index bd354cd..2d22775 100644
--- a/arch/arm64/kvm/reset.c
+++ b/arch/arm64/kvm/reset.c
@@ -74,10 +74,6 @@
 	if (!system_supports_sve())
 		return -EINVAL;
 
-	/* Verify that KVM startup enforced this when SVE was detected: */
-	if (WARN_ON(!has_vhe()))
-		return -EINVAL;
-
 	vcpu->arch.sve_max_vl = kvm_sve_max_vl;
 
 	/*
@@ -240,8 +236,8 @@
 		break;
 	}
 
-	/* Reset core registers */
-	memset(vcpu_gp_regs(vcpu), 0, sizeof(*vcpu_gp_regs(vcpu)));
+	/* Zero all registers */
+	memset(&vcpu->arch.ctxt, 0, sizeof(vcpu->arch.ctxt));
 	vcpu_gp_regs(vcpu)->pstate = pstate;
 
 	/* Reset system registers */
@@ -333,19 +329,10 @@
 	return 0;
 }
 
-/*
- * Configure the VTCR_EL2 for this VM. The VTCR value is common
- * across all the physical CPUs on the system. We use system wide
- * sanitised values to fill in different fields, except for Hardware
- * Management of Access Flags. HA Flag is set unconditionally on
- * all CPUs, as it is safe to run with or without the feature and
- * the bit is RES0 on CPUs that don't support it.
- */
 int kvm_arm_setup_stage2(struct kvm *kvm, unsigned long type)
 {
-	u64 vtcr = VTCR_EL2_FLAGS, mmfr0;
-	u32 parange, phys_shift;
-	u8 lvls;
+	u64 mmfr0, mmfr1;
+	u32 phys_shift;
 
 	if (type & ~KVM_VM_TYPE_ARM_IPA_SIZE_MASK)
 		return -EINVAL;
@@ -365,33 +352,8 @@
 	}
 
 	mmfr0 = read_sanitised_ftr_reg(SYS_ID_AA64MMFR0_EL1);
-	parange = cpuid_feature_extract_unsigned_field(mmfr0,
-				ID_AA64MMFR0_PARANGE_SHIFT);
-	if (parange > ID_AA64MMFR0_PARANGE_MAX)
-		parange = ID_AA64MMFR0_PARANGE_MAX;
-	vtcr |= parange << VTCR_EL2_PS_SHIFT;
+	mmfr1 = read_sanitised_ftr_reg(SYS_ID_AA64MMFR1_EL1);
+	kvm->arch.vtcr = kvm_get_vtcr(mmfr0, mmfr1, phys_shift);
 
-	vtcr |= VTCR_EL2_T0SZ(phys_shift);
-	/*
-	 * Use a minimum 2 level page table to prevent splitting
-	 * host PMD huge pages at stage2.
-	 */
-	lvls = stage2_pgtable_levels(phys_shift);
-	if (lvls < 2)
-		lvls = 2;
-	vtcr |= VTCR_EL2_LVLS_TO_SL0(lvls);
-
-	/*
-	 * Enable the Hardware Access Flag management, unconditionally
-	 * on all CPUs. The features is RES0 on CPUs without the support
-	 * and must be ignored by the CPUs.
-	 */
-	vtcr |= VTCR_EL2_HA;
-
-	/* Set the vmid bits */
-	vtcr |= (kvm_get_vmid_bits() == 16) ?
-		VTCR_EL2_VS_16BIT :
-		VTCR_EL2_VS_8BIT;
-	kvm->arch.vtcr = vtcr;
 	return 0;
 }
diff --git a/arch/arm64/kvm/sys_regs.c b/arch/arm64/kvm/sys_regs.c
index 4f2f1e3..52fdb9a0 100644
--- a/arch/arm64/kvm/sys_regs.c
+++ b/arch/arm64/kvm/sys_regs.c
@@ -1472,6 +1472,7 @@
 	{ SYS_DESC(SYS_GCR_EL1), undef_access },
 
 	{ SYS_DESC(SYS_ZCR_EL1), NULL, reset_val, ZCR_EL1, 0, .visibility = sve_visibility },
+	{ SYS_DESC(SYS_TRFCR_EL1), undef_access },
 	{ SYS_DESC(SYS_TTBR0_EL1), access_vm_reg, reset_unknown, TTBR0_EL1 },
 	{ SYS_DESC(SYS_TTBR1_EL1), access_vm_reg, reset_unknown, TTBR1_EL1 },
 	{ SYS_DESC(SYS_TCR_EL1), access_vm_reg, reset_val, TCR_EL1, 0 },
diff --git a/arch/arm64/kvm/trace_arm.h b/arch/arm64/kvm/trace_arm.h
index ff04443..33e4e7d 100644
--- a/arch/arm64/kvm/trace_arm.h
+++ b/arch/arm64/kvm/trace_arm.h
@@ -135,72 +135,6 @@
 		  __entry->vcpu_pc, __entry->instr, __entry->cpsr)
 );
 
-TRACE_EVENT(kvm_unmap_hva_range,
-	TP_PROTO(unsigned long start, unsigned long end),
-	TP_ARGS(start, end),
-
-	TP_STRUCT__entry(
-		__field(	unsigned long,	start		)
-		__field(	unsigned long,	end		)
-	),
-
-	TP_fast_assign(
-		__entry->start		= start;
-		__entry->end		= end;
-	),
-
-	TP_printk("mmu notifier unmap range: %#016lx -- %#016lx",
-		  __entry->start, __entry->end)
-);
-
-TRACE_EVENT(kvm_set_spte_hva,
-	TP_PROTO(unsigned long hva),
-	TP_ARGS(hva),
-
-	TP_STRUCT__entry(
-		__field(	unsigned long,	hva		)
-	),
-
-	TP_fast_assign(
-		__entry->hva		= hva;
-	),
-
-	TP_printk("mmu notifier set pte hva: %#016lx", __entry->hva)
-);
-
-TRACE_EVENT(kvm_age_hva,
-	TP_PROTO(unsigned long start, unsigned long end),
-	TP_ARGS(start, end),
-
-	TP_STRUCT__entry(
-		__field(	unsigned long,	start		)
-		__field(	unsigned long,	end		)
-	),
-
-	TP_fast_assign(
-		__entry->start		= start;
-		__entry->end		= end;
-	),
-
-	TP_printk("mmu notifier age hva: %#016lx -- %#016lx",
-		  __entry->start, __entry->end)
-);
-
-TRACE_EVENT(kvm_test_age_hva,
-	TP_PROTO(unsigned long hva),
-	TP_ARGS(hva),
-
-	TP_STRUCT__entry(
-		__field(	unsigned long,	hva		)
-	),
-
-	TP_fast_assign(
-		__entry->hva		= hva;
-	),
-
-	TP_printk("mmu notifier test age hva: %#016lx", __entry->hva)
-);
-
 TRACE_EVENT(kvm_set_way_flush,
 	    TP_PROTO(unsigned long vcpu_pc, bool cache),
 	    TP_ARGS(vcpu_pc, cache),
diff --git a/arch/arm64/kvm/va_layout.c b/arch/arm64/kvm/va_layout.c
index 9783013..acdb7b3 100644
--- a/arch/arm64/kvm/va_layout.c
+++ b/arch/arm64/kvm/va_layout.c
@@ -288,3 +288,10 @@
 {
 	generate_mov_q(kimage_voffset, origptr, updptr, nr_inst);
 }
+
+void kvm_compute_final_ctr_el0(struct alt_instr *alt,
+			       __le32 *origptr, __le32 *updptr, int nr_inst)
+{
+	generate_mov_q(read_sanitised_ftr_reg(SYS_CTR_EL0),
+		       origptr, updptr, nr_inst);
+}
diff --git a/arch/arm64/kvm/vgic/vgic-init.c b/arch/arm64/kvm/vgic/vgic-init.c
index 052917d..58cbda0 100644
--- a/arch/arm64/kvm/vgic/vgic-init.c
+++ b/arch/arm64/kvm/vgic/vgic-init.c
@@ -335,13 +335,14 @@
 	kfree(dist->spis);
 	dist->spis = NULL;
 	dist->nr_spis = 0;
+	dist->vgic_dist_base = VGIC_ADDR_UNDEF;
 
-	if (kvm->arch.vgic.vgic_model == KVM_DEV_TYPE_ARM_VGIC_V3) {
-		list_for_each_entry_safe(rdreg, next, &dist->rd_regions, list) {
-			list_del(&rdreg->list);
-			kfree(rdreg);
-		}
+	if (dist->vgic_model == KVM_DEV_TYPE_ARM_VGIC_V3) {
+		list_for_each_entry_safe(rdreg, next, &dist->rd_regions, list)
+			vgic_v3_free_redist_region(rdreg);
 		INIT_LIST_HEAD(&dist->rd_regions);
+	} else {
+		dist->vgic_cpu_base = VGIC_ADDR_UNDEF;
 	}
 
 	if (vgic_has_its(kvm))
@@ -362,6 +363,7 @@
 	vgic_flush_pending_lpis(vcpu);
 
 	INIT_LIST_HEAD(&vgic_cpu->ap_list_head);
+	vgic_cpu->rd_iodev.base_addr = VGIC_ADDR_UNDEF;
 }
 
 /* To be called with kvm->lock held */
diff --git a/arch/arm64/kvm/vgic/vgic-its.c b/arch/arm64/kvm/vgic/vgic-its.c
index 40cbaca..ec7543a 100644
--- a/arch/arm64/kvm/vgic/vgic-its.c
+++ b/arch/arm64/kvm/vgic/vgic-its.c
@@ -2218,10 +2218,10 @@
 		/*
 		 * If an LPI carries the HW bit, this means that this
 		 * interrupt is controlled by GICv4, and we do not
-		 * have direct access to that state. Let's simply fail
-		 * the save operation...
+		 * have direct access to that state without GICv4.1.
+		 * Let's simply fail the save operation...
 		 */
-		if (ite->irq->hw)
+		if (ite->irq->hw && !kvm_vgic_global_state.has_gicv4_1)
 			return -EACCES;
 
 		ret = vgic_its_save_ite(its, device, ite, gpa, ite_esz);
diff --git a/arch/arm64/kvm/vgic/vgic-kvm-device.c b/arch/arm64/kvm/vgic/vgic-kvm-device.c
index 4441967..2f66cf2 100644
--- a/arch/arm64/kvm/vgic/vgic-kvm-device.c
+++ b/arch/arm64/kvm/vgic/vgic-kvm-device.c
@@ -226,6 +226,9 @@
 		u64 addr;
 		unsigned long type = (unsigned long)attr->attr;
 
+		if (copy_from_user(&addr, uaddr, sizeof(addr)))
+			return -EFAULT;
+
 		r = kvm_vgic_addr(dev->kvm, type, &addr, false);
 		if (r)
 			return (r == -ENODEV) ? -ENXIO : r;
diff --git a/arch/arm64/kvm/vgic/vgic-mmio-v3.c b/arch/arm64/kvm/vgic/vgic-mmio-v3.c
index 15a6c98..03a2537 100644
--- a/arch/arm64/kvm/vgic/vgic-mmio-v3.c
+++ b/arch/arm64/kvm/vgic/vgic-mmio-v3.c
@@ -251,45 +251,52 @@
 		vgic_enable_lpis(vcpu);
 }
 
+static bool vgic_mmio_vcpu_rdist_is_last(struct kvm_vcpu *vcpu)
+{
+	struct vgic_dist *vgic = &vcpu->kvm->arch.vgic;
+	struct vgic_cpu *vgic_cpu = &vcpu->arch.vgic_cpu;
+	struct vgic_redist_region *iter, *rdreg = vgic_cpu->rdreg;
+
+	if (!rdreg)
+		return false;
+
+	if (vgic_cpu->rdreg_index < rdreg->free_index - 1) {
+		return false;
+	} else if (rdreg->count && vgic_cpu->rdreg_index == (rdreg->count - 1)) {
+		struct list_head *rd_regions = &vgic->rd_regions;
+		gpa_t end = rdreg->base + rdreg->count * KVM_VGIC_V3_REDIST_SIZE;
+
+		/*
+		 * the rdist is the last one of the redist region,
+		 * check whether there is no other contiguous rdist region
+		 */
+		list_for_each_entry(iter, rd_regions, list) {
+			if (iter->base == end && iter->free_index > 0)
+				return false;
+		}
+	}
+	return true;
+}
+
 static unsigned long vgic_mmio_read_v3r_typer(struct kvm_vcpu *vcpu,
 					      gpa_t addr, unsigned int len)
 {
 	unsigned long mpidr = kvm_vcpu_get_mpidr_aff(vcpu);
-	struct vgic_cpu *vgic_cpu = &vcpu->arch.vgic_cpu;
-	struct vgic_redist_region *rdreg = vgic_cpu->rdreg;
 	int target_vcpu_id = vcpu->vcpu_id;
-	gpa_t last_rdist_typer = rdreg->base + GICR_TYPER +
-			(rdreg->free_index - 1) * KVM_VGIC_V3_REDIST_SIZE;
 	u64 value;
 
 	value = (u64)(mpidr & GENMASK(23, 0)) << 32;
 	value |= ((target_vcpu_id & 0xffff) << 8);
 
-	if (addr == last_rdist_typer)
+	if (vgic_has_its(vcpu->kvm))
+		value |= GICR_TYPER_PLPIS;
+
+	if (vgic_mmio_vcpu_rdist_is_last(vcpu))
 		value |= GICR_TYPER_LAST;
-	if (vgic_has_its(vcpu->kvm))
-		value |= GICR_TYPER_PLPIS;
 
 	return extract_bytes(value, addr & 7, len);
 }
 
-static unsigned long vgic_uaccess_read_v3r_typer(struct kvm_vcpu *vcpu,
-						 gpa_t addr, unsigned int len)
-{
-	unsigned long mpidr = kvm_vcpu_get_mpidr_aff(vcpu);
-	int target_vcpu_id = vcpu->vcpu_id;
-	u64 value;
-
-	value = (u64)(mpidr & GENMASK(23, 0)) << 32;
-	value |= ((target_vcpu_id & 0xffff) << 8);
-
-	if (vgic_has_its(vcpu->kvm))
-		value |= GICR_TYPER_PLPIS;
-
-	/* reporting of the Last bit is not supported for userspace */
-	return extract_bytes(value, addr & 7, len);
-}
-
 static unsigned long vgic_mmio_read_v3r_iidr(struct kvm_vcpu *vcpu,
 					     gpa_t addr, unsigned int len)
 {
@@ -612,7 +619,7 @@
 		VGIC_ACCESS_32bit),
 	REGISTER_DESC_WITH_LENGTH_UACCESS(GICR_TYPER,
 		vgic_mmio_read_v3r_typer, vgic_mmio_write_wi,
-		vgic_uaccess_read_v3r_typer, vgic_mmio_uaccess_write_wi, 8,
+		NULL, vgic_mmio_uaccess_write_wi, 8,
 		VGIC_ACCESS_64bit | VGIC_ACCESS_32bit),
 	REGISTER_DESC_WITH_LENGTH(GICR_WAKER,
 		vgic_mmio_read_raz, vgic_mmio_write_wi, 4,
@@ -714,6 +721,7 @@
 		return -EINVAL;
 
 	vgic_cpu->rdreg = rdreg;
+	vgic_cpu->rdreg_index = rdreg->free_index;
 
 	rd_base = rdreg->base + rdreg->free_index * KVM_VGIC_V3_REDIST_SIZE;
 
@@ -768,7 +776,7 @@
 }
 
 /**
- * vgic_v3_insert_redist_region - Insert a new redistributor region
+ * vgic_v3_alloc_redist_region - Allocate a new redistributor region
  *
  * Performs various checks before inserting the rdist region in the list.
  * Those tests depend on whether the size of the rdist region is known
@@ -782,8 +790,8 @@
  *
  * Return 0 on success, < 0 otherwise
  */
-static int vgic_v3_insert_redist_region(struct kvm *kvm, uint32_t index,
-					gpa_t base, uint32_t count)
+static int vgic_v3_alloc_redist_region(struct kvm *kvm, uint32_t index,
+				       gpa_t base, uint32_t count)
 {
 	struct vgic_dist *d = &kvm->arch.vgic;
 	struct vgic_redist_region *rdreg;
@@ -791,10 +799,6 @@
 	size_t size = count * KVM_VGIC_V3_REDIST_SIZE;
 	int ret;
 
-	/* single rdist region already set ?*/
-	if (!count && !list_empty(rd_regions))
-		return -EINVAL;
-
 	/* cross the end of memory ? */
 	if (base + size < base)
 		return -EINVAL;
@@ -805,11 +809,15 @@
 	} else {
 		rdreg = list_last_entry(rd_regions,
 					struct vgic_redist_region, list);
-		if (index != rdreg->index + 1)
+
+		/* Don't mix single region and discrete redist regions */
+		if (!count && rdreg->count)
 			return -EINVAL;
 
-		/* Cannot add an explicitly sized regions after legacy region */
-		if (!rdreg->count)
+		if (!count)
+			return -EEXIST;
+
+		if (index != rdreg->index + 1)
 			return -EINVAL;
 	}
 
@@ -848,11 +856,17 @@
 	return ret;
 }
 
+void vgic_v3_free_redist_region(struct vgic_redist_region *rdreg)
+{
+	list_del(&rdreg->list);
+	kfree(rdreg);
+}
+
 int vgic_v3_set_redist_base(struct kvm *kvm, u32 index, u64 addr, u32 count)
 {
 	int ret;
 
-	ret = vgic_v3_insert_redist_region(kvm, index, addr, count);
+	ret = vgic_v3_alloc_redist_region(kvm, index, addr, count);
 	if (ret)
 		return ret;
 
@@ -861,8 +875,13 @@
 	 * afterwards will register the iodevs when needed.
 	 */
 	ret = vgic_register_all_redist_iodevs(kvm);
-	if (ret)
+	if (ret) {
+		struct vgic_redist_region *rdreg;
+
+		rdreg = vgic_v3_rdist_region_from_index(kvm, index);
+		vgic_v3_free_redist_region(rdreg);
 		return ret;
+	}
 
 	return 0;
 }
diff --git a/arch/arm64/kvm/vgic/vgic-mmio.c b/arch/arm64/kvm/vgic/vgic-mmio.c
index b2d73fc..48c6067 100644
--- a/arch/arm64/kvm/vgic/vgic-mmio.c
+++ b/arch/arm64/kvm/vgic/vgic-mmio.c
@@ -938,10 +938,9 @@
 	return region;
 }
 
-static int vgic_uaccess_read(struct kvm_vcpu *vcpu, struct kvm_io_device *dev,
+static int vgic_uaccess_read(struct kvm_vcpu *vcpu, struct vgic_io_device *iodev,
 			     gpa_t addr, u32 *val)
 {
-	struct vgic_io_device *iodev = kvm_to_vgic_iodev(dev);
 	const struct vgic_register_region *region;
 	struct kvm_vcpu *r_vcpu;
 
@@ -960,10 +959,9 @@
 	return 0;
 }
 
-static int vgic_uaccess_write(struct kvm_vcpu *vcpu, struct kvm_io_device *dev,
+static int vgic_uaccess_write(struct kvm_vcpu *vcpu, struct vgic_io_device *iodev,
 			      gpa_t addr, const u32 *val)
 {
-	struct vgic_io_device *iodev = kvm_to_vgic_iodev(dev);
 	const struct vgic_register_region *region;
 	struct kvm_vcpu *r_vcpu;
 
@@ -986,9 +984,9 @@
 		 bool is_write, int offset, u32 *val)
 {
 	if (is_write)
-		return vgic_uaccess_write(vcpu, &dev->dev, offset, val);
+		return vgic_uaccess_write(vcpu, dev, offset, val);
 	else
-		return vgic_uaccess_read(vcpu, &dev->dev, offset, val);
+		return vgic_uaccess_read(vcpu, dev, offset, val);
 }
 
 static int dispatch_mmio_read(struct kvm_vcpu *vcpu, struct kvm_io_device *dev,
diff --git a/arch/arm64/kvm/vgic/vgic-v3.c b/arch/arm64/kvm/vgic/vgic-v3.c
index 6f53092..41ecf21 100644
--- a/arch/arm64/kvm/vgic/vgic-v3.c
+++ b/arch/arm64/kvm/vgic/vgic-v3.c
@@ -1,6 +1,8 @@
 // SPDX-License-Identifier: GPL-2.0-only
 
 #include <linux/irqchip/arm-gic-v3.h>
+#include <linux/irq.h>
+#include <linux/irqdomain.h>
 #include <linux/kvm.h>
 #include <linux/kvm_host.h>
 #include <kvm/arm_vgic.h>
@@ -356,6 +358,32 @@
 	return 0;
 }
 
+/*
+ * The deactivation of the doorbell interrupt will trigger the
+ * unmapping of the associated vPE.
+ */
+static void unmap_all_vpes(struct vgic_dist *dist)
+{
+	struct irq_desc *desc;
+	int i;
+
+	for (i = 0; i < dist->its_vm.nr_vpes; i++) {
+		desc = irq_to_desc(dist->its_vm.vpes[i]->irq);
+		irq_domain_deactivate_irq(irq_desc_get_irq_data(desc));
+	}
+}
+
+static void map_all_vpes(struct vgic_dist *dist)
+{
+	struct irq_desc *desc;
+	int i;
+
+	for (i = 0; i < dist->its_vm.nr_vpes; i++) {
+		desc = irq_to_desc(dist->its_vm.vpes[i]->irq);
+		irq_domain_activate_irq(irq_desc_get_irq_data(desc), false);
+	}
+}
+
 /**
  * vgic_v3_save_pending_tables - Save the pending tables into guest RAM
  * kvm lock and all vcpu lock must be held
@@ -365,13 +393,28 @@
 	struct vgic_dist *dist = &kvm->arch.vgic;
 	struct vgic_irq *irq;
 	gpa_t last_ptr = ~(gpa_t)0;
-	int ret;
+	bool vlpi_avail = false;
+	int ret = 0;
 	u8 val;
 
+	if (unlikely(!vgic_initialized(kvm)))
+		return -ENXIO;
+
+	/*
+	 * A preparation for getting any VLPI states.
+	 * The above vgic initialized check also ensures that the allocation
+	 * and enabling of the doorbells have already been done.
+	 */
+	if (kvm_vgic_global_state.has_gicv4_1) {
+		unmap_all_vpes(dist);
+		vlpi_avail = true;
+	}
+
 	list_for_each_entry(irq, &dist->lpi_list_head, lpi_list) {
 		int byte_offset, bit_nr;
 		struct kvm_vcpu *vcpu;
 		gpa_t pendbase, ptr;
+		bool is_pending;
 		bool stored;
 
 		vcpu = irq->target_vcpu;
@@ -387,24 +430,35 @@
 		if (ptr != last_ptr) {
 			ret = kvm_read_guest_lock(kvm, ptr, &val, 1);
 			if (ret)
-				return ret;
+				goto out;
 			last_ptr = ptr;
 		}
 
 		stored = val & (1U << bit_nr);
-		if (stored == irq->pending_latch)
+
+		is_pending = irq->pending_latch;
+
+		if (irq->hw && vlpi_avail)
+			vgic_v4_get_vlpi_state(irq, &is_pending);
+
+		if (stored == is_pending)
 			continue;
 
-		if (irq->pending_latch)
+		if (is_pending)
 			val |= 1 << bit_nr;
 		else
 			val &= ~(1 << bit_nr);
 
 		ret = kvm_write_guest_lock(kvm, ptr, &val, 1);
 		if (ret)
-			return ret;
+			goto out;
 	}
-	return 0;
+
+out:
+	if (vlpi_avail)
+		map_all_vpes(dist);
+
+	return ret;
 }
 
 /**
diff --git a/arch/arm64/kvm/vgic/vgic-v4.c b/arch/arm64/kvm/vgic/vgic-v4.c
index 66508b0..c1845d8 100644
--- a/arch/arm64/kvm/vgic/vgic-v4.c
+++ b/arch/arm64/kvm/vgic/vgic-v4.c
@@ -203,6 +203,25 @@
 	kvm_arm_resume_guest(kvm);
 }
 
+/*
+ * Must be called with GICv4.1 and the vPE unmapped, which
+ * indicates the invalidation of any VPT caches associated
+ * with the vPE, thus we can get the VLPI state by peeking
+ * at the VPT.
+ */
+void vgic_v4_get_vlpi_state(struct vgic_irq *irq, bool *val)
+{
+	struct its_vpe *vpe = &irq->target_vcpu->arch.vgic_cpu.vgic_v3.its_vpe;
+	int mask = BIT(irq->intid % BITS_PER_BYTE);
+	void *va;
+	u8 *ptr;
+
+	va = page_address(vpe->vpt_page);
+	ptr = va + irq->intid / BITS_PER_BYTE;
+
+	*val = !!(*ptr & mask);
+}
+
 /**
  * vgic_v4_init - Initialize the GICv4 data structures
  * @kvm:	Pointer to the VM being initialized
@@ -385,6 +404,7 @@
 	struct vgic_its *its;
 	struct vgic_irq *irq;
 	struct its_vlpi_map map;
+	unsigned long flags;
 	int ret;
 
 	if (!vgic_supports_direct_msis(kvm))
@@ -430,6 +450,24 @@
 	irq->host_irq	= virq;
 	atomic_inc(&map.vpe->vlpi_count);
 
+	/* Transfer pending state */
+	raw_spin_lock_irqsave(&irq->irq_lock, flags);
+	if (irq->pending_latch) {
+		ret = irq_set_irqchip_state(irq->host_irq,
+					    IRQCHIP_STATE_PENDING,
+					    irq->pending_latch);
+		WARN_RATELIMIT(ret, "IRQ %d", irq->host_irq);
+
+		/*
+		 * Clear pending_latch and communicate this state
+		 * change via vgic_queue_irq_unlock.
+		 */
+		irq->pending_latch = false;
+		vgic_queue_irq_unlock(kvm, irq, flags);
+	} else {
+		raw_spin_unlock_irqrestore(&irq->irq_lock, flags);
+	}
+
 out:
 	mutex_unlock(&its->its_lock);
 	return ret;
diff --git a/arch/arm64/kvm/vgic/vgic.h b/arch/arm64/kvm/vgic/vgic.h
index 64fcd75..dc1f3d1 100644
--- a/arch/arm64/kvm/vgic/vgic.h
+++ b/arch/arm64/kvm/vgic/vgic.h
@@ -293,6 +293,7 @@
 
 struct vgic_redist_region *vgic_v3_rdist_region_from_index(struct kvm *kvm,
 							   u32 index);
+void vgic_v3_free_redist_region(struct vgic_redist_region *rdreg);
 
 bool vgic_v3_rdist_overlap(struct kvm *kvm, gpa_t base, size_t size);
 
@@ -317,5 +318,6 @@
 int vgic_v4_init(struct kvm *kvm);
 void vgic_v4_teardown(struct kvm *kvm);
 void vgic_v4_configure_vsgis(struct kvm *kvm);
+void vgic_v4_get_vlpi_state(struct vgic_irq *irq, bool *val);
 
 #endif
diff --git a/arch/arm64/lib/clear_page.S b/arch/arm64/lib/clear_page.S
index 073acbf..b84b179 100644
--- a/arch/arm64/lib/clear_page.S
+++ b/arch/arm64/lib/clear_page.S
@@ -14,7 +14,7 @@
  * Parameters:
  *	x0 - dest
  */
-SYM_FUNC_START(clear_page)
+SYM_FUNC_START_PI(clear_page)
 	mrs	x1, dczid_el0
 	and	w1, w1, #0xf
 	mov	x2, #4
@@ -25,5 +25,5 @@
 	tst	x0, #(PAGE_SIZE - 1)
 	b.ne	1b
 	ret
-SYM_FUNC_END(clear_page)
+SYM_FUNC_END_PI(clear_page)
 EXPORT_SYMBOL(clear_page)
diff --git a/arch/arm64/lib/copy_page.S b/arch/arm64/lib/copy_page.S
index e7a79396..29144f4 100644
--- a/arch/arm64/lib/copy_page.S
+++ b/arch/arm64/lib/copy_page.S
@@ -17,7 +17,7 @@
  *	x0 - dest
  *	x1 - src
  */
-SYM_FUNC_START(copy_page)
+SYM_FUNC_START_PI(copy_page)
 alternative_if ARM64_HAS_NO_HW_PREFETCH
 	// Prefetch three cache lines ahead.
 	prfm	pldl1strm, [x1, #128]
@@ -75,5 +75,5 @@
 	stnp	x16, x17, [x0, #112 - 256]
 
 	ret
-SYM_FUNC_END(copy_page)
+SYM_FUNC_END_PI(copy_page)
 EXPORT_SYMBOL(copy_page)
diff --git a/arch/arm64/mm/init.c b/arch/arm64/mm/init.c
index 3685e12..6cb22da 100644
--- a/arch/arm64/mm/init.c
+++ b/arch/arm64/mm/init.c
@@ -35,6 +35,7 @@
 #include <asm/fixmap.h>
 #include <asm/kasan.h>
 #include <asm/kernel-pgtable.h>
+#include <asm/kvm_host.h>
 #include <asm/memory.h>
 #include <asm/numa.h>
 #include <asm/sections.h>
@@ -452,6 +453,8 @@
 
 	dma_pernuma_cma_reserve();
 
+	kvm_hyp_reserve();
+
 	/*
 	 * sparse_init() tries to allocate memory from memblock, so must be
 	 * done after the fixed reservations
diff --git a/arch/mips/include/asm/kvm_host.h b/arch/mips/include/asm/kvm_host.h
index 3a5612e..d0944a7 100644
--- a/arch/mips/include/asm/kvm_host.h
+++ b/arch/mips/include/asm/kvm_host.h
@@ -815,14 +815,7 @@
 	int (*vcpu_init)(struct kvm_vcpu *vcpu);
 	void (*vcpu_uninit)(struct kvm_vcpu *vcpu);
 	int (*vcpu_setup)(struct kvm_vcpu *vcpu);
-	void (*flush_shadow_all)(struct kvm *kvm);
-	/*
-	 * Must take care of flushing any cached GPA PTEs (e.g. guest entries in
-	 * VZ root TLB, or T&E GVA page tables and corresponding root TLB
-	 * mappings).
-	 */
-	void (*flush_shadow_memslot)(struct kvm *kvm,
-				     const struct kvm_memory_slot *slot);
+	void (*prepare_flush_shadow)(struct kvm *kvm);
 	gpa_t (*gva_to_gpa)(gva_t gva);
 	void (*queue_timer_int)(struct kvm_vcpu *vcpu);
 	void (*dequeue_timer_int)(struct kvm_vcpu *vcpu);
@@ -967,11 +960,6 @@
 						   bool write);
 
 #define KVM_ARCH_WANT_MMU_NOTIFIER
-int kvm_unmap_hva_range(struct kvm *kvm,
-			unsigned long start, unsigned long end, unsigned flags);
-int kvm_set_spte_hva(struct kvm *kvm, unsigned long hva, pte_t pte);
-int kvm_age_hva(struct kvm *kvm, unsigned long start, unsigned long end);
-int kvm_test_age_hva(struct kvm *kvm, unsigned long hva);
 
 /* Emulation */
 int kvm_get_inst(u32 *opc, struct kvm_vcpu *vcpu, u32 *out);
@@ -1154,4 +1142,7 @@
 static inline void kvm_arch_vcpu_unblocking(struct kvm_vcpu *vcpu) {}
 static inline void kvm_arch_vcpu_block_finish(struct kvm_vcpu *vcpu) {}
 
+#define __KVM_HAVE_ARCH_FLUSH_REMOTE_TLB
+int kvm_arch_flush_remote_tlb(struct kvm *kvm);
+
 #endif /* __MIPS_KVM_HOST_H__ */
diff --git a/arch/mips/kvm/mips.c b/arch/mips/kvm/mips.c
index 58a8812..4a22ba7 100644
--- a/arch/mips/kvm/mips.c
+++ b/arch/mips/kvm/mips.c
@@ -204,9 +204,7 @@
 {
 	/* Flush whole GPA */
 	kvm_mips_flush_gpa_pt(kvm, 0, ~0);
-
-	/* Let implementation do the rest */
-	kvm_mips_callbacks->flush_shadow_all(kvm);
+	kvm_flush_remote_tlbs(kvm);
 }
 
 void kvm_arch_flush_shadow_memslot(struct kvm *kvm,
@@ -221,8 +219,7 @@
 	/* Flush slot from GPA */
 	kvm_mips_flush_gpa_pt(kvm, slot->base_gfn,
 			      slot->base_gfn + slot->npages - 1);
-	/* Let implementation do the rest */
-	kvm_mips_callbacks->flush_shadow_memslot(kvm, slot);
+	kvm_arch_flush_remote_tlbs_memslot(kvm, slot);
 	spin_unlock(&kvm->mmu_lock);
 }
 
@@ -262,9 +259,8 @@
 		/* Write protect GPA page table entries */
 		needs_flush = kvm_mips_mkclean_gpa_pt(kvm, new->base_gfn,
 					new->base_gfn + new->npages - 1);
-		/* Let implementation do the rest */
 		if (needs_flush)
-			kvm_mips_callbacks->flush_shadow_memslot(kvm, new);
+			kvm_arch_flush_remote_tlbs_memslot(kvm, new);
 		spin_unlock(&kvm->mmu_lock);
 	}
 }
@@ -996,11 +992,16 @@
 
 }
 
-void kvm_arch_flush_remote_tlbs_memslot(struct kvm *kvm,
-					struct kvm_memory_slot *memslot)
+int kvm_arch_flush_remote_tlb(struct kvm *kvm)
 {
-	/* Let implementation handle TLB/GVA invalidation */
-	kvm_mips_callbacks->flush_shadow_memslot(kvm, memslot);
+	kvm_mips_callbacks->prepare_flush_shadow(kvm);
+	return 1;
+}
+
+void kvm_arch_flush_remote_tlbs_memslot(struct kvm *kvm,
+					const struct kvm_memory_slot *memslot)
+{
+	kvm_flush_remote_tlbs(kvm);
 }
 
 long kvm_arch_vm_ioctl(struct file *filp, unsigned int ioctl, unsigned long arg)
diff --git a/arch/mips/kvm/mmu.c b/arch/mips/kvm/mmu.c
index 3dabeda..8af002b 100644
--- a/arch/mips/kvm/mmu.c
+++ b/arch/mips/kvm/mmu.c
@@ -439,85 +439,34 @@
 				  end_gfn << PAGE_SHIFT);
 }
 
-static int handle_hva_to_gpa(struct kvm *kvm,
-			     unsigned long start,
-			     unsigned long end,
-			     int (*handler)(struct kvm *kvm, gfn_t gfn,
-					    gpa_t gfn_end,
-					    struct kvm_memory_slot *memslot,
-					    void *data),
-			     void *data)
+bool kvm_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	struct kvm_memslots *slots;
-	struct kvm_memory_slot *memslot;
-	int ret = 0;
-
-	slots = kvm_memslots(kvm);
-
-	/* we only care about the pages that the guest sees */
-	kvm_for_each_memslot(memslot, slots) {
-		unsigned long hva_start, hva_end;
-		gfn_t gfn, gfn_end;
-
-		hva_start = max(start, memslot->userspace_addr);
-		hva_end = min(end, memslot->userspace_addr +
-					(memslot->npages << PAGE_SHIFT));
-		if (hva_start >= hva_end)
-			continue;
-
-		/*
-		 * {gfn(page) | page intersects with [hva_start, hva_end)} =
-		 * {gfn_start, gfn_start+1, ..., gfn_end-1}.
-		 */
-		gfn = hva_to_gfn_memslot(hva_start, memslot);
-		gfn_end = hva_to_gfn_memslot(hva_end + PAGE_SIZE - 1, memslot);
-
-		ret |= handler(kvm, gfn, gfn_end, memslot, data);
-	}
-
-	return ret;
-}
-
-
-static int kvm_unmap_hva_handler(struct kvm *kvm, gfn_t gfn, gfn_t gfn_end,
-				 struct kvm_memory_slot *memslot, void *data)
-{
-	kvm_mips_flush_gpa_pt(kvm, gfn, gfn_end);
+	kvm_mips_flush_gpa_pt(kvm, range->start, range->end);
 	return 1;
 }
 
-int kvm_unmap_hva_range(struct kvm *kvm, unsigned long start, unsigned long end,
-			unsigned flags)
+bool kvm_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	handle_hva_to_gpa(kvm, start, end, &kvm_unmap_hva_handler, NULL);
-
-	kvm_mips_callbacks->flush_shadow_all(kvm);
-	return 0;
-}
-
-static int kvm_set_spte_handler(struct kvm *kvm, gfn_t gfn, gfn_t gfn_end,
-				struct kvm_memory_slot *memslot, void *data)
-{
-	gpa_t gpa = gfn << PAGE_SHIFT;
-	pte_t hva_pte = *(pte_t *)data;
+	gpa_t gpa = range->start << PAGE_SHIFT;
+	pte_t hva_pte = range->pte;
 	pte_t *gpa_pte = kvm_mips_pte_for_gpa(kvm, NULL, gpa);
 	pte_t old_pte;
 
 	if (!gpa_pte)
-		return 0;
+		return false;
 
 	/* Mapping may need adjusting depending on memslot flags */
 	old_pte = *gpa_pte;
-	if (memslot->flags & KVM_MEM_LOG_DIRTY_PAGES && !pte_dirty(old_pte))
+	if (range->slot->flags & KVM_MEM_LOG_DIRTY_PAGES && !pte_dirty(old_pte))
 		hva_pte = pte_mkclean(hva_pte);
-	else if (memslot->flags & KVM_MEM_READONLY)
+	else if (range->slot->flags & KVM_MEM_READONLY)
 		hva_pte = pte_wrprotect(hva_pte);
 
 	set_pte(gpa_pte, hva_pte);
 
 	/* Replacing an absent or old page doesn't need flushes */
 	if (!pte_present(old_pte) || !pte_young(old_pte))
-		return 0;
+		return false;
 
 	/* Pages swapped, aged, moved, or cleaned require flushes */
 	return !pte_present(hva_pte) ||
@@ -526,27 +475,14 @@
 	       (pte_dirty(old_pte) && !pte_dirty(hva_pte));
 }
 
-int kvm_set_spte_hva(struct kvm *kvm, unsigned long hva, pte_t pte)
+bool kvm_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	unsigned long end = hva + PAGE_SIZE;
-	int ret;
-
-	ret = handle_hva_to_gpa(kvm, hva, end, &kvm_set_spte_handler, &pte);
-	if (ret)
-		kvm_mips_callbacks->flush_shadow_all(kvm);
-	return 0;
+	return kvm_mips_mkold_gpa_pt(kvm, range->start, range->end);
 }
 
-static int kvm_age_hva_handler(struct kvm *kvm, gfn_t gfn, gfn_t gfn_end,
-			       struct kvm_memory_slot *memslot, void *data)
+bool kvm_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	return kvm_mips_mkold_gpa_pt(kvm, gfn, gfn_end);
-}
-
-static int kvm_test_age_hva_handler(struct kvm *kvm, gfn_t gfn, gfn_t gfn_end,
-				    struct kvm_memory_slot *memslot, void *data)
-{
-	gpa_t gpa = gfn << PAGE_SHIFT;
+	gpa_t gpa = range->start << PAGE_SHIFT;
 	pte_t *gpa_pte = kvm_mips_pte_for_gpa(kvm, NULL, gpa);
 
 	if (!gpa_pte)
@@ -554,16 +490,6 @@
 	return pte_young(*gpa_pte);
 }
 
-int kvm_age_hva(struct kvm *kvm, unsigned long start, unsigned long end)
-{
-	return handle_hva_to_gpa(kvm, start, end, kvm_age_hva_handler, NULL);
-}
-
-int kvm_test_age_hva(struct kvm *kvm, unsigned long hva)
-{
-	return handle_hva_to_gpa(kvm, hva, hva, kvm_test_age_hva_handler, NULL);
-}
-
 /**
  * _kvm_mips_map_page_fast() - Fast path GPA fault handler.
  * @vcpu:		VCPU pointer.
diff --git a/arch/mips/kvm/trap_emul.c b/arch/mips/kvm/trap_emul.c
index 0788c00..5f2df49 100644
--- a/arch/mips/kvm/trap_emul.c
+++ b/arch/mips/kvm/trap_emul.c
@@ -687,16 +687,8 @@
 	return 0;
 }
 
-static void kvm_trap_emul_flush_shadow_all(struct kvm *kvm)
+static void kvm_trap_emul_prepare_flush_shadow(struct kvm *kvm)
 {
-	/* Flush GVA page tables and invalidate GVA ASIDs on all VCPUs */
-	kvm_flush_remote_tlbs(kvm);
-}
-
-static void kvm_trap_emul_flush_shadow_memslot(struct kvm *kvm,
-					const struct kvm_memory_slot *slot)
-{
-	kvm_trap_emul_flush_shadow_all(kvm);
 }
 
 static u64 kvm_trap_emul_get_one_regs[] = {
@@ -1280,8 +1272,7 @@
 	.vcpu_init = kvm_trap_emul_vcpu_init,
 	.vcpu_uninit = kvm_trap_emul_vcpu_uninit,
 	.vcpu_setup = kvm_trap_emul_vcpu_setup,
-	.flush_shadow_all = kvm_trap_emul_flush_shadow_all,
-	.flush_shadow_memslot = kvm_trap_emul_flush_shadow_memslot,
+	.prepare_flush_shadow = kvm_trap_emul_prepare_flush_shadow,
 	.gva_to_gpa = kvm_trap_emul_gva_to_gpa_cb,
 	.queue_timer_int = kvm_mips_queue_timer_int_cb,
 	.dequeue_timer_int = kvm_mips_dequeue_timer_int_cb,
diff --git a/arch/mips/kvm/vz.c b/arch/mips/kvm/vz.c
index 2ffbe92..2c75571 100644
--- a/arch/mips/kvm/vz.c
+++ b/arch/mips/kvm/vz.c
@@ -3211,32 +3211,22 @@
 	return 0;
 }
 
-static void kvm_vz_flush_shadow_all(struct kvm *kvm)
+static void kvm_vz_prepare_flush_shadow(struct kvm *kvm)
 {
-	if (cpu_has_guestid) {
-		/* Flush GuestID for each VCPU individually */
-		kvm_flush_remote_tlbs(kvm);
-	} else {
+	if (!cpu_has_guestid) {
 		/*
 		 * For each CPU there is a single GPA ASID used by all VCPUs in
 		 * the VM, so it doesn't make sense for the VCPUs to handle
 		 * invalidation of these ASIDs individually.
 		 *
 		 * Instead mark all CPUs as needing ASID invalidation in
-		 * asid_flush_mask, and just use kvm_flush_remote_tlbs(kvm) to
+		 * asid_flush_mask, and kvm_flush_remote_tlbs(kvm) will
 		 * kick any running VCPUs so they check asid_flush_mask.
 		 */
 		cpumask_setall(&kvm->arch.asid_flush_mask);
-		kvm_flush_remote_tlbs(kvm);
 	}
 }
 
-static void kvm_vz_flush_shadow_memslot(struct kvm *kvm,
-					const struct kvm_memory_slot *slot)
-{
-	kvm_vz_flush_shadow_all(kvm);
-}
-
 static void kvm_vz_vcpu_reenter(struct kvm_vcpu *vcpu)
 {
 	int cpu = smp_processor_id();
@@ -3292,8 +3282,7 @@
 	.vcpu_init = kvm_vz_vcpu_init,
 	.vcpu_uninit = kvm_vz_vcpu_uninit,
 	.vcpu_setup = kvm_vz_vcpu_setup,
-	.flush_shadow_all = kvm_vz_flush_shadow_all,
-	.flush_shadow_memslot = kvm_vz_flush_shadow_memslot,
+	.prepare_flush_shadow = kvm_vz_prepare_flush_shadow,
 	.gva_to_gpa = kvm_vz_gva_to_gpa_cb,
 	.queue_timer_int = kvm_vz_queue_timer_int_cb,
 	.dequeue_timer_int = kvm_vz_dequeue_timer_int_cb,
diff --git a/arch/powerpc/include/asm/kvm_book3s.h b/arch/powerpc/include/asm/kvm_book3s.h
index 2f5f919..2d03f293 100644
--- a/arch/powerpc/include/asm/kvm_book3s.h
+++ b/arch/powerpc/include/asm/kvm_book3s.h
@@ -210,12 +210,12 @@
 				      unsigned int lpid);
 extern int kvmppc_radix_init(void);
 extern void kvmppc_radix_exit(void);
-extern int kvm_unmap_radix(struct kvm *kvm, struct kvm_memory_slot *memslot,
-			unsigned long gfn);
-extern int kvm_age_radix(struct kvm *kvm, struct kvm_memory_slot *memslot,
-			unsigned long gfn);
-extern int kvm_test_age_radix(struct kvm *kvm, struct kvm_memory_slot *memslot,
-			unsigned long gfn);
+extern bool kvm_unmap_radix(struct kvm *kvm, struct kvm_memory_slot *memslot,
+			    unsigned long gfn);
+extern bool kvm_age_radix(struct kvm *kvm, struct kvm_memory_slot *memslot,
+			  unsigned long gfn);
+extern bool kvm_test_age_radix(struct kvm *kvm, struct kvm_memory_slot *memslot,
+			       unsigned long gfn);
 extern long kvmppc_hv_get_dirty_log_radix(struct kvm *kvm,
 			struct kvm_memory_slot *memslot, unsigned long *map);
 extern void kvmppc_radix_flush_memslot(struct kvm *kvm,
diff --git a/arch/powerpc/include/asm/kvm_host.h b/arch/powerpc/include/asm/kvm_host.h
index 05fb00d..1e83359 100644
--- a/arch/powerpc/include/asm/kvm_host.h
+++ b/arch/powerpc/include/asm/kvm_host.h
@@ -56,13 +56,6 @@
 
 #define KVM_ARCH_WANT_MMU_NOTIFIER
 
-extern int kvm_unmap_hva_range(struct kvm *kvm,
-			       unsigned long start, unsigned long end,
-			       unsigned flags);
-extern int kvm_age_hva(struct kvm *kvm, unsigned long start, unsigned long end);
-extern int kvm_test_age_hva(struct kvm *kvm, unsigned long hva);
-extern int kvm_set_spte_hva(struct kvm *kvm, unsigned long hva, pte_t pte);
-
 #define HPTEG_CACHE_NUM			(1 << 15)
 #define HPTEG_HASH_BITS_PTE		13
 #define HPTEG_HASH_BITS_PTE_LONG	12
diff --git a/arch/powerpc/include/asm/kvm_ppc.h b/arch/powerpc/include/asm/kvm_ppc.h
index 8aacd76..21ab033 100644
--- a/arch/powerpc/include/asm/kvm_ppc.h
+++ b/arch/powerpc/include/asm/kvm_ppc.h
@@ -281,11 +281,10 @@
 				     const struct kvm_memory_slot *old,
 				     const struct kvm_memory_slot *new,
 				     enum kvm_mr_change change);
-	int (*unmap_hva_range)(struct kvm *kvm, unsigned long start,
-			   unsigned long end);
-	int (*age_hva)(struct kvm *kvm, unsigned long start, unsigned long end);
-	int (*test_age_hva)(struct kvm *kvm, unsigned long hva);
-	void (*set_spte_hva)(struct kvm *kvm, unsigned long hva, pte_t pte);
+	bool (*unmap_gfn_range)(struct kvm *kvm, struct kvm_gfn_range *range);
+	bool (*age_gfn)(struct kvm *kvm, struct kvm_gfn_range *range);
+	bool (*test_age_gfn)(struct kvm *kvm, struct kvm_gfn_range *range);
+	bool (*set_spte_gfn)(struct kvm *kvm, struct kvm_gfn_range *range);
 	void (*free_memslot)(struct kvm_memory_slot *slot);
 	int (*init_vm)(struct kvm *kvm);
 	void (*destroy_vm)(struct kvm *kvm);
diff --git a/arch/powerpc/kvm/book3s.c b/arch/powerpc/kvm/book3s.c
index 44bf567..2b691f4 100644
--- a/arch/powerpc/kvm/book3s.c
+++ b/arch/powerpc/kvm/book3s.c
@@ -834,26 +834,24 @@
 	kvm->arch.kvm_ops->commit_memory_region(kvm, mem, old, new, change);
 }
 
-int kvm_unmap_hva_range(struct kvm *kvm, unsigned long start, unsigned long end,
-			unsigned flags)
+bool kvm_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	return kvm->arch.kvm_ops->unmap_hva_range(kvm, start, end);
+	return kvm->arch.kvm_ops->unmap_gfn_range(kvm, range);
 }
 
-int kvm_age_hva(struct kvm *kvm, unsigned long start, unsigned long end)
+bool kvm_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	return kvm->arch.kvm_ops->age_hva(kvm, start, end);
+	return kvm->arch.kvm_ops->age_gfn(kvm, range);
 }
 
-int kvm_test_age_hva(struct kvm *kvm, unsigned long hva)
+bool kvm_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	return kvm->arch.kvm_ops->test_age_hva(kvm, hva);
+	return kvm->arch.kvm_ops->test_age_gfn(kvm, range);
 }
 
-int kvm_set_spte_hva(struct kvm *kvm, unsigned long hva, pte_t pte)
+bool kvm_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	kvm->arch.kvm_ops->set_spte_hva(kvm, hva, pte);
-	return 0;
+	return kvm->arch.kvm_ops->set_spte_gfn(kvm, range);
 }
 
 int kvmppc_core_init_vm(struct kvm *kvm)
diff --git a/arch/powerpc/kvm/book3s.h b/arch/powerpc/kvm/book3s.h
index 9b6323e..740e51d 100644
--- a/arch/powerpc/kvm/book3s.h
+++ b/arch/powerpc/kvm/book3s.h
@@ -9,12 +9,10 @@
 
 extern void kvmppc_core_flush_memslot_hv(struct kvm *kvm,
 					 struct kvm_memory_slot *memslot);
-extern int kvm_unmap_hva_range_hv(struct kvm *kvm, unsigned long start,
-				  unsigned long end);
-extern int kvm_age_hva_hv(struct kvm *kvm, unsigned long start,
-			  unsigned long end);
-extern int kvm_test_age_hva_hv(struct kvm *kvm, unsigned long hva);
-extern void kvm_set_spte_hva_hv(struct kvm *kvm, unsigned long hva, pte_t pte);
+extern bool kvm_unmap_gfn_range_hv(struct kvm *kvm, struct kvm_gfn_range *range);
+extern bool kvm_age_gfn_hv(struct kvm *kvm, struct kvm_gfn_range *range);
+extern bool kvm_test_age_gfn_hv(struct kvm *kvm, struct kvm_gfn_range *range);
+extern bool kvm_set_spte_gfn_hv(struct kvm *kvm, struct kvm_gfn_range *range);
 
 extern int kvmppc_mmu_init_pr(struct kvm_vcpu *vcpu);
 extern void kvmppc_mmu_destroy_pr(struct kvm_vcpu *vcpu);
diff --git a/arch/powerpc/kvm/book3s_64_mmu_hv.c b/arch/powerpc/kvm/book3s_64_mmu_hv.c
index bb67735..b7bd9ca 100644
--- a/arch/powerpc/kvm/book3s_64_mmu_hv.c
+++ b/arch/powerpc/kvm/book3s_64_mmu_hv.c
@@ -752,51 +752,6 @@
 	srcu_read_unlock(&kvm->srcu, srcu_idx);
 }
 
-typedef int (*hva_handler_fn)(struct kvm *kvm, struct kvm_memory_slot *memslot,
-			      unsigned long gfn);
-
-static int kvm_handle_hva_range(struct kvm *kvm,
-				unsigned long start,
-				unsigned long end,
-				hva_handler_fn handler)
-{
-	int ret;
-	int retval = 0;
-	struct kvm_memslots *slots;
-	struct kvm_memory_slot *memslot;
-
-	slots = kvm_memslots(kvm);
-	kvm_for_each_memslot(memslot, slots) {
-		unsigned long hva_start, hva_end;
-		gfn_t gfn, gfn_end;
-
-		hva_start = max(start, memslot->userspace_addr);
-		hva_end = min(end, memslot->userspace_addr +
-					(memslot->npages << PAGE_SHIFT));
-		if (hva_start >= hva_end)
-			continue;
-		/*
-		 * {gfn(page) | page intersects with [hva_start, hva_end)} =
-		 * {gfn, gfn+1, ..., gfn_end-1}.
-		 */
-		gfn = hva_to_gfn_memslot(hva_start, memslot);
-		gfn_end = hva_to_gfn_memslot(hva_end + PAGE_SIZE - 1, memslot);
-
-		for (; gfn < gfn_end; ++gfn) {
-			ret = handler(kvm, memslot, gfn);
-			retval |= ret;
-		}
-	}
-
-	return retval;
-}
-
-static int kvm_handle_hva(struct kvm *kvm, unsigned long hva,
-			  hva_handler_fn handler)
-{
-	return kvm_handle_hva_range(kvm, hva, hva + 1, handler);
-}
-
 /* Must be called with both HPTE and rmap locked */
 static void kvmppc_unmap_hpte(struct kvm *kvm, unsigned long i,
 			      struct kvm_memory_slot *memslot,
@@ -840,8 +795,8 @@
 	}
 }
 
-static int kvm_unmap_rmapp(struct kvm *kvm, struct kvm_memory_slot *memslot,
-			   unsigned long gfn)
+static bool kvm_unmap_rmapp(struct kvm *kvm, struct kvm_memory_slot *memslot,
+			    unsigned long gfn)
 {
 	unsigned long i;
 	__be64 *hptep;
@@ -874,16 +829,15 @@
 		unlock_rmap(rmapp);
 		__unlock_hpte(hptep, be64_to_cpu(hptep[0]));
 	}
-	return 0;
+	return false;
 }
 
-int kvm_unmap_hva_range_hv(struct kvm *kvm, unsigned long start, unsigned long end)
+bool kvm_unmap_gfn_range_hv(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	hva_handler_fn handler;
+	if (kvm_is_radix(kvm))
+		return kvm_unmap_radix(kvm, range->slot, range->start);
 
-	handler = kvm_is_radix(kvm) ? kvm_unmap_radix : kvm_unmap_rmapp;
-	kvm_handle_hva_range(kvm, start, end, handler);
-	return 0;
+	return kvm_unmap_rmapp(kvm, range->slot, range->start);
 }
 
 void kvmppc_core_flush_memslot_hv(struct kvm *kvm,
@@ -913,8 +867,8 @@
 	}
 }
 
-static int kvm_age_rmapp(struct kvm *kvm, struct kvm_memory_slot *memslot,
-			 unsigned long gfn)
+static bool kvm_age_rmapp(struct kvm *kvm, struct kvm_memory_slot *memslot,
+			  unsigned long gfn)
 {
 	struct revmap_entry *rev = kvm->arch.hpt.rev;
 	unsigned long head, i, j;
@@ -968,26 +922,26 @@
 	return ret;
 }
 
-int kvm_age_hva_hv(struct kvm *kvm, unsigned long start, unsigned long end)
+bool kvm_age_gfn_hv(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	hva_handler_fn handler;
+	if (kvm_is_radix(kvm))
+		kvm_age_radix(kvm, range->slot, range->start);
 
-	handler = kvm_is_radix(kvm) ? kvm_age_radix : kvm_age_rmapp;
-	return kvm_handle_hva_range(kvm, start, end, handler);
+	return kvm_age_rmapp(kvm, range->slot, range->start);
 }
 
-static int kvm_test_age_rmapp(struct kvm *kvm, struct kvm_memory_slot *memslot,
-			      unsigned long gfn)
+static bool kvm_test_age_rmapp(struct kvm *kvm, struct kvm_memory_slot *memslot,
+			       unsigned long gfn)
 {
 	struct revmap_entry *rev = kvm->arch.hpt.rev;
 	unsigned long head, i, j;
 	unsigned long *hp;
-	int ret = 1;
+	bool ret = true;
 	unsigned long *rmapp;
 
 	rmapp = &memslot->arch.rmap[gfn - memslot->base_gfn];
 	if (*rmapp & KVMPPC_RMAP_REFERENCED)
-		return 1;
+		return true;
 
 	lock_rmap(rmapp);
 	if (*rmapp & KVMPPC_RMAP_REFERENCED)
@@ -1002,27 +956,27 @@
 				goto out;
 		} while ((i = j) != head);
 	}
-	ret = 0;
+	ret = false;
 
  out:
 	unlock_rmap(rmapp);
 	return ret;
 }
 
-int kvm_test_age_hva_hv(struct kvm *kvm, unsigned long hva)
+bool kvm_test_age_gfn_hv(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	hva_handler_fn handler;
+	if (kvm_is_radix(kvm))
+		kvm_test_age_radix(kvm, range->slot, range->start);
 
-	handler = kvm_is_radix(kvm) ? kvm_test_age_radix : kvm_test_age_rmapp;
-	return kvm_handle_hva(kvm, hva, handler);
+	return kvm_test_age_rmapp(kvm, range->slot, range->start);
 }
 
-void kvm_set_spte_hva_hv(struct kvm *kvm, unsigned long hva, pte_t pte)
+bool kvm_set_spte_gfn_hv(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	hva_handler_fn handler;
+	if (kvm_is_radix(kvm))
+		return kvm_unmap_radix(kvm, range->slot, range->start);
 
-	handler = kvm_is_radix(kvm) ? kvm_unmap_radix : kvm_unmap_rmapp;
-	kvm_handle_hva(kvm, hva, handler);
+	return kvm_unmap_rmapp(kvm, range->slot, range->start);
 }
 
 static int vcpus_running(struct kvm *kvm)
diff --git a/arch/powerpc/kvm/book3s_64_mmu_radix.c b/arch/powerpc/kvm/book3s_64_mmu_radix.c
index e603de7..ec4f58f 100644
--- a/arch/powerpc/kvm/book3s_64_mmu_radix.c
+++ b/arch/powerpc/kvm/book3s_64_mmu_radix.c
@@ -993,8 +993,8 @@
 }
 
 /* Called with kvm->mmu_lock held */
-int kvm_unmap_radix(struct kvm *kvm, struct kvm_memory_slot *memslot,
-		    unsigned long gfn)
+bool kvm_unmap_radix(struct kvm *kvm, struct kvm_memory_slot *memslot,
+		     unsigned long gfn)
 {
 	pte_t *ptep;
 	unsigned long gpa = gfn << PAGE_SHIFT;
@@ -1002,24 +1002,24 @@
 
 	if (kvm->arch.secure_guest & KVMPPC_SECURE_INIT_DONE) {
 		uv_page_inval(kvm->arch.lpid, gpa, PAGE_SHIFT);
-		return 0;
+		return false;
 	}
 
 	ptep = find_kvm_secondary_pte(kvm, gpa, &shift);
 	if (ptep && pte_present(*ptep))
 		kvmppc_unmap_pte(kvm, ptep, gpa, shift, memslot,
 				 kvm->arch.lpid);
-	return 0;
+	return false;
 }
 
 /* Called with kvm->mmu_lock held */
-int kvm_age_radix(struct kvm *kvm, struct kvm_memory_slot *memslot,
-		  unsigned long gfn)
+bool kvm_age_radix(struct kvm *kvm, struct kvm_memory_slot *memslot,
+		   unsigned long gfn)
 {
 	pte_t *ptep;
 	unsigned long gpa = gfn << PAGE_SHIFT;
 	unsigned int shift;
-	int ref = 0;
+	bool ref = false;
 	unsigned long old, *rmapp;
 
 	if (kvm->arch.secure_guest & KVMPPC_SECURE_INIT_DONE)
@@ -1035,26 +1035,27 @@
 		kvmhv_update_nest_rmap_rc_list(kvm, rmapp, _PAGE_ACCESSED, 0,
 					       old & PTE_RPN_MASK,
 					       1UL << shift);
-		ref = 1;
+		ref = true;
 	}
 	return ref;
 }
 
 /* Called with kvm->mmu_lock held */
-int kvm_test_age_radix(struct kvm *kvm, struct kvm_memory_slot *memslot,
-		       unsigned long gfn)
+bool kvm_test_age_radix(struct kvm *kvm, struct kvm_memory_slot *memslot,
+			unsigned long gfn)
+
 {
 	pte_t *ptep;
 	unsigned long gpa = gfn << PAGE_SHIFT;
 	unsigned int shift;
-	int ref = 0;
+	bool ref = false;
 
 	if (kvm->arch.secure_guest & KVMPPC_SECURE_INIT_DONE)
 		return ref;
 
 	ptep = find_kvm_secondary_pte(kvm, gpa, &shift);
 	if (ptep && pte_present(*ptep) && pte_young(*ptep))
-		ref = 1;
+		ref = true;
 	return ref;
 }
 
diff --git a/arch/powerpc/kvm/book3s_hv.c b/arch/powerpc/kvm/book3s_hv.c
index 13bad6b..07682ad 100644
--- a/arch/powerpc/kvm/book3s_hv.c
+++ b/arch/powerpc/kvm/book3s_hv.c
@@ -4770,7 +4770,7 @@
 		kvmhv_release_all_nested(kvm);
 	kvmppc_rmap_reset(kvm);
 	kvm->arch.process_table = 0;
-	/* Mutual exclusion with kvm_unmap_hva_range etc. */
+	/* Mutual exclusion with kvm_unmap_gfn_range etc. */
 	spin_lock(&kvm->mmu_lock);
 	kvm->arch.radix = 0;
 	spin_unlock(&kvm->mmu_lock);
@@ -4792,7 +4792,7 @@
 	if (err)
 		return err;
 	kvmppc_rmap_reset(kvm);
-	/* Mutual exclusion with kvm_unmap_hva_range etc. */
+	/* Mutual exclusion with kvm_unmap_gfn_range etc. */
 	spin_lock(&kvm->mmu_lock);
 	kvm->arch.radix = 1;
 	spin_unlock(&kvm->mmu_lock);
@@ -5654,10 +5654,10 @@
 	.flush_memslot  = kvmppc_core_flush_memslot_hv,
 	.prepare_memory_region = kvmppc_core_prepare_memory_region_hv,
 	.commit_memory_region  = kvmppc_core_commit_memory_region_hv,
-	.unmap_hva_range = kvm_unmap_hva_range_hv,
-	.age_hva  = kvm_age_hva_hv,
-	.test_age_hva = kvm_test_age_hva_hv,
-	.set_spte_hva = kvm_set_spte_hva_hv,
+	.unmap_gfn_range = kvm_unmap_gfn_range_hv,
+	.age_gfn = kvm_age_gfn_hv,
+	.test_age_gfn = kvm_test_age_gfn_hv,
+	.set_spte_gfn = kvm_set_spte_gfn_hv,
 	.free_memslot = kvmppc_core_free_memslot_hv,
 	.init_vm =  kvmppc_core_init_vm_hv,
 	.destroy_vm = kvmppc_core_destroy_vm_hv,
diff --git a/arch/powerpc/kvm/book3s_pr.c b/arch/powerpc/kvm/book3s_pr.c
index 913944d..d7733b0 100644
--- a/arch/powerpc/kvm/book3s_pr.c
+++ b/arch/powerpc/kvm/book3s_pr.c
@@ -425,61 +425,39 @@
 }
 
 /************* MMU Notifiers *************/
-static void do_kvm_unmap_hva(struct kvm *kvm, unsigned long start,
-			     unsigned long end)
+static bool do_kvm_unmap_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
 	long i;
 	struct kvm_vcpu *vcpu;
-	struct kvm_memslots *slots;
-	struct kvm_memory_slot *memslot;
 
-	slots = kvm_memslots(kvm);
-	kvm_for_each_memslot(memslot, slots) {
-		unsigned long hva_start, hva_end;
-		gfn_t gfn, gfn_end;
+	kvm_for_each_vcpu(i, vcpu, kvm)
+		kvmppc_mmu_pte_pflush(vcpu, range->start << PAGE_SHIFT,
+				      range->end << PAGE_SHIFT);
 
-		hva_start = max(start, memslot->userspace_addr);
-		hva_end = min(end, memslot->userspace_addr +
-					(memslot->npages << PAGE_SHIFT));
-		if (hva_start >= hva_end)
-			continue;
-		/*
-		 * {gfn(page) | page intersects with [hva_start, hva_end)} =
-		 * {gfn, gfn+1, ..., gfn_end-1}.
-		 */
-		gfn = hva_to_gfn_memslot(hva_start, memslot);
-		gfn_end = hva_to_gfn_memslot(hva_end + PAGE_SIZE - 1, memslot);
-		kvm_for_each_vcpu(i, vcpu, kvm)
-			kvmppc_mmu_pte_pflush(vcpu, gfn << PAGE_SHIFT,
-					      gfn_end << PAGE_SHIFT);
-	}
+	return false;
 }
 
-static int kvm_unmap_hva_range_pr(struct kvm *kvm, unsigned long start,
-				  unsigned long end)
+static bool kvm_unmap_gfn_range_pr(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	do_kvm_unmap_hva(kvm, start, end);
-
-	return 0;
+	return do_kvm_unmap_gfn(kvm, range);
 }
 
-static int kvm_age_hva_pr(struct kvm *kvm, unsigned long start,
-			  unsigned long end)
+static bool kvm_age_gfn_pr(struct kvm *kvm, struct kvm_gfn_range *range)
 {
 	/* XXX could be more clever ;) */
-	return 0;
+	return false;
 }
 
-static int kvm_test_age_hva_pr(struct kvm *kvm, unsigned long hva)
+static bool kvm_test_age_gfn_pr(struct kvm *kvm, struct kvm_gfn_range *range)
 {
 	/* XXX could be more clever ;) */
-	return 0;
+	return false;
 }
 
-static void kvm_set_spte_hva_pr(struct kvm *kvm, unsigned long hva, pte_t pte)
+static bool kvm_set_spte_gfn_pr(struct kvm *kvm, struct kvm_gfn_range *range)
 {
 	/* The page will get remapped properly on its next fault */
-	do_kvm_unmap_hva(kvm, hva, hva + PAGE_SIZE);
+	return do_kvm_unmap_gfn(kvm, range);
 }
 
 /*****************************************/
@@ -2079,10 +2057,10 @@
 	.flush_memslot = kvmppc_core_flush_memslot_pr,
 	.prepare_memory_region = kvmppc_core_prepare_memory_region_pr,
 	.commit_memory_region = kvmppc_core_commit_memory_region_pr,
-	.unmap_hva_range = kvm_unmap_hva_range_pr,
-	.age_hva  = kvm_age_hva_pr,
-	.test_age_hva = kvm_test_age_hva_pr,
-	.set_spte_hva = kvm_set_spte_hva_pr,
+	.unmap_gfn_range = kvm_unmap_gfn_range_pr,
+	.age_gfn  = kvm_age_gfn_pr,
+	.test_age_gfn = kvm_test_age_gfn_pr,
+	.set_spte_gfn = kvm_set_spte_gfn_pr,
 	.free_memslot = kvmppc_core_free_memslot_pr,
 	.init_vm = kvmppc_core_init_vm_pr,
 	.destroy_vm = kvmppc_core_destroy_vm_pr,
diff --git a/arch/powerpc/kvm/e500_mmu_host.c b/arch/powerpc/kvm/e500_mmu_host.c
index ed0c9c4..7f16afc 100644
--- a/arch/powerpc/kvm/e500_mmu_host.c
+++ b/arch/powerpc/kvm/e500_mmu_host.c
@@ -721,45 +721,36 @@
 
 /************* MMU Notifiers *************/
 
-static int kvm_unmap_hva(struct kvm *kvm, unsigned long hva)
+static bool kvm_e500_mmu_unmap_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	trace_kvm_unmap_hva(hva);
-
 	/*
 	 * Flush all shadow tlb entries everywhere. This is slow, but
 	 * we are 100% sure that we catch the to be unmapped page
 	 */
-	kvm_flush_remote_tlbs(kvm);
-
-	return 0;
+	return true;
 }
 
-int kvm_unmap_hva_range(struct kvm *kvm, unsigned long start, unsigned long end,
-			unsigned flags)
+bool kvm_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	/* kvm_unmap_hva flushes everything anyways */
-	kvm_unmap_hva(kvm, start);
-
-	return 0;
+	return kvm_e500_mmu_unmap_gfn(kvm, range);
 }
 
-int kvm_age_hva(struct kvm *kvm, unsigned long start, unsigned long end)
+bool kvm_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
 	/* XXX could be more clever ;) */
-	return 0;
+	return false;
 }
 
-int kvm_test_age_hva(struct kvm *kvm, unsigned long hva)
+bool kvm_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
 	/* XXX could be more clever ;) */
-	return 0;
+	return false;
 }
 
-int kvm_set_spte_hva(struct kvm *kvm, unsigned long hva, pte_t pte)
+bool kvm_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
 	/* The page will get remapped properly on its next fault */
-	kvm_unmap_hva(kvm, hva);
-	return 0;
+	return kvm_e500_mmu_unmap_gfn(kvm, range);
 }
 
 /*****************************************/
diff --git a/arch/powerpc/kvm/trace_booke.h b/arch/powerpc/kvm/trace_booke.h
index 3837842..eff6e82 100644
--- a/arch/powerpc/kvm/trace_booke.h
+++ b/arch/powerpc/kvm/trace_booke.h
@@ -69,21 +69,6 @@
 		)
 );
 
-TRACE_EVENT(kvm_unmap_hva,
-	TP_PROTO(unsigned long hva),
-	TP_ARGS(hva),
-
-	TP_STRUCT__entry(
-		__field(	unsigned long,	hva		)
-	),
-
-	TP_fast_assign(
-		__entry->hva		= hva;
-	),
-
-	TP_printk("unmap hva 0x%lx\n", __entry->hva)
-);
-
 TRACE_EVENT(kvm_booke206_stlb_write,
 	TP_PROTO(__u32 mas0, __u32 mas8, __u32 mas1, __u64 mas2, __u64 mas7_3),
 	TP_ARGS(mas0, mas8, mas1, mas2, mas7_3),
diff --git a/arch/x86/include/asm/cpufeatures.h b/arch/x86/include/asm/cpufeatures.h
index cc96e26..12c4a13 100644
--- a/arch/x86/include/asm/cpufeatures.h
+++ b/arch/x86/include/asm/cpufeatures.h
@@ -336,6 +336,7 @@
 #define X86_FEATURE_AVIC		(15*32+13) /* Virtual Interrupt Controller */
 #define X86_FEATURE_V_VMSAVE_VMLOAD	(15*32+15) /* Virtual VMSAVE VMLOAD */
 #define X86_FEATURE_VGIF		(15*32+16) /* Virtual GIF */
+#define X86_FEATURE_V_SPEC_CTRL		(15*32+20) /* Virtual SPEC_CTRL */
 #define X86_FEATURE_SVME_ADDR_CHK	(15*32+28) /* "" SVME addr check */
 
 /* Intel-defined CPU features, CPUID level 0x00000007:0 (ECX), word 16 */
diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index 3768819..90a9686 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -1068,25 +1068,36 @@
 	bool tdp_mmu_enabled;
 
 	/*
-	 * List of struct kvmp_mmu_pages being used as roots.
+	 * List of struct kvm_mmu_pages being used as roots.
 	 * All struct kvm_mmu_pages in the list should have
 	 * tdp_mmu_page set.
-	 * All struct kvm_mmu_pages in the list should have a positive
-	 * root_count except when a thread holds the MMU lock and is removing
-	 * an entry from the list.
+	 *
+	 * For reads, this list is protected by:
+	 *	the MMU lock in read mode + RCU or
+	 *	the MMU lock in write mode
+	 *
+	 * For writes, this list is protected by:
+	 *	the MMU lock in read mode + the tdp_mmu_pages_lock or
+	 *	the MMU lock in write mode
+	 *
+	 * Roots will remain in the list until their tdp_mmu_root_count
+	 * drops to zero, at which point the thread that decremented the
+	 * count to zero should removed the root from the list and clean
+	 * it up, freeing the root after an RCU grace period.
 	 */
 	struct list_head tdp_mmu_roots;
 
 	/*
 	 * List of struct kvmp_mmu_pages not being used as roots.
 	 * All struct kvm_mmu_pages in the list should have
-	 * tdp_mmu_page set and a root_count of 0.
+	 * tdp_mmu_page set and a tdp_mmu_root_count of 0.
 	 */
 	struct list_head tdp_mmu_pages;
 
 	/*
 	 * Protects accesses to the following fields when the MMU lock
 	 * is held in read mode:
+	 *  - tdp_mmu_roots (above)
 	 *  - tdp_mmu_pages (above)
 	 *  - the link field of struct kvm_mmu_pages used by the TDP MMU
 	 *  - lpage_disallowed_mmu_pages
@@ -1143,6 +1154,7 @@
 	u64 req_event;
 	u64 halt_poll_success_ns;
 	u64 halt_poll_fail_ns;
+	u64 nested_run;
 };
 
 struct x86_instruction_info;
@@ -1269,8 +1281,8 @@
 	int (*set_identity_map_addr)(struct kvm *kvm, u64 ident_addr);
 	u64 (*get_mt_mask)(struct kvm_vcpu *vcpu, gfn_t gfn, bool is_mmio);
 
-	void (*load_mmu_pgd)(struct kvm_vcpu *vcpu, unsigned long pgd,
-			     int pgd_level);
+	void (*load_mmu_pgd)(struct kvm_vcpu *vcpu, hpa_t root_hpa,
+			     int root_level);
 
 	bool (*has_wbinvd_exit)(void);
 
@@ -1339,6 +1351,7 @@
 	int (*mem_enc_op)(struct kvm *kvm, void __user *argp);
 	int (*mem_enc_reg_region)(struct kvm *kvm, struct kvm_enc_region *argp);
 	int (*mem_enc_unreg_region)(struct kvm *kvm, struct kvm_enc_region *argp);
+	int (*vm_copy_enc_context_from)(struct kvm *kvm, unsigned int source_fd);
 
 	int (*get_msr_feature)(struct kvm_msr_entry *entry);
 
@@ -1357,6 +1370,7 @@
 struct kvm_x86_nested_ops {
 	int (*check_events)(struct kvm_vcpu *vcpu);
 	bool (*hv_timer_pending)(struct kvm_vcpu *vcpu);
+	void (*triple_fault)(struct kvm_vcpu *vcpu);
 	int (*get_state)(struct kvm_vcpu *vcpu,
 			 struct kvm_nested_state __user *user_kvm_nested_state,
 			 unsigned user_data_size);
@@ -1428,9 +1442,6 @@
 int kvm_mmu_create(struct kvm_vcpu *vcpu);
 void kvm_mmu_init_vm(struct kvm *kvm);
 void kvm_mmu_uninit_vm(struct kvm *kvm);
-void kvm_mmu_set_mask_ptes(u64 user_mask, u64 accessed_mask,
-		u64 dirty_mask, u64 nx_mask, u64 x_mask, u64 p_mask,
-		u64 acc_track_mask, u64 me_mask);
 
 void kvm_mmu_reset_context(struct kvm_vcpu *vcpu);
 void kvm_mmu_slot_remove_write_access(struct kvm *kvm,
@@ -1538,6 +1549,11 @@
 int kvm_set_msr(struct kvm_vcpu *vcpu, u32 index, u64 data);
 int kvm_emulate_rdmsr(struct kvm_vcpu *vcpu);
 int kvm_emulate_wrmsr(struct kvm_vcpu *vcpu);
+int kvm_emulate_as_nop(struct kvm_vcpu *vcpu);
+int kvm_emulate_invd(struct kvm_vcpu *vcpu);
+int kvm_emulate_mwait(struct kvm_vcpu *vcpu);
+int kvm_handle_invalid_op(struct kvm_vcpu *vcpu);
+int kvm_emulate_monitor(struct kvm_vcpu *vcpu);
 
 int kvm_fast_pio(struct kvm_vcpu *vcpu, int size, unsigned short port, int in);
 int kvm_emulate_cpuid(struct kvm_vcpu *vcpu);
@@ -1566,14 +1582,14 @@
 unsigned long kvm_get_cr8(struct kvm_vcpu *vcpu);
 void kvm_lmsw(struct kvm_vcpu *vcpu, unsigned long msw);
 void kvm_get_cs_db_l_bits(struct kvm_vcpu *vcpu, int *db, int *l);
-int kvm_set_xcr(struct kvm_vcpu *vcpu, u32 index, u64 xcr);
+int kvm_emulate_xsetbv(struct kvm_vcpu *vcpu);
 
 int kvm_get_msr_common(struct kvm_vcpu *vcpu, struct msr_data *msr);
 int kvm_set_msr_common(struct kvm_vcpu *vcpu, struct msr_data *msr);
 
 unsigned long kvm_get_rflags(struct kvm_vcpu *vcpu);
 void kvm_set_rflags(struct kvm_vcpu *vcpu, unsigned long rflags);
-bool kvm_rdpmc(struct kvm_vcpu *vcpu);
+int kvm_emulate_rdpmc(struct kvm_vcpu *vcpu);
 
 void kvm_queue_exception(struct kvm_vcpu *vcpu, unsigned nr);
 void kvm_queue_exception_e(struct kvm_vcpu *vcpu, unsigned nr, u32 error_code);
@@ -1614,9 +1630,6 @@
 
 int kvm_mmu_unprotect_page(struct kvm *kvm, gfn_t gfn);
 void __kvm_mmu_free_some_pages(struct kvm_vcpu *vcpu);
-int kvm_mmu_load(struct kvm_vcpu *vcpu);
-void kvm_mmu_unload(struct kvm_vcpu *vcpu);
-void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu);
 void kvm_mmu_free_roots(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
 			ulong roots_to_free);
 gpa_t translate_nested_gpa(struct kvm_vcpu *vcpu, gpa_t gpa, u32 access,
@@ -1735,11 +1748,7 @@
 	_ASM_EXTABLE(666b, 667b)
 
 #define KVM_ARCH_WANT_MMU_NOTIFIER
-int kvm_unmap_hva_range(struct kvm *kvm, unsigned long start, unsigned long end,
-			unsigned flags);
-int kvm_age_hva(struct kvm *kvm, unsigned long start, unsigned long end);
-int kvm_test_age_hva(struct kvm *kvm, unsigned long hva);
-int kvm_set_spte_hva(struct kvm *kvm, unsigned long hva, pte_t pte);
+
 int kvm_cpu_has_injectable_intr(struct kvm_vcpu *v);
 int kvm_cpu_has_interrupt(struct kvm_vcpu *vcpu);
 int kvm_cpu_has_extint(struct kvm_vcpu *v);
diff --git a/arch/x86/include/asm/svm.h b/arch/x86/include/asm/svm.h
index 1c56194..772e60e 100644
--- a/arch/x86/include/asm/svm.h
+++ b/arch/x86/include/asm/svm.h
@@ -269,7 +269,9 @@
 	 * SEV-ES guests when referenced through the GHCB or for
 	 * saving to the host save area.
 	 */
-	u8 reserved_7[80];
+	u8 reserved_7[72];
+	u32 spec_ctrl;		/* Guest version of SPEC_CTRL at 0x2E0 */
+	u8 reserved_7b[4];
 	u32 pkru;
 	u8 reserved_7a[20];
 	u64 reserved_8;		/* rax already available at 0x01f8 */
diff --git a/arch/x86/kvm/cpuid.h b/arch/x86/kvm/cpuid.h
index 2a0c506..ded84d2 100644
--- a/arch/x86/kvm/cpuid.h
+++ b/arch/x86/kvm/cpuid.h
@@ -248,6 +248,14 @@
 		is_guest_vendor_hygon(best->ebx, best->ecx, best->edx));
 }
 
+static inline bool guest_cpuid_is_intel(struct kvm_vcpu *vcpu)
+{
+	struct kvm_cpuid_entry2 *best;
+
+	best = kvm_find_cpuid_entry(vcpu, 0, 0);
+	return best && is_guest_vendor_intel(best->ebx, best->ecx, best->edx);
+}
+
 static inline int guest_cpuid_family(struct kvm_vcpu *vcpu)
 {
 	struct kvm_cpuid_entry2 *best;
diff --git a/arch/x86/kvm/kvm_cache_regs.h b/arch/x86/kvm/kvm_cache_regs.h
index 2e11da2..07d6079 100644
--- a/arch/x86/kvm/kvm_cache_regs.h
+++ b/arch/x86/kvm/kvm_cache_regs.h
@@ -55,6 +55,13 @@
 	__set_bit(reg, (unsigned long *)&vcpu->arch.regs_avail);
 }
 
+static inline void kvm_register_clear_available(struct kvm_vcpu *vcpu,
+					       enum kvm_reg reg)
+{
+	__clear_bit(reg, (unsigned long *)&vcpu->arch.regs_avail);
+	__clear_bit(reg, (unsigned long *)&vcpu->arch.regs_dirty);
+}
+
 static inline void kvm_register_mark_dirty(struct kvm_vcpu *vcpu,
 					   enum kvm_reg reg)
 {
diff --git a/arch/x86/kvm/lapic.c b/arch/x86/kvm/lapic.c
index cc369b9..0050f39 100644
--- a/arch/x86/kvm/lapic.c
+++ b/arch/x86/kvm/lapic.c
@@ -2869,7 +2869,7 @@
 		return;
 
 	if (is_guest_mode(vcpu)) {
-		r = kvm_x86_ops.nested_ops->check_events(vcpu);
+		r = kvm_check_nested_events(vcpu);
 		if (r < 0)
 			return;
 		/*
diff --git a/arch/x86/kvm/mmu.h b/arch/x86/kvm/mmu.h
index c68bfc3..88d0ed52 100644
--- a/arch/x86/kvm/mmu.h
+++ b/arch/x86/kvm/mmu.h
@@ -59,7 +59,8 @@
 	return ((2ULL << (e - s)) - 1) << s;
 }
 
-void kvm_mmu_set_mmio_spte_mask(u64 mmio_value, u64 access_mask);
+void kvm_mmu_set_mmio_spte_mask(u64 mmio_value, u64 mmio_mask, u64 access_mask);
+void kvm_mmu_set_ept_masks(bool has_ad_bits, bool has_exec_only);
 
 void
 reset_shadow_zero_bits_mask(struct kvm_vcpu *vcpu, struct kvm_mmu *context);
@@ -73,6 +74,10 @@
 int kvm_handle_page_fault(struct kvm_vcpu *vcpu, u64 error_code,
 				u64 fault_address, char *insn, int insn_len);
 
+int kvm_mmu_load(struct kvm_vcpu *vcpu);
+void kvm_mmu_unload(struct kvm_vcpu *vcpu);
+void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu);
+
 static inline int kvm_mmu_reload(struct kvm_vcpu *vcpu)
 {
 	if (likely(vcpu->arch.mmu->root_hpa != INVALID_PAGE))
@@ -102,8 +107,8 @@
 	if (!VALID_PAGE(root_hpa))
 		return;
 
-	static_call(kvm_x86_load_mmu_pgd)(vcpu, root_hpa | kvm_get_active_pcid(vcpu),
-				 vcpu->arch.mmu->shadow_root_level);
+	static_call(kvm_x86_load_mmu_pgd)(vcpu, root_hpa,
+					  vcpu->arch.mmu->shadow_root_level);
 }
 
 int kvm_tdp_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
@@ -124,7 +129,7 @@
  * write-protects guest page to sync the guest modification, b) another one is
  * used to sync dirty bitmap when we do KVM_GET_DIRTY_LOG. The differences
  * between these two sorts are:
- * 1) the first case clears SPTE_MMU_WRITEABLE bit.
+ * 1) the first case clears MMU-writable bit.
  * 2) the first case requires flushing tlb immediately avoiding corrupting
  *    shadow page table between all vcpus so it should be in the protection of
  *    mmu-lock. And the another case does not need to flush tlb until returning
@@ -135,17 +140,17 @@
  * So, there is the problem: the first case can meet the corrupted tlb caused
  * by another case which write-protects pages but without flush tlb
  * immediately. In order to making the first case be aware this problem we let
- * it flush tlb if we try to write-protect a spte whose SPTE_MMU_WRITEABLE bit
- * is set, it works since another case never touches SPTE_MMU_WRITEABLE bit.
+ * it flush tlb if we try to write-protect a spte whose MMU-writable bit
+ * is set, it works since another case never touches MMU-writable bit.
  *
  * Anyway, whenever a spte is updated (only permission and status bits are
- * changed) we need to check whether the spte with SPTE_MMU_WRITEABLE becomes
+ * changed) we need to check whether the spte with MMU-writable becomes
  * readonly, if that happens, we need to flush tlb. Fortunately,
  * mmu_spte_update() has already handled it perfectly.
  *
- * The rules to use SPTE_MMU_WRITEABLE and PT_WRITABLE_MASK:
+ * The rules to use MMU-writable and PT_WRITABLE_MASK:
  * - if we want to see if it has writable tlb entry or if the spte can be
- *   writable on the mmu mapping, check SPTE_MMU_WRITEABLE, this is the most
+ *   writable on the mmu mapping, check MMU-writable, this is the most
  *   case, otherwise
  * - if we fix page fault on the spte or do write-protection by dirty logging,
  *   check PT_WRITABLE_MASK.
diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
index 951dae4..8f15282 100644
--- a/arch/x86/kvm/mmu/mmu.c
+++ b/arch/x86/kvm/mmu/mmu.c
@@ -48,6 +48,7 @@
 #include <asm/memtype.h>
 #include <asm/cmpxchg.h>
 #include <asm/io.h>
+#include <asm/set_memory.h>
 #include <asm/vmx.h>
 #include <asm/kvm_page_track.h>
 #include "trace.h"
@@ -215,10 +216,10 @@
 static void mark_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, u64 gfn,
 			   unsigned int access)
 {
-	u64 mask = make_mmio_spte(vcpu, gfn, access);
+	u64 spte = make_mmio_spte(vcpu, gfn, access);
 
-	trace_mark_mmio_spte(sptep, gfn, mask);
-	mmu_spte_set(sptep, mask);
+	trace_mark_mmio_spte(sptep, gfn, spte);
+	mmu_spte_set(sptep, spte);
 }
 
 static gfn_t get_mmio_spte_gfn(u64 spte)
@@ -236,17 +237,6 @@
 	return spte & shadow_mmio_access_mask;
 }
 
-static bool set_mmio_spte(struct kvm_vcpu *vcpu, u64 *sptep, gfn_t gfn,
-			  kvm_pfn_t pfn, unsigned int access)
-{
-	if (unlikely(is_noslot_pfn(pfn))) {
-		mark_mmio_spte(vcpu, sptep, gfn, access);
-		return true;
-	}
-
-	return false;
-}
-
 static bool check_mmio_spte(struct kvm_vcpu *vcpu, u64 spte)
 {
 	u64 kvm_gen, spte_gen, gen;
@@ -725,8 +715,7 @@
  * handling slots that are not large page aligned.
  */
 static struct kvm_lpage_info *lpage_info_slot(gfn_t gfn,
-					      struct kvm_memory_slot *slot,
-					      int level)
+		const struct kvm_memory_slot *slot, int level)
 {
 	unsigned long idx;
 
@@ -1118,7 +1107,7 @@
 	rmap_printk("spte %p %llx\n", sptep, *sptep);
 
 	if (pt_protect)
-		spte &= ~SPTE_MMU_WRITEABLE;
+		spte &= ~shadow_mmu_writable_mask;
 	spte = spte & ~PT_WRITABLE_MASK;
 
 	return mmu_spte_update(sptep, spte);
@@ -1308,26 +1297,25 @@
 	return flush;
 }
 
-static int kvm_unmap_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
-			   struct kvm_memory_slot *slot, gfn_t gfn, int level,
-			   unsigned long data)
+static bool kvm_unmap_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+			    struct kvm_memory_slot *slot, gfn_t gfn, int level,
+			    pte_t unused)
 {
 	return kvm_zap_rmapp(kvm, rmap_head, slot);
 }
 
-static int kvm_set_pte_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
-			     struct kvm_memory_slot *slot, gfn_t gfn, int level,
-			     unsigned long data)
+static bool kvm_set_pte_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+			      struct kvm_memory_slot *slot, gfn_t gfn, int level,
+			      pte_t pte)
 {
 	u64 *sptep;
 	struct rmap_iterator iter;
 	int need_flush = 0;
 	u64 new_spte;
-	pte_t *ptep = (pte_t *)data;
 	kvm_pfn_t new_pfn;
 
-	WARN_ON(pte_huge(*ptep));
-	new_pfn = pte_pfn(*ptep);
+	WARN_ON(pte_huge(pte));
+	new_pfn = pte_pfn(pte);
 
 restart:
 	for_each_rmap_spte(rmap_head, &iter, sptep) {
@@ -1336,7 +1324,7 @@
 
 		need_flush = 1;
 
-		if (pte_write(*ptep)) {
+		if (pte_write(pte)) {
 			pte_list_remove(rmap_head, sptep);
 			goto restart;
 		} else {
@@ -1424,93 +1412,52 @@
 	     slot_rmap_walk_okay(_iter_);				\
 	     slot_rmap_walk_next(_iter_))
 
-static __always_inline int
-kvm_handle_hva_range(struct kvm *kvm,
-		     unsigned long start,
-		     unsigned long end,
-		     unsigned long data,
-		     int (*handler)(struct kvm *kvm,
-				    struct kvm_rmap_head *rmap_head,
-				    struct kvm_memory_slot *slot,
-				    gfn_t gfn,
-				    int level,
-				    unsigned long data))
+typedef bool (*rmap_handler_t)(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+			       struct kvm_memory_slot *slot, gfn_t gfn,
+			       int level, pte_t pte);
+
+static __always_inline bool kvm_handle_gfn_range(struct kvm *kvm,
+						 struct kvm_gfn_range *range,
+						 rmap_handler_t handler)
 {
-	struct kvm_memslots *slots;
-	struct kvm_memory_slot *memslot;
 	struct slot_rmap_walk_iterator iterator;
-	int ret = 0;
-	int i;
+	bool ret = false;
 
-	for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
-		slots = __kvm_memslots(kvm, i);
-		kvm_for_each_memslot(memslot, slots) {
-			unsigned long hva_start, hva_end;
-			gfn_t gfn_start, gfn_end;
-
-			hva_start = max(start, memslot->userspace_addr);
-			hva_end = min(end, memslot->userspace_addr +
-				      (memslot->npages << PAGE_SHIFT));
-			if (hva_start >= hva_end)
-				continue;
-			/*
-			 * {gfn(page) | page intersects with [hva_start, hva_end)} =
-			 * {gfn_start, gfn_start+1, ..., gfn_end-1}.
-			 */
-			gfn_start = hva_to_gfn_memslot(hva_start, memslot);
-			gfn_end = hva_to_gfn_memslot(hva_end + PAGE_SIZE - 1, memslot);
-
-			for_each_slot_rmap_range(memslot, PG_LEVEL_4K,
-						 KVM_MAX_HUGEPAGE_LEVEL,
-						 gfn_start, gfn_end - 1,
-						 &iterator)
-				ret |= handler(kvm, iterator.rmap, memslot,
-					       iterator.gfn, iterator.level, data);
-		}
-	}
+	for_each_slot_rmap_range(range->slot, PG_LEVEL_4K, KVM_MAX_HUGEPAGE_LEVEL,
+				 range->start, range->end - 1, &iterator)
+		ret |= handler(kvm, iterator.rmap, range->slot, iterator.gfn,
+			       iterator.level, range->pte);
 
 	return ret;
 }
 
-static int kvm_handle_hva(struct kvm *kvm, unsigned long hva,
-			  unsigned long data,
-			  int (*handler)(struct kvm *kvm,
-					 struct kvm_rmap_head *rmap_head,
-					 struct kvm_memory_slot *slot,
-					 gfn_t gfn, int level,
-					 unsigned long data))
+bool kvm_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	return kvm_handle_hva_range(kvm, hva, hva + 1, data, handler);
-}
+	bool flush;
 
-int kvm_unmap_hva_range(struct kvm *kvm, unsigned long start, unsigned long end,
-			unsigned flags)
-{
-	int r;
-
-	r = kvm_handle_hva_range(kvm, start, end, 0, kvm_unmap_rmapp);
+	flush = kvm_handle_gfn_range(kvm, range, kvm_unmap_rmapp);
 
 	if (is_tdp_mmu_enabled(kvm))
-		r |= kvm_tdp_mmu_zap_hva_range(kvm, start, end);
+		flush |= kvm_tdp_mmu_unmap_gfn_range(kvm, range, flush);
 
-	return r;
+	return flush;
 }
 
-int kvm_set_spte_hva(struct kvm *kvm, unsigned long hva, pte_t pte)
+bool kvm_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	int r;
+	bool flush;
 
-	r = kvm_handle_hva(kvm, hva, (unsigned long)&pte, kvm_set_pte_rmapp);
+	flush = kvm_handle_gfn_range(kvm, range, kvm_set_pte_rmapp);
 
 	if (is_tdp_mmu_enabled(kvm))
-		r |= kvm_tdp_mmu_set_spte_hva(kvm, hva, &pte);
+		flush |= kvm_tdp_mmu_set_spte_gfn(kvm, range);
 
-	return r;
+	return flush;
 }
 
-static int kvm_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
-			 struct kvm_memory_slot *slot, gfn_t gfn, int level,
-			 unsigned long data)
+static bool kvm_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+			  struct kvm_memory_slot *slot, gfn_t gfn, int level,
+			  pte_t unused)
 {
 	u64 *sptep;
 	struct rmap_iterator iter;
@@ -1519,13 +1466,12 @@
 	for_each_rmap_spte(rmap_head, &iter, sptep)
 		young |= mmu_spte_age(sptep);
 
-	trace_kvm_age_page(gfn, level, slot, young);
 	return young;
 }
 
-static int kvm_test_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
-			      struct kvm_memory_slot *slot, gfn_t gfn,
-			      int level, unsigned long data)
+static bool kvm_test_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
+			       struct kvm_memory_slot *slot, gfn_t gfn,
+			       int level, pte_t unused)
 {
 	u64 *sptep;
 	struct rmap_iterator iter;
@@ -1547,29 +1493,31 @@
 
 	rmap_head = gfn_to_rmap(vcpu->kvm, gfn, sp);
 
-	kvm_unmap_rmapp(vcpu->kvm, rmap_head, NULL, gfn, sp->role.level, 0);
+	kvm_unmap_rmapp(vcpu->kvm, rmap_head, NULL, gfn, sp->role.level, __pte(0));
 	kvm_flush_remote_tlbs_with_address(vcpu->kvm, sp->gfn,
 			KVM_PAGES_PER_HPAGE(sp->role.level));
 }
 
-int kvm_age_hva(struct kvm *kvm, unsigned long start, unsigned long end)
+bool kvm_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	int young = false;
+	bool young;
 
-	young = kvm_handle_hva_range(kvm, start, end, 0, kvm_age_rmapp);
+	young = kvm_handle_gfn_range(kvm, range, kvm_age_rmapp);
+
 	if (is_tdp_mmu_enabled(kvm))
-		young |= kvm_tdp_mmu_age_hva_range(kvm, start, end);
+		young |= kvm_tdp_mmu_age_gfn_range(kvm, range);
 
 	return young;
 }
 
-int kvm_test_age_hva(struct kvm *kvm, unsigned long hva)
+bool kvm_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	int young = false;
+	bool young;
 
-	young = kvm_handle_hva(kvm, hva, 0, kvm_test_age_rmapp);
+	young = kvm_handle_gfn_range(kvm, range, kvm_test_age_rmapp);
+
 	if (is_tdp_mmu_enabled(kvm))
-		young |= kvm_tdp_mmu_test_age_hva(kvm, hva);
+		young |= kvm_tdp_mmu_test_age_gfn(kvm, range);
 
 	return young;
 }
@@ -2421,6 +2369,15 @@
 
 	kvm_mmu_zap_oldest_mmu_pages(vcpu->kvm, KVM_REFILL_PAGES - avail);
 
+	/*
+	 * Note, this check is intentionally soft, it only guarantees that one
+	 * page is available, while the caller may end up allocating as many as
+	 * four pages, e.g. for PAE roots or for 5-level paging.  Temporarily
+	 * exceeding the (arbitrary by default) limit will not harm the host,
+	 * being too agressive may unnecessarily kill the guest, and getting an
+	 * exact count is far more trouble than it's worth, especially in the
+	 * page fault paths.
+	 */
 	if (!kvm_mmu_available_pages(vcpu->kvm))
 		return -ENOSPC;
 	return 0;
@@ -2561,9 +2518,6 @@
 	struct kvm_mmu_page *sp;
 	int ret;
 
-	if (set_mmio_spte(vcpu, sptep, gfn, pfn, pte_access))
-		return 0;
-
 	sp = sptep_to_sp(sptep);
 
 	ret = make_spte(vcpu, pte_access, level, gfn, pfn, *sptep, speculative,
@@ -2593,6 +2547,11 @@
 	pgprintk("%s: spte %llx write_fault %d gfn %llx\n", __func__,
 		 *sptep, write_fault, gfn);
 
+	if (unlikely(is_noslot_pfn(pfn))) {
+		mark_mmio_spte(vcpu, sptep, gfn, pte_access);
+		return RET_PF_EMULATE;
+	}
+
 	if (is_shadow_present_pte(*sptep)) {
 		/*
 		 * If we overwrite a PTE page pointer with a 2MB PMD, unlink
@@ -2626,9 +2585,6 @@
 		kvm_flush_remote_tlbs_with_address(vcpu->kvm, gfn,
 				KVM_PAGES_PER_HPAGE(level));
 
-	if (unlikely(is_mmio_spte(*sptep)))
-		ret = RET_PF_EMULATE;
-
 	/*
 	 * The fault is fully spurious if and only if the new SPTE and old SPTE
 	 * are identical, and emulation is not required.
@@ -2745,7 +2701,7 @@
 }
 
 static int host_pfn_mapping_level(struct kvm *kvm, gfn_t gfn, kvm_pfn_t pfn,
-				  struct kvm_memory_slot *slot)
+				  const struct kvm_memory_slot *slot)
 {
 	unsigned long hva;
 	pte_t *pte;
@@ -2771,8 +2727,9 @@
 	return level;
 }
 
-int kvm_mmu_max_mapping_level(struct kvm *kvm, struct kvm_memory_slot *slot,
-			      gfn_t gfn, kvm_pfn_t pfn, int max_level)
+int kvm_mmu_max_mapping_level(struct kvm *kvm,
+			      const struct kvm_memory_slot *slot, gfn_t gfn,
+			      kvm_pfn_t pfn, int max_level)
 {
 	struct kvm_lpage_info *linfo;
 
@@ -2946,9 +2903,19 @@
 		return true;
 	}
 
-	if (unlikely(is_noslot_pfn(pfn)))
+	if (unlikely(is_noslot_pfn(pfn))) {
 		vcpu_cache_mmio_info(vcpu, gva, gfn,
 				     access & shadow_mmio_access_mask);
+		/*
+		 * If MMIO caching is disabled, emulate immediately without
+		 * touching the shadow page tables as attempting to install an
+		 * MMIO SPTE will just be an expensive nop.
+		 */
+		if (unlikely(!shadow_mmio_value)) {
+			*ret_val = RET_PF_EMULATE;
+			return true;
+		}
+	}
 
 	return false;
 }
@@ -3061,6 +3028,9 @@
 			if (!is_shadow_present_pte(spte))
 				break;
 
+		if (!is_shadow_present_pte(spte))
+			break;
+
 		sp = sptep_to_sp(iterator.sptep);
 		if (!is_last_spte(spte, sp->role.level))
 			break;
@@ -3150,12 +3120,10 @@
 
 	sp = to_shadow_page(*root_hpa & PT64_BASE_ADDR_MASK);
 
-	if (kvm_mmu_put_root(kvm, sp)) {
-		if (is_tdp_mmu_page(sp))
-			kvm_tdp_mmu_free_root(kvm, sp);
-		else if (sp->role.invalid)
-			kvm_mmu_prepare_zap_page(kvm, sp, invalid_list);
-	}
+	if (is_tdp_mmu_page(sp))
+		kvm_tdp_mmu_put_root(kvm, sp, false);
+	else if (!--sp->root_count && sp->role.invalid)
+		kvm_mmu_prepare_zap_page(kvm, sp, invalid_list);
 
 	*root_hpa = INVALID_PAGE;
 }
@@ -3193,14 +3161,17 @@
 		if (mmu->shadow_root_level >= PT64_ROOT_4LEVEL &&
 		    (mmu->root_level >= PT64_ROOT_4LEVEL || mmu->direct_map)) {
 			mmu_free_root_page(kvm, &mmu->root_hpa, &invalid_list);
-		} else {
-			for (i = 0; i < 4; ++i)
-				if (mmu->pae_root[i] != 0)
-					mmu_free_root_page(kvm,
-							   &mmu->pae_root[i],
-							   &invalid_list);
-			mmu->root_hpa = INVALID_PAGE;
+		} else if (mmu->pae_root) {
+			for (i = 0; i < 4; ++i) {
+				if (!IS_VALID_PAE_ROOT(mmu->pae_root[i]))
+					continue;
+
+				mmu_free_root_page(kvm, &mmu->pae_root[i],
+						   &invalid_list);
+				mmu->pae_root[i] = INVALID_PAE_ROOT;
+			}
 		}
+		mmu->root_hpa = INVALID_PAGE;
 		mmu->root_pgd = 0;
 	}
 
@@ -3226,153 +3197,180 @@
 {
 	struct kvm_mmu_page *sp;
 
-	write_lock(&vcpu->kvm->mmu_lock);
-
-	if (make_mmu_pages_available(vcpu)) {
-		write_unlock(&vcpu->kvm->mmu_lock);
-		return INVALID_PAGE;
-	}
 	sp = kvm_mmu_get_page(vcpu, gfn, gva, level, direct, ACC_ALL);
 	++sp->root_count;
 
-	write_unlock(&vcpu->kvm->mmu_lock);
 	return __pa(sp->spt);
 }
 
 static int mmu_alloc_direct_roots(struct kvm_vcpu *vcpu)
 {
-	u8 shadow_root_level = vcpu->arch.mmu->shadow_root_level;
+	struct kvm_mmu *mmu = vcpu->arch.mmu;
+	u8 shadow_root_level = mmu->shadow_root_level;
 	hpa_t root;
 	unsigned i;
 
 	if (is_tdp_mmu_enabled(vcpu->kvm)) {
 		root = kvm_tdp_mmu_get_vcpu_root_hpa(vcpu);
-
-		if (!VALID_PAGE(root))
-			return -ENOSPC;
-		vcpu->arch.mmu->root_hpa = root;
+		mmu->root_hpa = root;
 	} else if (shadow_root_level >= PT64_ROOT_4LEVEL) {
-		root = mmu_alloc_root(vcpu, 0, 0, shadow_root_level,
-				      true);
-
-		if (!VALID_PAGE(root))
-			return -ENOSPC;
-		vcpu->arch.mmu->root_hpa = root;
+		root = mmu_alloc_root(vcpu, 0, 0, shadow_root_level, true);
+		mmu->root_hpa = root;
 	} else if (shadow_root_level == PT32E_ROOT_LEVEL) {
+		if (WARN_ON_ONCE(!mmu->pae_root))
+			return -EIO;
+
 		for (i = 0; i < 4; ++i) {
-			MMU_WARN_ON(VALID_PAGE(vcpu->arch.mmu->pae_root[i]));
+			WARN_ON_ONCE(IS_VALID_PAE_ROOT(mmu->pae_root[i]));
 
 			root = mmu_alloc_root(vcpu, i << (30 - PAGE_SHIFT),
 					      i << 30, PT32_ROOT_LEVEL, true);
-			if (!VALID_PAGE(root))
-				return -ENOSPC;
-			vcpu->arch.mmu->pae_root[i] = root | PT_PRESENT_MASK;
+			mmu->pae_root[i] = root | PT_PRESENT_MASK |
+					   shadow_me_mask;
 		}
-		vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->pae_root);
-	} else
-		BUG();
+		mmu->root_hpa = __pa(mmu->pae_root);
+	} else {
+		WARN_ONCE(1, "Bad TDP root level = %d\n", shadow_root_level);
+		return -EIO;
+	}
 
 	/* root_pgd is ignored for direct MMUs. */
-	vcpu->arch.mmu->root_pgd = 0;
+	mmu->root_pgd = 0;
 
 	return 0;
 }
 
 static int mmu_alloc_shadow_roots(struct kvm_vcpu *vcpu)
 {
-	u64 pdptr, pm_mask;
+	struct kvm_mmu *mmu = vcpu->arch.mmu;
+	u64 pdptrs[4], pm_mask;
 	gfn_t root_gfn, root_pgd;
 	hpa_t root;
 	int i;
 
-	root_pgd = vcpu->arch.mmu->get_guest_pgd(vcpu);
+	root_pgd = mmu->get_guest_pgd(vcpu);
 	root_gfn = root_pgd >> PAGE_SHIFT;
 
 	if (mmu_check_root(vcpu, root_gfn))
 		return 1;
 
+	if (mmu->root_level == PT32E_ROOT_LEVEL) {
+		for (i = 0; i < 4; ++i) {
+			pdptrs[i] = mmu->get_pdptr(vcpu, i);
+			if (!(pdptrs[i] & PT_PRESENT_MASK))
+				continue;
+
+			if (mmu_check_root(vcpu, pdptrs[i] >> PAGE_SHIFT))
+				return 1;
+		}
+	}
+
 	/*
 	 * Do we shadow a long mode page table? If so we need to
 	 * write-protect the guests page table root.
 	 */
-	if (vcpu->arch.mmu->root_level >= PT64_ROOT_4LEVEL) {
-		MMU_WARN_ON(VALID_PAGE(vcpu->arch.mmu->root_hpa));
-
+	if (mmu->root_level >= PT64_ROOT_4LEVEL) {
 		root = mmu_alloc_root(vcpu, root_gfn, 0,
-				      vcpu->arch.mmu->shadow_root_level, false);
-		if (!VALID_PAGE(root))
-			return -ENOSPC;
-		vcpu->arch.mmu->root_hpa = root;
+				      mmu->shadow_root_level, false);
+		mmu->root_hpa = root;
 		goto set_root_pgd;
 	}
 
+	if (WARN_ON_ONCE(!mmu->pae_root))
+		return -EIO;
+
 	/*
 	 * We shadow a 32 bit page table. This may be a legacy 2-level
 	 * or a PAE 3-level page table. In either case we need to be aware that
 	 * the shadow page table may be a PAE or a long mode page table.
 	 */
-	pm_mask = PT_PRESENT_MASK;
-	if (vcpu->arch.mmu->shadow_root_level == PT64_ROOT_4LEVEL)
+	pm_mask = PT_PRESENT_MASK | shadow_me_mask;
+	if (mmu->shadow_root_level == PT64_ROOT_4LEVEL) {
 		pm_mask |= PT_ACCESSED_MASK | PT_WRITABLE_MASK | PT_USER_MASK;
 
+		if (WARN_ON_ONCE(!mmu->lm_root))
+			return -EIO;
+
+		mmu->lm_root[0] = __pa(mmu->pae_root) | pm_mask;
+	}
+
 	for (i = 0; i < 4; ++i) {
-		MMU_WARN_ON(VALID_PAGE(vcpu->arch.mmu->pae_root[i]));
-		if (vcpu->arch.mmu->root_level == PT32E_ROOT_LEVEL) {
-			pdptr = vcpu->arch.mmu->get_pdptr(vcpu, i);
-			if (!(pdptr & PT_PRESENT_MASK)) {
-				vcpu->arch.mmu->pae_root[i] = 0;
+		WARN_ON_ONCE(IS_VALID_PAE_ROOT(mmu->pae_root[i]));
+
+		if (mmu->root_level == PT32E_ROOT_LEVEL) {
+			if (!(pdptrs[i] & PT_PRESENT_MASK)) {
+				mmu->pae_root[i] = INVALID_PAE_ROOT;
 				continue;
 			}
-			root_gfn = pdptr >> PAGE_SHIFT;
-			if (mmu_check_root(vcpu, root_gfn))
-				return 1;
+			root_gfn = pdptrs[i] >> PAGE_SHIFT;
 		}
 
 		root = mmu_alloc_root(vcpu, root_gfn, i << 30,
 				      PT32_ROOT_LEVEL, false);
-		if (!VALID_PAGE(root))
-			return -ENOSPC;
-		vcpu->arch.mmu->pae_root[i] = root | pm_mask;
+		mmu->pae_root[i] = root | pm_mask;
 	}
-	vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->pae_root);
 
-	/*
-	 * If we shadow a 32 bit page table with a long mode page
-	 * table we enter this path.
-	 */
-	if (vcpu->arch.mmu->shadow_root_level == PT64_ROOT_4LEVEL) {
-		if (vcpu->arch.mmu->lm_root == NULL) {
-			/*
-			 * The additional page necessary for this is only
-			 * allocated on demand.
-			 */
-
-			u64 *lm_root;
-
-			lm_root = (void*)get_zeroed_page(GFP_KERNEL_ACCOUNT);
-			if (lm_root == NULL)
-				return 1;
-
-			lm_root[0] = __pa(vcpu->arch.mmu->pae_root) | pm_mask;
-
-			vcpu->arch.mmu->lm_root = lm_root;
-		}
-
-		vcpu->arch.mmu->root_hpa = __pa(vcpu->arch.mmu->lm_root);
-	}
+	if (mmu->shadow_root_level == PT64_ROOT_4LEVEL)
+		mmu->root_hpa = __pa(mmu->lm_root);
+	else
+		mmu->root_hpa = __pa(mmu->pae_root);
 
 set_root_pgd:
-	vcpu->arch.mmu->root_pgd = root_pgd;
+	mmu->root_pgd = root_pgd;
 
 	return 0;
 }
 
-static int mmu_alloc_roots(struct kvm_vcpu *vcpu)
+static int mmu_alloc_special_roots(struct kvm_vcpu *vcpu)
 {
-	if (vcpu->arch.mmu->direct_map)
-		return mmu_alloc_direct_roots(vcpu);
-	else
-		return mmu_alloc_shadow_roots(vcpu);
+	struct kvm_mmu *mmu = vcpu->arch.mmu;
+	u64 *lm_root, *pae_root;
+
+	/*
+	 * When shadowing 32-bit or PAE NPT with 64-bit NPT, the PML4 and PDP
+	 * tables are allocated and initialized at root creation as there is no
+	 * equivalent level in the guest's NPT to shadow.  Allocate the tables
+	 * on demand, as running a 32-bit L1 VMM on 64-bit KVM is very rare.
+	 */
+	if (mmu->direct_map || mmu->root_level >= PT64_ROOT_4LEVEL ||
+	    mmu->shadow_root_level < PT64_ROOT_4LEVEL)
+		return 0;
+
+	/*
+	 * This mess only works with 4-level paging and needs to be updated to
+	 * work with 5-level paging.
+	 */
+	if (WARN_ON_ONCE(mmu->shadow_root_level != PT64_ROOT_4LEVEL))
+		return -EIO;
+
+	if (mmu->pae_root && mmu->lm_root)
+		return 0;
+
+	/*
+	 * The special roots should always be allocated in concert.  Yell and
+	 * bail if KVM ends up in a state where only one of the roots is valid.
+	 */
+	if (WARN_ON_ONCE(!tdp_enabled || mmu->pae_root || mmu->lm_root))
+		return -EIO;
+
+	/*
+	 * Unlike 32-bit NPT, the PDP table doesn't need to be in low mem, and
+	 * doesn't need to be decrypted.
+	 */
+	pae_root = (void *)get_zeroed_page(GFP_KERNEL_ACCOUNT);
+	if (!pae_root)
+		return -ENOMEM;
+
+	lm_root = (void *)get_zeroed_page(GFP_KERNEL_ACCOUNT);
+	if (!lm_root) {
+		free_page((unsigned long)pae_root);
+		return -ENOMEM;
+	}
+
+	mmu->pae_root = pae_root;
+	mmu->lm_root = lm_root;
+
+	return 0;
 }
 
 void kvm_mmu_sync_roots(struct kvm_vcpu *vcpu)
@@ -3422,7 +3420,7 @@
 	for (i = 0; i < 4; ++i) {
 		hpa_t root = vcpu->arch.mmu->pae_root[i];
 
-		if (root && VALID_PAGE(root)) {
+		if (IS_VALID_PAE_ROOT(root)) {
 			root &= PT64_BASE_ADDR_MASK;
 			sp = to_shadow_page(root);
 			mmu_sync_children(vcpu, sp);
@@ -3554,11 +3552,12 @@
 			    __is_rsvd_bits_set(rsvd_check, sptes[level], level);
 
 	if (reserved) {
-		pr_err("%s: detect reserved bits on spte, addr 0x%llx, dump hierarchy:\n",
+		pr_err("%s: reserved bits set on MMU-present spte, addr 0x%llx, hierarchy:\n",
 		       __func__, addr);
 		for (level = root; level >= leaf; level--)
-			pr_err("------ spte 0x%llx level %d.\n",
-			       sptes[level], level);
+			pr_err("------ spte = 0x%llx level = %d, rsvd bits = 0x%llx",
+			       sptes[level], level,
+			       rsvd_check->rsvd_bits_mask[(sptes[level] >> 7) & 1][level-1]);
 	}
 
 	return reserved;
@@ -3653,6 +3652,14 @@
 	struct kvm_memory_slot *slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
 	bool async;
 
+	/*
+	 * Retry the page fault if the gfn hit a memslot that is being deleted
+	 * or moved.  This ensures any existing SPTEs for the old memslot will
+	 * be zapped before KVM inserts a new MMIO SPTE for the gfn.
+	 */
+	if (slot && (slot->flags & KVM_MEMSLOT_INVALID))
+		return true;
+
 	/* Don't expose private memslots to L2. */
 	if (is_guest_mode(vcpu) && !kvm_is_visible_memslot(slot)) {
 		*pfn = KVM_PFN_NOSLOT;
@@ -4615,12 +4622,17 @@
 	struct kvm_mmu *context = &vcpu->arch.guest_mmu;
 	union kvm_mmu_role new_role = kvm_calc_shadow_npt_root_page_role(vcpu);
 
-	context->shadow_root_level = new_role.base.level;
-
 	__kvm_mmu_new_pgd(vcpu, nested_cr3, new_role.base, false, false);
 
-	if (new_role.as_u64 != context->mmu_role.as_u64)
+	if (new_role.as_u64 != context->mmu_role.as_u64) {
 		shadow_mmu_init_context(vcpu, context, cr0, cr4, efer, new_role);
+
+		/*
+		 * Override the level set by the common init helper, nested TDP
+		 * always uses the host's TDP configuration.
+		 */
+		context->shadow_root_level = new_role.base.level;
+	}
 }
 EXPORT_SYMBOL_GPL(kvm_init_shadow_npt_mmu);
 
@@ -4802,16 +4814,27 @@
 	r = mmu_topup_memory_caches(vcpu, !vcpu->arch.mmu->direct_map);
 	if (r)
 		goto out;
-	r = mmu_alloc_roots(vcpu);
-	kvm_mmu_sync_roots(vcpu);
+	r = mmu_alloc_special_roots(vcpu);
 	if (r)
 		goto out;
+	write_lock(&vcpu->kvm->mmu_lock);
+	if (make_mmu_pages_available(vcpu))
+		r = -ENOSPC;
+	else if (vcpu->arch.mmu->direct_map)
+		r = mmu_alloc_direct_roots(vcpu);
+	else
+		r = mmu_alloc_shadow_roots(vcpu);
+	write_unlock(&vcpu->kvm->mmu_lock);
+	if (r)
+		goto out;
+
+	kvm_mmu_sync_roots(vcpu);
+
 	kvm_mmu_load_pgd(vcpu);
 	static_call(kvm_x86_tlb_flush_current)(vcpu);
 out:
 	return r;
 }
-EXPORT_SYMBOL_GPL(kvm_mmu_load);
 
 void kvm_mmu_unload(struct kvm_vcpu *vcpu)
 {
@@ -4820,7 +4843,6 @@
 	kvm_mmu_free_roots(vcpu, &vcpu->arch.guest_mmu, KVM_MMU_ROOTS_ALL);
 	WARN_ON(VALID_PAGE(vcpu->arch.guest_mmu.root_hpa));
 }
-EXPORT_SYMBOL_GPL(kvm_mmu_unload);
 
 static bool need_remote_flush(u64 old, u64 new)
 {
@@ -5169,10 +5191,10 @@
 static __always_inline bool
 slot_handle_level_range(struct kvm *kvm, struct kvm_memory_slot *memslot,
 			slot_level_handler fn, int start_level, int end_level,
-			gfn_t start_gfn, gfn_t end_gfn, bool lock_flush_tlb)
+			gfn_t start_gfn, gfn_t end_gfn, bool flush_on_yield,
+			bool flush)
 {
 	struct slot_rmap_walk_iterator iterator;
-	bool flush = false;
 
 	for_each_slot_rmap_range(memslot, start_level, end_level, start_gfn,
 			end_gfn, &iterator) {
@@ -5180,7 +5202,7 @@
 			flush |= fn(kvm, iterator.rmap, memslot);
 
 		if (need_resched() || rwlock_needbreak(&kvm->mmu_lock)) {
-			if (flush && lock_flush_tlb) {
+			if (flush && flush_on_yield) {
 				kvm_flush_remote_tlbs_with_address(kvm,
 						start_gfn,
 						iterator.gfn - start_gfn + 1);
@@ -5190,36 +5212,32 @@
 		}
 	}
 
-	if (flush && lock_flush_tlb) {
-		kvm_flush_remote_tlbs_with_address(kvm, start_gfn,
-						   end_gfn - start_gfn + 1);
-		flush = false;
-	}
-
 	return flush;
 }
 
 static __always_inline bool
 slot_handle_level(struct kvm *kvm, struct kvm_memory_slot *memslot,
 		  slot_level_handler fn, int start_level, int end_level,
-		  bool lock_flush_tlb)
+		  bool flush_on_yield)
 {
 	return slot_handle_level_range(kvm, memslot, fn, start_level,
 			end_level, memslot->base_gfn,
 			memslot->base_gfn + memslot->npages - 1,
-			lock_flush_tlb);
+			flush_on_yield, false);
 }
 
 static __always_inline bool
 slot_handle_leaf(struct kvm *kvm, struct kvm_memory_slot *memslot,
-		 slot_level_handler fn, bool lock_flush_tlb)
+		 slot_level_handler fn, bool flush_on_yield)
 {
 	return slot_handle_level(kvm, memslot, fn, PG_LEVEL_4K,
-				 PG_LEVEL_4K, lock_flush_tlb);
+				 PG_LEVEL_4K, flush_on_yield);
 }
 
 static void free_mmu_pages(struct kvm_mmu *mmu)
 {
+	if (!tdp_enabled && mmu->pae_root)
+		set_memory_encrypted((unsigned long)mmu->pae_root, 1);
 	free_page((unsigned long)mmu->pae_root);
 	free_page((unsigned long)mmu->lm_root);
 }
@@ -5240,9 +5258,11 @@
 	 * while the PDP table is a per-vCPU construct that's allocated at MMU
 	 * creation.  When emulating 32-bit mode, cr3 is only 32 bits even on
 	 * x86_64.  Therefore we need to allocate the PDP table in the first
-	 * 4GB of memory, which happens to fit the DMA32 zone.  Except for
-	 * SVM's 32-bit NPT support, TDP paging doesn't use PAE paging and can
-	 * skip allocating the PDP table.
+	 * 4GB of memory, which happens to fit the DMA32 zone.  TDP paging
+	 * generally doesn't use PAE paging and can skip allocating the PDP
+	 * table.  The main exception, handled here, is SVM's 32-bit NPT.  The
+	 * other exception is for shadowing L1's 32-bit or PAE NPT on 64-bit
+	 * KVM; that horror is handled on-demand by mmu_alloc_shadow_roots().
 	 */
 	if (tdp_enabled && kvm_mmu_get_tdp_level(vcpu) > PT32E_ROOT_LEVEL)
 		return 0;
@@ -5252,8 +5272,22 @@
 		return -ENOMEM;
 
 	mmu->pae_root = page_address(page);
+
+	/*
+	 * CR3 is only 32 bits when PAE paging is used, thus it's impossible to
+	 * get the CPU to treat the PDPTEs as encrypted.  Decrypt the page so
+	 * that KVM's writes and the CPU's reads get along.  Note, this is
+	 * only necessary when using shadow paging, as 64-bit NPT can get at
+	 * the C-bit even when shadowing 32-bit NPT, and SME isn't supported
+	 * by 32-bit kernels (when KVM itself uses 32-bit NPT).
+	 */
+	if (!tdp_enabled)
+		set_memory_decrypted((unsigned long)mmu->pae_root, 1);
+	else
+		WARN_ON_ONCE(shadow_me_mask);
+
 	for (i = 0; i < 4; ++i)
-		mmu->pae_root[i] = INVALID_PAGE;
+		mmu->pae_root[i] = INVALID_PAE_ROOT;
 
 	return 0;
 }
@@ -5351,6 +5385,8 @@
  */
 static void kvm_mmu_zap_all_fast(struct kvm *kvm)
 {
+	struct kvm_mmu_page *root;
+
 	lockdep_assert_held(&kvm->slots_lock);
 
 	write_lock(&kvm->mmu_lock);
@@ -5365,6 +5401,40 @@
 	 */
 	kvm->arch.mmu_valid_gen = kvm->arch.mmu_valid_gen ? 0 : 1;
 
+
+	if (is_tdp_mmu_enabled(kvm)) {
+		/*
+		 * Mark each TDP MMU root as invalid so that other threads
+		 * will drop their references and allow the root count to
+		 * go to 0.
+		 *
+		 * Also take a reference on all roots so that this thread
+		 * can do the bulk of the work required to free the roots
+		 * once they are invalidated. Without this reference, a
+		 * vCPU thread might drop the last reference to a root and
+		 * get stuck with tearing down the entire paging structure.
+		 *
+		 * Roots which have a zero refcount should be skipped as
+		 * they're already being torn down.
+		 * Already invalid roots should be referenced again so that
+		 * they aren't freed before kvm_tdp_mmu_zap_all_fast is
+		 * done with them.
+		 *
+		 * This has essentially the same effect for the TDP MMU
+		 * as updating mmu_valid_gen above does for the shadow
+		 * MMU.
+		 *
+		 * In order to ensure all threads see this change when
+		 * handling the MMU reload signal, this must happen in the
+		 * same critical section as kvm_reload_remote_mmus, and
+		 * before kvm_zap_obsolete_pages as kvm_zap_obsolete_pages
+		 * could drop the MMU lock and yield.
+		 */
+		list_for_each_entry(root, &kvm->arch.tdp_mmu_roots, link)
+			if (refcount_inc_not_zero(&root->tdp_mmu_root_count))
+				root->role.invalid = true;
+	}
+
 	/*
 	 * Notify all vcpus to reload its shadow page table and flush TLB.
 	 * Then all vcpus will switch to new shadow page table with the new
@@ -5377,10 +5447,13 @@
 
 	kvm_zap_obsolete_pages(kvm);
 
-	if (is_tdp_mmu_enabled(kvm))
-		kvm_tdp_mmu_zap_all(kvm);
-
 	write_unlock(&kvm->mmu_lock);
+
+	if (is_tdp_mmu_enabled(kvm)) {
+		read_lock(&kvm->mmu_lock);
+		kvm_tdp_mmu_zap_invalidated_roots(kvm);
+		read_unlock(&kvm->mmu_lock);
+	}
 }
 
 static bool kvm_has_zapped_obsolete_pages(struct kvm *kvm)
@@ -5420,7 +5493,7 @@
 	struct kvm_memslots *slots;
 	struct kvm_memory_slot *memslot;
 	int i;
-	bool flush;
+	bool flush = false;
 
 	write_lock(&kvm->mmu_lock);
 	for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
@@ -5433,20 +5506,31 @@
 			if (start >= end)
 				continue;
 
-			slot_handle_level_range(kvm, memslot, kvm_zap_rmapp,
-						PG_LEVEL_4K,
-						KVM_MAX_HUGEPAGE_LEVEL,
-						start, end - 1, true);
+			flush = slot_handle_level_range(kvm, memslot, kvm_zap_rmapp,
+							PG_LEVEL_4K,
+							KVM_MAX_HUGEPAGE_LEVEL,
+							start, end - 1, true, flush);
 		}
 	}
 
-	if (is_tdp_mmu_enabled(kvm)) {
-		flush = kvm_tdp_mmu_zap_gfn_range(kvm, gfn_start, gfn_end);
-		if (flush)
-			kvm_flush_remote_tlbs(kvm);
-	}
+	if (flush)
+		kvm_flush_remote_tlbs_with_address(kvm, gfn_start, gfn_end);
 
 	write_unlock(&kvm->mmu_lock);
+
+	if (is_tdp_mmu_enabled(kvm)) {
+		flush = false;
+
+		read_lock(&kvm->mmu_lock);
+		for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++)
+			flush = kvm_tdp_mmu_zap_gfn_range(kvm, i, gfn_start,
+							  gfn_end, flush, true);
+		if (flush)
+			kvm_flush_remote_tlbs_with_address(kvm, gfn_start,
+							   gfn_end);
+
+		read_unlock(&kvm->mmu_lock);
+	}
 }
 
 static bool slot_rmap_write_protect(struct kvm *kvm,
@@ -5465,10 +5549,14 @@
 	write_lock(&kvm->mmu_lock);
 	flush = slot_handle_level(kvm, memslot, slot_rmap_write_protect,
 				start_level, KVM_MAX_HUGEPAGE_LEVEL, false);
-	if (is_tdp_mmu_enabled(kvm))
-		flush |= kvm_tdp_mmu_wrprot_slot(kvm, memslot, PG_LEVEL_4K);
 	write_unlock(&kvm->mmu_lock);
 
+	if (is_tdp_mmu_enabled(kvm)) {
+		read_lock(&kvm->mmu_lock);
+		flush |= kvm_tdp_mmu_wrprot_slot(kvm, memslot, start_level);
+		read_unlock(&kvm->mmu_lock);
+	}
+
 	/*
 	 * We can flush all the TLBs out of the mmu lock without TLB
 	 * corruption since we just change the spte from writable to
@@ -5476,9 +5564,9 @@
 	 * spte from present to present (changing the spte from present
 	 * to nonpresent will flush all the TLBs immediately), in other
 	 * words, the only case we care is mmu_spte_update() where we
-	 * have checked SPTE_HOST_WRITEABLE | SPTE_MMU_WRITEABLE
-	 * instead of PT_WRITABLE_MASK, that means it does not depend
-	 * on PT_WRITABLE_MASK anymore.
+	 * have checked Host-writable | MMU-writable instead of
+	 * PT_WRITABLE_MASK, that means it does not depend on PT_WRITABLE_MASK
+	 * anymore.
 	 */
 	if (flush)
 		kvm_arch_flush_remote_tlbs_memslot(kvm, memslot);
@@ -5529,21 +5617,32 @@
 {
 	/* FIXME: const-ify all uses of struct kvm_memory_slot.  */
 	struct kvm_memory_slot *slot = (struct kvm_memory_slot *)memslot;
+	bool flush;
 
 	write_lock(&kvm->mmu_lock);
-	slot_handle_leaf(kvm, slot, kvm_mmu_zap_collapsible_spte, true);
+	flush = slot_handle_leaf(kvm, slot, kvm_mmu_zap_collapsible_spte, true);
 
-	if (is_tdp_mmu_enabled(kvm))
-		kvm_tdp_mmu_zap_collapsible_sptes(kvm, slot);
+	if (flush)
+		kvm_arch_flush_remote_tlbs_memslot(kvm, slot);
 	write_unlock(&kvm->mmu_lock);
+
+	if (is_tdp_mmu_enabled(kvm)) {
+		flush = false;
+
+		read_lock(&kvm->mmu_lock);
+		flush = kvm_tdp_mmu_zap_collapsible_sptes(kvm, slot, flush);
+		if (flush)
+			kvm_arch_flush_remote_tlbs_memslot(kvm, slot);
+		read_unlock(&kvm->mmu_lock);
+	}
 }
 
 void kvm_arch_flush_remote_tlbs_memslot(struct kvm *kvm,
-					struct kvm_memory_slot *memslot)
+					const struct kvm_memory_slot *memslot)
 {
 	/*
 	 * All current use cases for flushing the TLBs for a specific memslot
-	 * are related to dirty logging, and do the TLB flush out of mmu_lock.
+	 * related to dirty logging, and many do the TLB flush out of mmu_lock.
 	 * The interaction between the various operations on memslot must be
 	 * serialized by slots_locks to ensure the TLB flush from one operation
 	 * is observed by any other operation on the same memslot.
@@ -5560,10 +5659,14 @@
 
 	write_lock(&kvm->mmu_lock);
 	flush = slot_handle_leaf(kvm, memslot, __rmap_clear_dirty, false);
-	if (is_tdp_mmu_enabled(kvm))
-		flush |= kvm_tdp_mmu_clear_dirty_slot(kvm, memslot);
 	write_unlock(&kvm->mmu_lock);
 
+	if (is_tdp_mmu_enabled(kvm)) {
+		read_lock(&kvm->mmu_lock);
+		flush |= kvm_tdp_mmu_clear_dirty_slot(kvm, memslot);
+		read_unlock(&kvm->mmu_lock);
+	}
+
 	/*
 	 * It's also safe to flush TLBs out of mmu lock here as currently this
 	 * function is only used for dirty logging, in which case flushing TLB
@@ -5701,25 +5804,6 @@
 	kmem_cache_destroy(mmu_page_header_cache);
 }
 
-static void kvm_set_mmio_spte_mask(void)
-{
-	u64 mask;
-
-	/*
-	 * Set a reserved PA bit in MMIO SPTEs to generate page faults with
-	 * PFEC.RSVD=1 on MMIO accesses.  64-bit PTEs (PAE, x86-64, and EPT
-	 * paging) support a maximum of 52 bits of PA, i.e. if the CPU supports
-	 * 52-bit physical addresses then there are no reserved PA bits in the
-	 * PTEs and so the reserved PA approach must be disabled.
-	 */
-	if (shadow_phys_bits < 52)
-		mask = BIT_ULL(51) | PT_PRESENT_MASK;
-	else
-		mask = 0;
-
-	kvm_mmu_set_mmio_spte_mask(mask, ACC_WRITE_MASK | ACC_USER_MASK);
-}
-
 static bool get_nx_auto_mode(void)
 {
 	/* Return true when CPU has the bug, and mitigations are ON */
@@ -5785,8 +5869,6 @@
 
 	kvm_mmu_reset_all_pte_masks();
 
-	kvm_set_mmio_spte_mask();
-
 	pte_list_desc_cache = kmem_cache_create("pte_list_desc",
 					    sizeof(struct pte_list_desc),
 					    0, SLAB_ACCOUNT, NULL);
diff --git a/arch/x86/kvm/mmu/mmu_audit.c b/arch/x86/kvm/mmu/mmu_audit.c
index ced15fd..cedc17b 100644
--- a/arch/x86/kvm/mmu/mmu_audit.c
+++ b/arch/x86/kvm/mmu/mmu_audit.c
@@ -70,7 +70,7 @@
 	for (i = 0; i < 4; ++i) {
 		hpa_t root = vcpu->arch.mmu->pae_root[i];
 
-		if (root && VALID_PAGE(root)) {
+		if (IS_VALID_PAE_ROOT(root)) {
 			root &= PT64_BASE_ADDR_MASK;
 			sp = to_shadow_page(root);
 			__mmu_spte_walk(vcpu, sp, fn, 2);
diff --git a/arch/x86/kvm/mmu/mmu_internal.h b/arch/x86/kvm/mmu/mmu_internal.h
index 1f6f98c..f2546d6 100644
--- a/arch/x86/kvm/mmu/mmu_internal.h
+++ b/arch/x86/kvm/mmu/mmu_internal.h
@@ -20,6 +20,16 @@
 #define MMU_WARN_ON(x) do { } while (0)
 #endif
 
+/*
+ * Unlike regular MMU roots, PAE "roots", a.k.a. PDPTEs/PDPTRs, have a PRESENT
+ * bit, and thus are guaranteed to be non-zero when valid.  And, when a guest
+ * PDPTR is !PRESENT, its corresponding PAE root cannot be set to INVALID_PAGE,
+ * as the CPU would treat that as PRESENT PDPTR with reserved bits set.  Use
+ * '0' instead of INVALID_PAGE to indicate an invalid PAE root.
+ */
+#define INVALID_PAE_ROOT	0
+#define IS_VALID_PAE_ROOT(x)	(!!(x))
+
 struct kvm_mmu_page {
 	struct list_head link;
 	struct hlist_node hash_link;
@@ -40,7 +50,11 @@
 	u64 *spt;
 	/* hold the gfn of each spte inside spt */
 	gfn_t *gfns;
-	int root_count;          /* Currently serving as active root */
+	/* Currently serving as active root */
+	union {
+		int root_count;
+		refcount_t tdp_mmu_root_count;
+	};
 	unsigned int unsync_children;
 	struct kvm_rmap_head parent_ptes; /* rmap pointers to parent sptes */
 	DECLARE_BITMAP(unsync_child_bitmap, 512);
@@ -78,9 +92,14 @@
 	return to_shadow_page(__pa(sptep));
 }
 
+static inline int kvm_mmu_role_as_id(union kvm_mmu_page_role role)
+{
+	return role.smm ? 1 : 0;
+}
+
 static inline int kvm_mmu_page_as_id(struct kvm_mmu_page *sp)
 {
-	return sp->role.smm ? 1 : 0;
+	return kvm_mmu_role_as_id(sp->role);
 }
 
 static inline bool kvm_vcpu_ad_need_write_protect(struct kvm_vcpu *vcpu)
@@ -108,22 +127,6 @@
 void kvm_flush_remote_tlbs_with_address(struct kvm *kvm,
 					u64 start_gfn, u64 pages);
 
-static inline void kvm_mmu_get_root(struct kvm *kvm, struct kvm_mmu_page *sp)
-{
-	BUG_ON(!sp->root_count);
-	lockdep_assert_held(&kvm->mmu_lock);
-
-	++sp->root_count;
-}
-
-static inline bool kvm_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *sp)
-{
-	lockdep_assert_held(&kvm->mmu_lock);
-	--sp->root_count;
-
-	return !sp->root_count;
-}
-
 /*
  * Return values of handle_mmio_page_fault, mmu.page_fault, and fast_page_fault().
  *
@@ -146,8 +149,9 @@
 #define SET_SPTE_NEED_REMOTE_TLB_FLUSH	BIT(1)
 #define SET_SPTE_SPURIOUS		BIT(2)
 
-int kvm_mmu_max_mapping_level(struct kvm *kvm, struct kvm_memory_slot *slot,
-			      gfn_t gfn, kvm_pfn_t pfn, int max_level);
+int kvm_mmu_max_mapping_level(struct kvm *kvm,
+			      const struct kvm_memory_slot *slot, gfn_t gfn,
+			      kvm_pfn_t pfn, int max_level);
 int kvm_mmu_hugepage_adjust(struct kvm_vcpu *vcpu, gfn_t gfn,
 			    int max_level, kvm_pfn_t *pfnp,
 			    bool huge_page_disallowed, int *req_level);
diff --git a/arch/x86/kvm/mmu/paging_tmpl.h b/arch/x86/kvm/mmu/paging_tmpl.h
index 55d7b47..70b7e44 100644
--- a/arch/x86/kvm/mmu/paging_tmpl.h
+++ b/arch/x86/kvm/mmu/paging_tmpl.h
@@ -503,6 +503,7 @@
 #endif
 	walker->fault.address = addr;
 	walker->fault.nested_page_fault = mmu != vcpu->arch.walk_mmu;
+	walker->fault.async_page_fault = false;
 
 	trace_kvm_mmu_walker_error(walker->fault.error_code);
 	return 0;
@@ -1084,7 +1085,7 @@
 
 		nr_present++;
 
-		host_writable = sp->spt[i] & SPTE_HOST_WRITEABLE;
+		host_writable = sp->spt[i] & shadow_host_writable_mask;
 
 		set_spte_ret |= set_spte(vcpu, &sp->spt[i],
 					 pte_access, PG_LEVEL_4K,
diff --git a/arch/x86/kvm/mmu/spte.c b/arch/x86/kvm/mmu/spte.c
index ef55f0b..66d43ce 100644
--- a/arch/x86/kvm/mmu/spte.c
+++ b/arch/x86/kvm/mmu/spte.c
@@ -16,13 +16,20 @@
 #include "spte.h"
 
 #include <asm/e820/api.h>
+#include <asm/vmx.h>
 
+static bool __read_mostly enable_mmio_caching = true;
+module_param_named(mmio_caching, enable_mmio_caching, bool, 0444);
+
+u64 __read_mostly shadow_host_writable_mask;
+u64 __read_mostly shadow_mmu_writable_mask;
 u64 __read_mostly shadow_nx_mask;
 u64 __read_mostly shadow_x_mask; /* mutual exclusive with nx_mask */
 u64 __read_mostly shadow_user_mask;
 u64 __read_mostly shadow_accessed_mask;
 u64 __read_mostly shadow_dirty_mask;
 u64 __read_mostly shadow_mmio_value;
+u64 __read_mostly shadow_mmio_mask;
 u64 __read_mostly shadow_mmio_access_mask;
 u64 __read_mostly shadow_present_mask;
 u64 __read_mostly shadow_me_mask;
@@ -38,7 +45,6 @@
 	u64 mask;
 
 	WARN_ON(gen & ~MMIO_SPTE_GEN_MASK);
-	BUILD_BUG_ON((MMIO_SPTE_GEN_HIGH_MASK | MMIO_SPTE_GEN_LOW_MASK) & SPTE_SPECIAL_MASK);
 
 	mask = (gen << MMIO_SPTE_GEN_LOW_SHIFT) & MMIO_SPTE_GEN_LOW_MASK;
 	mask |= (gen << MMIO_SPTE_GEN_HIGH_SHIFT) & MMIO_SPTE_GEN_HIGH_MASK;
@@ -48,16 +54,18 @@
 u64 make_mmio_spte(struct kvm_vcpu *vcpu, u64 gfn, unsigned int access)
 {
 	u64 gen = kvm_vcpu_memslots(vcpu)->generation & MMIO_SPTE_GEN_MASK;
-	u64 mask = generation_mmio_spte_mask(gen);
+	u64 spte = generation_mmio_spte_mask(gen);
 	u64 gpa = gfn << PAGE_SHIFT;
 
+	WARN_ON_ONCE(!shadow_mmio_value);
+
 	access &= shadow_mmio_access_mask;
-	mask |= shadow_mmio_value | access;
-	mask |= gpa | shadow_nonpresent_or_rsvd_mask;
-	mask |= (gpa & shadow_nonpresent_or_rsvd_mask)
+	spte |= shadow_mmio_value | access;
+	spte |= gpa | shadow_nonpresent_or_rsvd_mask;
+	spte |= (gpa & shadow_nonpresent_or_rsvd_mask)
 		<< SHADOW_NONPRESENT_OR_RSVD_MASK_LEN;
 
-	return mask;
+	return spte;
 }
 
 static bool kvm_is_mmio_pfn(kvm_pfn_t pfn)
@@ -86,13 +94,20 @@
 		     bool can_unsync, bool host_writable, bool ad_disabled,
 		     u64 *new_spte)
 {
-	u64 spte = 0;
+	u64 spte = SPTE_MMU_PRESENT_MASK;
 	int ret = 0;
 
 	if (ad_disabled)
-		spte |= SPTE_AD_DISABLED_MASK;
+		spte |= SPTE_TDP_AD_DISABLED_MASK;
 	else if (kvm_vcpu_ad_need_write_protect(vcpu))
-		spte |= SPTE_AD_WRPROT_ONLY_MASK;
+		spte |= SPTE_TDP_AD_WRPROT_ONLY_MASK;
+
+	/*
+	 * Bits 62:52 of PAE SPTEs are reserved.  WARN if said bits are set
+	 * if PAE paging may be employed (shadow paging or any 32-bit KVM).
+	 */
+	WARN_ON_ONCE((!tdp_enabled || !IS_ENABLED(CONFIG_X86_64)) &&
+		     (spte & SPTE_TDP_AD_MASK));
 
 	/*
 	 * For the EPT case, shadow_present_mask is 0 if hardware
@@ -124,7 +139,7 @@
 			kvm_is_mmio_pfn(pfn));
 
 	if (host_writable)
-		spte |= SPTE_HOST_WRITEABLE;
+		spte |= shadow_host_writable_mask;
 	else
 		pte_access &= ~ACC_WRITE_MASK;
 
@@ -134,7 +149,7 @@
 	spte |= (u64)pfn << PAGE_SHIFT;
 
 	if (pte_access & ACC_WRITE_MASK) {
-		spte |= PT_WRITABLE_MASK | SPTE_MMU_WRITEABLE;
+		spte |= PT_WRITABLE_MASK | shadow_mmu_writable_mask;
 
 		/*
 		 * Optimization: for pte sync, if spte was writable the hash
@@ -150,7 +165,7 @@
 				 __func__, gfn);
 			ret |= SET_SPTE_WRITE_PROTECTED_PT;
 			pte_access &= ~ACC_WRITE_MASK;
-			spte &= ~(PT_WRITABLE_MASK | SPTE_MMU_WRITEABLE);
+			spte &= ~(PT_WRITABLE_MASK | shadow_mmu_writable_mask);
 		}
 	}
 
@@ -161,19 +176,20 @@
 		spte = mark_spte_for_access_track(spte);
 
 out:
+	WARN_ON(is_mmio_spte(spte));
 	*new_spte = spte;
 	return ret;
 }
 
 u64 make_nonleaf_spte(u64 *child_pt, bool ad_disabled)
 {
-	u64 spte;
+	u64 spte = SPTE_MMU_PRESENT_MASK;
 
-	spte = __pa(child_pt) | shadow_present_mask | PT_WRITABLE_MASK |
-	       shadow_user_mask | shadow_x_mask | shadow_me_mask;
+	spte |= __pa(child_pt) | shadow_present_mask | PT_WRITABLE_MASK |
+		shadow_user_mask | shadow_x_mask | shadow_me_mask;
 
 	if (ad_disabled)
-		spte |= SPTE_AD_DISABLED_MASK;
+		spte |= SPTE_TDP_AD_DISABLED_MASK;
 	else
 		spte |= shadow_accessed_mask;
 
@@ -188,7 +204,7 @@
 	new_spte |= (u64)new_pfn << PAGE_SHIFT;
 
 	new_spte &= ~PT_WRITABLE_MASK;
-	new_spte &= ~SPTE_HOST_WRITEABLE;
+	new_spte &= ~shadow_host_writable_mask;
 
 	new_spte = mark_spte_for_access_track(new_spte);
 
@@ -242,53 +258,68 @@
 	return spte;
 }
 
-void kvm_mmu_set_mmio_spte_mask(u64 mmio_value, u64 access_mask)
+void kvm_mmu_set_mmio_spte_mask(u64 mmio_value, u64 mmio_mask, u64 access_mask)
 {
 	BUG_ON((u64)(unsigned)access_mask != access_mask);
-	WARN_ON(mmio_value & (shadow_nonpresent_or_rsvd_mask << SHADOW_NONPRESENT_OR_RSVD_MASK_LEN));
 	WARN_ON(mmio_value & shadow_nonpresent_or_rsvd_lower_gfn_mask);
-	shadow_mmio_value = mmio_value | SPTE_MMIO_MASK;
+
+	if (!enable_mmio_caching)
+		mmio_value = 0;
+
+	/*
+	 * Disable MMIO caching if the MMIO value collides with the bits that
+	 * are used to hold the relocated GFN when the L1TF mitigation is
+	 * enabled.  This should never fire as there is no known hardware that
+	 * can trigger this condition, e.g. SME/SEV CPUs that require a custom
+	 * MMIO value are not susceptible to L1TF.
+	 */
+	if (WARN_ON(mmio_value & (shadow_nonpresent_or_rsvd_mask <<
+				  SHADOW_NONPRESENT_OR_RSVD_MASK_LEN)))
+		mmio_value = 0;
+
+	/*
+	 * The masked MMIO value must obviously match itself and a removed SPTE
+	 * must not get a false positive.  Removed SPTEs and MMIO SPTEs should
+	 * never collide as MMIO must set some RWX bits, and removed SPTEs must
+	 * not set any RWX bits.
+	 */
+	if (WARN_ON((mmio_value & mmio_mask) != mmio_value) ||
+	    WARN_ON(mmio_value && (REMOVED_SPTE & mmio_mask) == mmio_value))
+		mmio_value = 0;
+
+	shadow_mmio_value = mmio_value;
+	shadow_mmio_mask  = mmio_mask;
 	shadow_mmio_access_mask = access_mask;
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_set_mmio_spte_mask);
 
-/*
- * Sets the shadow PTE masks used by the MMU.
- *
- * Assumptions:
- *  - Setting either @accessed_mask or @dirty_mask requires setting both
- *  - At least one of @accessed_mask or @acc_track_mask must be set
- */
-void kvm_mmu_set_mask_ptes(u64 user_mask, u64 accessed_mask,
-		u64 dirty_mask, u64 nx_mask, u64 x_mask, u64 p_mask,
-		u64 acc_track_mask, u64 me_mask)
+void kvm_mmu_set_ept_masks(bool has_ad_bits, bool has_exec_only)
 {
-	BUG_ON(!dirty_mask != !accessed_mask);
-	BUG_ON(!accessed_mask && !acc_track_mask);
-	BUG_ON(acc_track_mask & SPTE_SPECIAL_MASK);
+	shadow_user_mask	= VMX_EPT_READABLE_MASK;
+	shadow_accessed_mask	= has_ad_bits ? VMX_EPT_ACCESS_BIT : 0ull;
+	shadow_dirty_mask	= has_ad_bits ? VMX_EPT_DIRTY_BIT : 0ull;
+	shadow_nx_mask		= 0ull;
+	shadow_x_mask		= VMX_EPT_EXECUTABLE_MASK;
+	shadow_present_mask	= has_exec_only ? 0ull : VMX_EPT_READABLE_MASK;
+	shadow_acc_track_mask	= VMX_EPT_RWX_MASK;
+	shadow_me_mask		= 0ull;
 
-	shadow_user_mask = user_mask;
-	shadow_accessed_mask = accessed_mask;
-	shadow_dirty_mask = dirty_mask;
-	shadow_nx_mask = nx_mask;
-	shadow_x_mask = x_mask;
-	shadow_present_mask = p_mask;
-	shadow_acc_track_mask = acc_track_mask;
-	shadow_me_mask = me_mask;
+	shadow_host_writable_mask = EPT_SPTE_HOST_WRITABLE;
+	shadow_mmu_writable_mask  = EPT_SPTE_MMU_WRITABLE;
+
+	/*
+	 * EPT Misconfigurations are generated if the value of bits 2:0
+	 * of an EPT paging-structure entry is 110b (write/execute).
+	 */
+	kvm_mmu_set_mmio_spte_mask(VMX_EPT_MISCONFIG_WX_VALUE,
+				   VMX_EPT_RWX_MASK, 0);
 }
-EXPORT_SYMBOL_GPL(kvm_mmu_set_mask_ptes);
+EXPORT_SYMBOL_GPL(kvm_mmu_set_ept_masks);
 
 void kvm_mmu_reset_all_pte_masks(void)
 {
 	u8 low_phys_bits;
-
-	shadow_user_mask = 0;
-	shadow_accessed_mask = 0;
-	shadow_dirty_mask = 0;
-	shadow_nx_mask = 0;
-	shadow_x_mask = 0;
-	shadow_present_mask = 0;
-	shadow_acc_track_mask = 0;
+	u64 mask;
 
 	shadow_phys_bits = kvm_get_shadow_phys_bits();
 
@@ -315,4 +346,30 @@
 
 	shadow_nonpresent_or_rsvd_lower_gfn_mask =
 		GENMASK_ULL(low_phys_bits - 1, PAGE_SHIFT);
+
+	shadow_user_mask	= PT_USER_MASK;
+	shadow_accessed_mask	= PT_ACCESSED_MASK;
+	shadow_dirty_mask	= PT_DIRTY_MASK;
+	shadow_nx_mask		= PT64_NX_MASK;
+	shadow_x_mask		= 0;
+	shadow_present_mask	= PT_PRESENT_MASK;
+	shadow_acc_track_mask	= 0;
+	shadow_me_mask		= sme_me_mask;
+
+	shadow_host_writable_mask = DEFAULT_SPTE_HOST_WRITEABLE;
+	shadow_mmu_writable_mask  = DEFAULT_SPTE_MMU_WRITEABLE;
+
+	/*
+	 * Set a reserved PA bit in MMIO SPTEs to generate page faults with
+	 * PFEC.RSVD=1 on MMIO accesses.  64-bit PTEs (PAE, x86-64, and EPT
+	 * paging) support a maximum of 52 bits of PA, i.e. if the CPU supports
+	 * 52-bit physical addresses then there are no reserved PA bits in the
+	 * PTEs and so the reserved PA approach must be disabled.
+	 */
+	if (shadow_phys_bits < 52)
+		mask = BIT_ULL(51) | PT_PRESENT_MASK;
+	else
+		mask = 0;
+
+	kvm_mmu_set_mmio_spte_mask(mask, mask, ACC_WRITE_MASK | ACC_USER_MASK);
 }
diff --git a/arch/x86/kvm/mmu/spte.h b/arch/x86/kvm/mmu/spte.h
index 6de3950..bca0ba1 100644
--- a/arch/x86/kvm/mmu/spte.h
+++ b/arch/x86/kvm/mmu/spte.h
@@ -5,18 +5,33 @@
 
 #include "mmu_internal.h"
 
-#define PT_FIRST_AVAIL_BITS_SHIFT 10
-#define PT64_SECOND_AVAIL_BITS_SHIFT 54
+/*
+ * A MMU present SPTE is backed by actual memory and may or may not be present
+ * in hardware.  E.g. MMIO SPTEs are not considered present.  Use bit 11, as it
+ * is ignored by all flavors of SPTEs and checking a low bit often generates
+ * better code than for a high bit, e.g. 56+.  MMU present checks are pervasive
+ * enough that the improved code generation is noticeable in KVM's footprint.
+ */
+#define SPTE_MMU_PRESENT_MASK		BIT_ULL(11)
 
 /*
- * The mask used to denote special SPTEs, which can be either MMIO SPTEs or
- * Access Tracking SPTEs.
+ * TDP SPTES (more specifically, EPT SPTEs) may not have A/D bits, and may also
+ * be restricted to using write-protection (for L2 when CPU dirty logging, i.e.
+ * PML, is enabled).  Use bits 52 and 53 to hold the type of A/D tracking that
+ * is must be employed for a given TDP SPTE.
+ *
+ * Note, the "enabled" mask must be '0', as bits 62:52 are _reserved_ for PAE
+ * paging, including NPT PAE.  This scheme works because legacy shadow paging
+ * is guaranteed to have A/D bits and write-protection is forced only for
+ * TDP with CPU dirty logging (PML).  If NPT ever gains PML-like support, it
+ * must be restricted to 64-bit KVM.
  */
-#define SPTE_SPECIAL_MASK (3ULL << 52)
-#define SPTE_AD_ENABLED_MASK (0ULL << 52)
-#define SPTE_AD_DISABLED_MASK (1ULL << 52)
-#define SPTE_AD_WRPROT_ONLY_MASK (2ULL << 52)
-#define SPTE_MMIO_MASK (3ULL << 52)
+#define SPTE_TDP_AD_SHIFT		52
+#define SPTE_TDP_AD_MASK		(3ULL << SPTE_TDP_AD_SHIFT)
+#define SPTE_TDP_AD_ENABLED_MASK	(0ULL << SPTE_TDP_AD_SHIFT)
+#define SPTE_TDP_AD_DISABLED_MASK	(1ULL << SPTE_TDP_AD_SHIFT)
+#define SPTE_TDP_AD_WRPROT_ONLY_MASK	(2ULL << SPTE_TDP_AD_SHIFT)
+static_assert(SPTE_TDP_AD_ENABLED_MASK == 0);
 
 #ifdef CONFIG_DYNAMIC_PHYSICAL_MASK
 #define PT64_BASE_ADDR_MASK (physical_mask & ~(u64)(PAGE_SIZE-1))
@@ -51,16 +66,46 @@
 	(((address) >> PT64_LEVEL_SHIFT(level)) & ((1 << PT64_LEVEL_BITS) - 1))
 #define SHADOW_PT_INDEX(addr, level) PT64_INDEX(addr, level)
 
-
-#define SPTE_HOST_WRITEABLE	(1ULL << PT_FIRST_AVAIL_BITS_SHIFT)
-#define SPTE_MMU_WRITEABLE	(1ULL << (PT_FIRST_AVAIL_BITS_SHIFT + 1))
+/* Bits 9 and 10 are ignored by all non-EPT PTEs. */
+#define DEFAULT_SPTE_HOST_WRITEABLE	BIT_ULL(9)
+#define DEFAULT_SPTE_MMU_WRITEABLE	BIT_ULL(10)
 
 /*
- * Due to limited space in PTEs, the MMIO generation is a 18 bit subset of
+ * The mask/shift to use for saving the original R/X bits when marking the PTE
+ * as not-present for access tracking purposes. We do not save the W bit as the
+ * PTEs being access tracked also need to be dirty tracked, so the W bit will be
+ * restored only when a write is attempted to the page.  This mask obviously
+ * must not overlap the A/D type mask.
+ */
+#define SHADOW_ACC_TRACK_SAVED_BITS_MASK (PT64_EPT_READABLE_MASK | \
+					  PT64_EPT_EXECUTABLE_MASK)
+#define SHADOW_ACC_TRACK_SAVED_BITS_SHIFT 54
+#define SHADOW_ACC_TRACK_SAVED_MASK	(SHADOW_ACC_TRACK_SAVED_BITS_MASK << \
+					 SHADOW_ACC_TRACK_SAVED_BITS_SHIFT)
+static_assert(!(SPTE_TDP_AD_MASK & SHADOW_ACC_TRACK_SAVED_MASK));
+
+/*
+ * Low ignored bits are at a premium for EPT, use high ignored bits, taking care
+ * to not overlap the A/D type mask or the saved access bits of access-tracked
+ * SPTEs when A/D bits are disabled.
+ */
+#define EPT_SPTE_HOST_WRITABLE		BIT_ULL(57)
+#define EPT_SPTE_MMU_WRITABLE		BIT_ULL(58)
+
+static_assert(!(EPT_SPTE_HOST_WRITABLE & SPTE_TDP_AD_MASK));
+static_assert(!(EPT_SPTE_MMU_WRITABLE & SPTE_TDP_AD_MASK));
+static_assert(!(EPT_SPTE_HOST_WRITABLE & SHADOW_ACC_TRACK_SAVED_MASK));
+static_assert(!(EPT_SPTE_MMU_WRITABLE & SHADOW_ACC_TRACK_SAVED_MASK));
+
+/* Defined only to keep the above static asserts readable. */
+#undef SHADOW_ACC_TRACK_SAVED_MASK
+
+/*
+ * Due to limited space in PTEs, the MMIO generation is a 19 bit subset of
  * the memslots generation and is derived as follows:
  *
- * Bits 0-8 of the MMIO generation are propagated to spte bits 3-11
- * Bits 9-17 of the MMIO generation are propagated to spte bits 54-62
+ * Bits 0-7 of the MMIO generation are propagated to spte bits 3-10
+ * Bits 8-18 of the MMIO generation are propagated to spte bits 52-62
  *
  * The KVM_MEMSLOT_GEN_UPDATE_IN_PROGRESS flag is intentionally not included in
  * the MMIO generation number, as doing so would require stealing a bit from
@@ -71,39 +116,44 @@
  */
 
 #define MMIO_SPTE_GEN_LOW_START		3
-#define MMIO_SPTE_GEN_LOW_END		11
+#define MMIO_SPTE_GEN_LOW_END		10
 
-#define MMIO_SPTE_GEN_HIGH_START	PT64_SECOND_AVAIL_BITS_SHIFT
+#define MMIO_SPTE_GEN_HIGH_START	52
 #define MMIO_SPTE_GEN_HIGH_END		62
 
 #define MMIO_SPTE_GEN_LOW_MASK		GENMASK_ULL(MMIO_SPTE_GEN_LOW_END, \
 						    MMIO_SPTE_GEN_LOW_START)
 #define MMIO_SPTE_GEN_HIGH_MASK		GENMASK_ULL(MMIO_SPTE_GEN_HIGH_END, \
 						    MMIO_SPTE_GEN_HIGH_START)
+static_assert(!(SPTE_MMU_PRESENT_MASK &
+		(MMIO_SPTE_GEN_LOW_MASK | MMIO_SPTE_GEN_HIGH_MASK)));
 
 #define MMIO_SPTE_GEN_LOW_BITS		(MMIO_SPTE_GEN_LOW_END - MMIO_SPTE_GEN_LOW_START + 1)
 #define MMIO_SPTE_GEN_HIGH_BITS		(MMIO_SPTE_GEN_HIGH_END - MMIO_SPTE_GEN_HIGH_START + 1)
 
 /* remember to adjust the comment above as well if you change these */
-static_assert(MMIO_SPTE_GEN_LOW_BITS == 9 && MMIO_SPTE_GEN_HIGH_BITS == 9);
+static_assert(MMIO_SPTE_GEN_LOW_BITS == 8 && MMIO_SPTE_GEN_HIGH_BITS == 11);
 
 #define MMIO_SPTE_GEN_LOW_SHIFT		(MMIO_SPTE_GEN_LOW_START - 0)
 #define MMIO_SPTE_GEN_HIGH_SHIFT	(MMIO_SPTE_GEN_HIGH_START - MMIO_SPTE_GEN_LOW_BITS)
 
 #define MMIO_SPTE_GEN_MASK		GENMASK_ULL(MMIO_SPTE_GEN_LOW_BITS + MMIO_SPTE_GEN_HIGH_BITS - 1, 0)
 
+extern u64 __read_mostly shadow_host_writable_mask;
+extern u64 __read_mostly shadow_mmu_writable_mask;
 extern u64 __read_mostly shadow_nx_mask;
 extern u64 __read_mostly shadow_x_mask; /* mutual exclusive with nx_mask */
 extern u64 __read_mostly shadow_user_mask;
 extern u64 __read_mostly shadow_accessed_mask;
 extern u64 __read_mostly shadow_dirty_mask;
 extern u64 __read_mostly shadow_mmio_value;
+extern u64 __read_mostly shadow_mmio_mask;
 extern u64 __read_mostly shadow_mmio_access_mask;
 extern u64 __read_mostly shadow_present_mask;
 extern u64 __read_mostly shadow_me_mask;
 
 /*
- * SPTEs used by MMUs without A/D bits are marked with SPTE_AD_DISABLED_MASK;
+ * SPTEs in MMUs without A/D bits are marked with SPTE_TDP_AD_DISABLED_MASK;
  * shadow_acc_track_mask is the set of bits to be cleared in non-accessed
  * pages.
  */
@@ -121,28 +171,21 @@
 #define SHADOW_NONPRESENT_OR_RSVD_MASK_LEN 5
 
 /*
- * The mask/shift to use for saving the original R/X bits when marking the PTE
- * as not-present for access tracking purposes. We do not save the W bit as the
- * PTEs being access tracked also need to be dirty tracked, so the W bit will be
- * restored only when a write is attempted to the page.
- */
-#define SHADOW_ACC_TRACK_SAVED_BITS_MASK (PT64_EPT_READABLE_MASK | \
-					  PT64_EPT_EXECUTABLE_MASK)
-#define SHADOW_ACC_TRACK_SAVED_BITS_SHIFT PT64_SECOND_AVAIL_BITS_SHIFT
-
-/*
  * If a thread running without exclusive control of the MMU lock must perform a
  * multi-part operation on an SPTE, it can set the SPTE to REMOVED_SPTE as a
  * non-present intermediate value. Other threads which encounter this value
  * should not modify the SPTE.
  *
- * This constant works because it is considered non-present on both AMD and
- * Intel CPUs and does not create a L1TF vulnerability because the pfn section
- * is zeroed out.
+ * Use a semi-arbitrary value that doesn't set RWX bits, i.e. is not-present on
+ * bot AMD and Intel CPUs, and doesn't set PFN bits, i.e. doesn't create a L1TF
+ * vulnerability.  Use only low bits to avoid 64-bit immediates.
  *
  * Only used by the TDP MMU.
  */
-#define REMOVED_SPTE (1ull << 59)
+#define REMOVED_SPTE	0x5a0ULL
+
+/* Removed SPTEs must not be misconstrued as shadow present PTEs. */
+static_assert(!(REMOVED_SPTE & SPTE_MMU_PRESENT_MASK));
 
 static inline bool is_removed_spte(u64 spte)
 {
@@ -167,7 +210,13 @@
 
 static inline bool is_mmio_spte(u64 spte)
 {
-	return (spte & SPTE_SPECIAL_MASK) == SPTE_MMIO_MASK;
+	return (spte & shadow_mmio_mask) == shadow_mmio_value &&
+	       likely(shadow_mmio_value);
+}
+
+static inline bool is_shadow_present_pte(u64 pte)
+{
+	return !!(pte & SPTE_MMU_PRESENT_MASK);
 }
 
 static inline bool sp_ad_disabled(struct kvm_mmu_page *sp)
@@ -177,25 +226,30 @@
 
 static inline bool spte_ad_enabled(u64 spte)
 {
-	MMU_WARN_ON(is_mmio_spte(spte));
-	return (spte & SPTE_SPECIAL_MASK) != SPTE_AD_DISABLED_MASK;
+	MMU_WARN_ON(!is_shadow_present_pte(spte));
+	return (spte & SPTE_TDP_AD_MASK) != SPTE_TDP_AD_DISABLED_MASK;
 }
 
 static inline bool spte_ad_need_write_protect(u64 spte)
 {
-	MMU_WARN_ON(is_mmio_spte(spte));
-	return (spte & SPTE_SPECIAL_MASK) != SPTE_AD_ENABLED_MASK;
+	MMU_WARN_ON(!is_shadow_present_pte(spte));
+	/*
+	 * This is benign for non-TDP SPTEs as SPTE_TDP_AD_ENABLED_MASK is '0',
+	 * and non-TDP SPTEs will never set these bits.  Optimize for 64-bit
+	 * TDP and do the A/D type check unconditionally.
+	 */
+	return (spte & SPTE_TDP_AD_MASK) != SPTE_TDP_AD_ENABLED_MASK;
 }
 
 static inline u64 spte_shadow_accessed_mask(u64 spte)
 {
-	MMU_WARN_ON(is_mmio_spte(spte));
+	MMU_WARN_ON(!is_shadow_present_pte(spte));
 	return spte_ad_enabled(spte) ? shadow_accessed_mask : 0;
 }
 
 static inline u64 spte_shadow_dirty_mask(u64 spte)
 {
-	MMU_WARN_ON(is_mmio_spte(spte));
+	MMU_WARN_ON(!is_shadow_present_pte(spte));
 	return spte_ad_enabled(spte) ? shadow_dirty_mask : 0;
 }
 
@@ -204,11 +258,6 @@
 	return !spte_ad_enabled(spte) && (spte & shadow_acc_track_mask) == 0;
 }
 
-static inline bool is_shadow_present_pte(u64 pte)
-{
-	return (pte != 0) && !is_mmio_spte(pte) && !is_removed_spte(pte);
-}
-
 static inline bool is_large_pte(u64 pte)
 {
 	return pte & PT_PAGE_SIZE_MASK;
@@ -246,8 +295,8 @@
 
 static inline bool spte_can_locklessly_be_made_writable(u64 spte)
 {
-	return (spte & (SPTE_HOST_WRITEABLE | SPTE_MMU_WRITEABLE)) ==
-		(SPTE_HOST_WRITEABLE | SPTE_MMU_WRITEABLE);
+	return (spte & shadow_host_writable_mask) &&
+	       (spte & shadow_mmu_writable_mask);
 }
 
 static inline u64 get_mmio_spte_generation(u64 spte)
diff --git a/arch/x86/kvm/mmu/tdp_mmu.c b/arch/x86/kvm/mmu/tdp_mmu.c
index 018d82e7..8ce8d09 100644
--- a/arch/x86/kvm/mmu/tdp_mmu.c
+++ b/arch/x86/kvm/mmu/tdp_mmu.c
@@ -27,6 +27,15 @@
 	INIT_LIST_HEAD(&kvm->arch.tdp_mmu_pages);
 }
 
+static __always_inline void kvm_lockdep_assert_mmu_lock_held(struct kvm *kvm,
+							     bool shared)
+{
+	if (shared)
+		lockdep_assert_held_read(&kvm->mmu_lock);
+	else
+		lockdep_assert_held_write(&kvm->mmu_lock);
+}
+
 void kvm_mmu_uninit_tdp_mmu(struct kvm *kvm)
 {
 	if (!kvm->arch.tdp_mmu_enabled)
@@ -41,32 +50,85 @@
 	rcu_barrier();
 }
 
-static void tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root)
+static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
+			  gfn_t start, gfn_t end, bool can_yield, bool flush,
+			  bool shared);
+
+static void tdp_mmu_free_sp(struct kvm_mmu_page *sp)
 {
-	if (kvm_mmu_put_root(kvm, root))
-		kvm_tdp_mmu_free_root(kvm, root);
+	free_page((unsigned long)sp->spt);
+	kmem_cache_free(mmu_page_header_cache, sp);
 }
 
-static inline bool tdp_mmu_next_root_valid(struct kvm *kvm,
-					   struct kvm_mmu_page *root)
+/*
+ * This is called through call_rcu in order to free TDP page table memory
+ * safely with respect to other kernel threads that may be operating on
+ * the memory.
+ * By only accessing TDP MMU page table memory in an RCU read critical
+ * section, and freeing it after a grace period, lockless access to that
+ * memory won't use it after it is freed.
+ */
+static void tdp_mmu_free_sp_rcu_callback(struct rcu_head *head)
 {
-	lockdep_assert_held_write(&kvm->mmu_lock);
+	struct kvm_mmu_page *sp = container_of(head, struct kvm_mmu_page,
+					       rcu_head);
 
-	if (list_entry_is_head(root, &kvm->arch.tdp_mmu_roots, link))
-		return false;
-
-	kvm_mmu_get_root(kvm, root);
-	return true;
-
+	tdp_mmu_free_sp(sp);
 }
 
-static inline struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
-						     struct kvm_mmu_page *root)
+void kvm_tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root,
+			  bool shared)
+{
+	gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
+
+	kvm_lockdep_assert_mmu_lock_held(kvm, shared);
+
+	if (!refcount_dec_and_test(&root->tdp_mmu_root_count))
+		return;
+
+	WARN_ON(!root->tdp_mmu_page);
+
+	spin_lock(&kvm->arch.tdp_mmu_pages_lock);
+	list_del_rcu(&root->link);
+	spin_unlock(&kvm->arch.tdp_mmu_pages_lock);
+
+	zap_gfn_range(kvm, root, 0, max_gfn, false, false, shared);
+
+	call_rcu(&root->rcu_head, tdp_mmu_free_sp_rcu_callback);
+}
+
+/*
+ * Finds the next valid root after root (or the first valid root if root
+ * is NULL), takes a reference on it, and returns that next root. If root
+ * is not NULL, this thread should have already taken a reference on it, and
+ * that reference will be dropped. If no valid root is found, this
+ * function will return NULL.
+ */
+static struct kvm_mmu_page *tdp_mmu_next_root(struct kvm *kvm,
+					      struct kvm_mmu_page *prev_root,
+					      bool shared)
 {
 	struct kvm_mmu_page *next_root;
 
-	next_root = list_next_entry(root, link);
-	tdp_mmu_put_root(kvm, root);
+	rcu_read_lock();
+
+	if (prev_root)
+		next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						  &prev_root->link,
+						  typeof(*prev_root), link);
+	else
+		next_root = list_first_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						   typeof(*next_root), link);
+
+	while (next_root && !kvm_tdp_mmu_get_root(kvm, next_root))
+		next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+				&next_root->link, typeof(*next_root), link);
+
+	rcu_read_unlock();
+
+	if (prev_root)
+		kvm_tdp_mmu_put_root(kvm, prev_root, shared);
+
 	return next_root;
 }
 
@@ -75,35 +137,24 @@
  * This makes it safe to release the MMU lock and yield within the loop, but
  * if exiting the loop early, the caller must drop the reference to the most
  * recent root. (Unless keeping a live reference is desirable.)
+ *
+ * If shared is set, this function is operating under the MMU lock in read
+ * mode. In the unlikely event that this thread must free a root, the lock
+ * will be temporarily dropped and reacquired in write mode.
  */
-#define for_each_tdp_mmu_root_yield_safe(_kvm, _root)				\
-	for (_root = list_first_entry(&_kvm->arch.tdp_mmu_roots,	\
-				      typeof(*_root), link);		\
-	     tdp_mmu_next_root_valid(_kvm, _root);			\
-	     _root = tdp_mmu_next_root(_kvm, _root))
+#define for_each_tdp_mmu_root_yield_safe(_kvm, _root, _as_id, _shared)	\
+	for (_root = tdp_mmu_next_root(_kvm, NULL, _shared);		\
+	     _root;							\
+	     _root = tdp_mmu_next_root(_kvm, _root, _shared))		\
+		if (kvm_mmu_page_as_id(_root) != _as_id) {		\
+		} else
 
-#define for_each_tdp_mmu_root(_kvm, _root)				\
-	list_for_each_entry(_root, &_kvm->arch.tdp_mmu_roots, link)
-
-static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
-			  gfn_t start, gfn_t end, bool can_yield, bool flush);
-
-void kvm_tdp_mmu_free_root(struct kvm *kvm, struct kvm_mmu_page *root)
-{
-	gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
-
-	lockdep_assert_held_write(&kvm->mmu_lock);
-
-	WARN_ON(root->root_count);
-	WARN_ON(!root->tdp_mmu_page);
-
-	list_del(&root->link);
-
-	zap_gfn_range(kvm, root, 0, max_gfn, false, false);
-
-	free_page((unsigned long)root->spt);
-	kmem_cache_free(mmu_page_header_cache, root);
-}
+#define for_each_tdp_mmu_root(_kvm, _root, _as_id)				\
+	list_for_each_entry_rcu(_root, &_kvm->arch.tdp_mmu_roots, link,		\
+				lockdep_is_held_type(&kvm->mmu_lock, 0) ||	\
+				lockdep_is_help(&kvm->arch.tdp_mmu_pages_lock))	\
+		if (kvm_mmu_page_as_id(_root) != _as_id) {		\
+		} else
 
 static union kvm_mmu_page_role page_role_for_level(struct kvm_vcpu *vcpu,
 						   int level)
@@ -137,81 +188,46 @@
 	return sp;
 }
 
-static struct kvm_mmu_page *get_tdp_mmu_vcpu_root(struct kvm_vcpu *vcpu)
+hpa_t kvm_tdp_mmu_get_vcpu_root_hpa(struct kvm_vcpu *vcpu)
 {
 	union kvm_mmu_page_role role;
 	struct kvm *kvm = vcpu->kvm;
 	struct kvm_mmu_page *root;
 
+	lockdep_assert_held_write(&kvm->mmu_lock);
+
 	role = page_role_for_level(vcpu, vcpu->arch.mmu->shadow_root_level);
 
-	write_lock(&kvm->mmu_lock);
-
 	/* Check for an existing root before allocating a new one. */
-	for_each_tdp_mmu_root(kvm, root) {
-		if (root->role.word == role.word) {
-			kvm_mmu_get_root(kvm, root);
-			write_unlock(&kvm->mmu_lock);
-			return root;
-		}
+	for_each_tdp_mmu_root(kvm, root, kvm_mmu_role_as_id(role)) {
+		if (root->role.word == role.word &&
+		    kvm_tdp_mmu_get_root(kvm, root))
+			goto out;
 	}
 
 	root = alloc_tdp_mmu_page(vcpu, 0, vcpu->arch.mmu->shadow_root_level);
-	root->root_count = 1;
+	refcount_set(&root->tdp_mmu_root_count, 1);
 
-	list_add(&root->link, &kvm->arch.tdp_mmu_roots);
+	spin_lock(&kvm->arch.tdp_mmu_pages_lock);
+	list_add_rcu(&root->link, &kvm->arch.tdp_mmu_roots);
+	spin_unlock(&kvm->arch.tdp_mmu_pages_lock);
 
-	write_unlock(&kvm->mmu_lock);
-
-	return root;
-}
-
-hpa_t kvm_tdp_mmu_get_vcpu_root_hpa(struct kvm_vcpu *vcpu)
-{
-	struct kvm_mmu_page *root;
-
-	root = get_tdp_mmu_vcpu_root(vcpu);
-	if (!root)
-		return INVALID_PAGE;
-
+out:
 	return __pa(root->spt);
 }
 
-static void tdp_mmu_free_sp(struct kvm_mmu_page *sp)
-{
-	free_page((unsigned long)sp->spt);
-	kmem_cache_free(mmu_page_header_cache, sp);
-}
-
-/*
- * This is called through call_rcu in order to free TDP page table memory
- * safely with respect to other kernel threads that may be operating on
- * the memory.
- * By only accessing TDP MMU page table memory in an RCU read critical
- * section, and freeing it after a grace period, lockless access to that
- * memory won't use it after it is freed.
- */
-static void tdp_mmu_free_sp_rcu_callback(struct rcu_head *head)
-{
-	struct kvm_mmu_page *sp = container_of(head, struct kvm_mmu_page,
-					       rcu_head);
-
-	tdp_mmu_free_sp(sp);
-}
-
 static void handle_changed_spte(struct kvm *kvm, int as_id, gfn_t gfn,
 				u64 old_spte, u64 new_spte, int level,
 				bool shared);
 
 static void handle_changed_spte_acc_track(u64 old_spte, u64 new_spte, int level)
 {
-	bool pfn_changed = spte_to_pfn(old_spte) != spte_to_pfn(new_spte);
-
 	if (!is_shadow_present_pte(old_spte) || !is_last_spte(old_spte, level))
 		return;
 
 	if (is_accessed_spte(old_spte) &&
-	    (!is_accessed_spte(new_spte) || pfn_changed))
+	    (!is_shadow_present_pte(new_spte) || !is_accessed_spte(new_spte) ||
+	     spte_to_pfn(old_spte) != spte_to_pfn(new_spte)))
 		kvm_set_pfn_accessed(spte_to_pfn(old_spte));
 }
 
@@ -455,7 +471,7 @@
 
 
 	if (was_leaf && is_dirty_spte(old_spte) &&
-	    (!is_dirty_spte(new_spte) || pfn_changed))
+	    (!is_present || !is_dirty_spte(new_spte) || pfn_changed))
 		kvm_set_pfn_dirty(spte_to_pfn(old_spte));
 
 	/*
@@ -479,8 +495,9 @@
 }
 
 /*
- * tdp_mmu_set_spte_atomic - Set a TDP MMU SPTE atomically and handle the
- * associated bookkeeping
+ * tdp_mmu_set_spte_atomic_no_dirty_log - Set a TDP MMU SPTE atomically
+ * and handle the associated bookkeeping, but do not mark the page dirty
+ * in KVM's dirty bitmaps.
  *
  * @kvm: kvm instance
  * @iter: a tdp_iter instance currently on the SPTE that should be set
@@ -488,9 +505,9 @@
  * Returns: true if the SPTE was set, false if it was not. If false is returned,
  *	    this function will have no side-effects.
  */
-static inline bool tdp_mmu_set_spte_atomic(struct kvm *kvm,
-					   struct tdp_iter *iter,
-					   u64 new_spte)
+static inline bool tdp_mmu_set_spte_atomic_no_dirty_log(struct kvm *kvm,
+							struct tdp_iter *iter,
+							u64 new_spte)
 {
 	lockdep_assert_held_read(&kvm->mmu_lock);
 
@@ -498,19 +515,32 @@
 	 * Do not change removed SPTEs. Only the thread that froze the SPTE
 	 * may modify it.
 	 */
-	if (iter->old_spte == REMOVED_SPTE)
+	if (is_removed_spte(iter->old_spte))
 		return false;
 
 	if (cmpxchg64(rcu_dereference(iter->sptep), iter->old_spte,
 		      new_spte) != iter->old_spte)
 		return false;
 
-	handle_changed_spte(kvm, iter->as_id, iter->gfn, iter->old_spte,
-			    new_spte, iter->level, true);
+	__handle_changed_spte(kvm, iter->as_id, iter->gfn, iter->old_spte,
+			      new_spte, iter->level, true);
+	handle_changed_spte_acc_track(iter->old_spte, new_spte, iter->level);
 
 	return true;
 }
 
+static inline bool tdp_mmu_set_spte_atomic(struct kvm *kvm,
+					   struct tdp_iter *iter,
+					   u64 new_spte)
+{
+	if (!tdp_mmu_set_spte_atomic_no_dirty_log(kvm, iter, new_spte))
+		return false;
+
+	handle_changed_spte_dirty_log(kvm, iter->as_id, iter->gfn,
+				      iter->old_spte, new_spte, iter->level);
+	return true;
+}
+
 static inline bool tdp_mmu_zap_spte_atomic(struct kvm *kvm,
 					   struct tdp_iter *iter)
 {
@@ -569,7 +599,7 @@
 	 * should be used. If operating under the MMU lock in write mode, the
 	 * use of the removed SPTE should not be necessary.
 	 */
-	WARN_ON(iter->old_spte == REMOVED_SPTE);
+	WARN_ON(is_removed_spte(iter->old_spte));
 
 	WRITE_ONCE(*rcu_dereference(iter->sptep), new_spte);
 
@@ -634,7 +664,8 @@
  * Return false if a yield was not needed.
  */
 static inline bool tdp_mmu_iter_cond_resched(struct kvm *kvm,
-					     struct tdp_iter *iter, bool flush)
+					     struct tdp_iter *iter, bool flush,
+					     bool shared)
 {
 	/* Ensure forward progress has been made before yielding. */
 	if (iter->next_last_level_gfn == iter->yielded_gfn)
@@ -646,7 +677,11 @@
 		if (flush)
 			kvm_flush_remote_tlbs(kvm);
 
-		cond_resched_rwlock_write(&kvm->mmu_lock);
+		if (shared)
+			cond_resched_rwlock_read(&kvm->mmu_lock);
+		else
+			cond_resched_rwlock_write(&kvm->mmu_lock);
+
 		rcu_read_lock();
 
 		WARN_ON(iter->gfn > iter->next_last_level_gfn);
@@ -664,24 +699,32 @@
  * non-root pages mapping GFNs strictly within that range. Returns true if
  * SPTEs have been cleared and a TLB flush is needed before releasing the
  * MMU lock.
+ *
  * If can_yield is true, will release the MMU lock and reschedule if the
  * scheduler needs the CPU or there is contention on the MMU lock. If this
  * function cannot yield, it will not release the MMU lock or reschedule and
  * the caller must ensure it does not supply too large a GFN range, or the
- * operation can cause a soft lockup.  Note, in some use cases a flush may be
- * required by prior actions.  Ensure the pending flush is performed prior to
- * yielding.
+ * operation can cause a soft lockup.
+ *
+ * If shared is true, this thread holds the MMU lock in read mode and must
+ * account for the possibility that other threads are modifying the paging
+ * structures concurrently. If shared is false, this thread should hold the
+ * MMU lock in write mode.
  */
 static bool zap_gfn_range(struct kvm *kvm, struct kvm_mmu_page *root,
-			  gfn_t start, gfn_t end, bool can_yield, bool flush)
+			  gfn_t start, gfn_t end, bool can_yield, bool flush,
+			  bool shared)
 {
 	struct tdp_iter iter;
 
+	kvm_lockdep_assert_mmu_lock_held(kvm, shared);
+
 	rcu_read_lock();
 
 	tdp_root_for_each_pte(iter, root, start, end) {
+retry:
 		if (can_yield &&
-		    tdp_mmu_iter_cond_resched(kvm, &iter, flush)) {
+		    tdp_mmu_iter_cond_resched(kvm, &iter, flush, shared)) {
 			flush = false;
 			continue;
 		}
@@ -699,8 +742,17 @@
 		    !is_last_spte(iter.old_spte, iter.level))
 			continue;
 
-		tdp_mmu_set_spte(kvm, &iter, 0);
-		flush = true;
+		if (!shared) {
+			tdp_mmu_set_spte(kvm, &iter, 0);
+			flush = true;
+		} else if (!tdp_mmu_zap_spte_atomic(kvm, &iter)) {
+			/*
+			 * The iter must explicitly re-read the SPTE because
+			 * the atomic cmpxchg failed.
+			 */
+			iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+			goto retry;
+		}
 	}
 
 	rcu_read_unlock();
@@ -712,15 +764,21 @@
  * non-root pages mapping GFNs strictly within that range. Returns true if
  * SPTEs have been cleared and a TLB flush is needed before releasing the
  * MMU lock.
+ *
+ * If shared is true, this thread holds the MMU lock in read mode and must
+ * account for the possibility that other threads are modifying the paging
+ * structures concurrently. If shared is false, this thread should hold the
+ * MMU in write mode.
  */
-bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, gfn_t start, gfn_t end,
-				 bool can_yield)
+bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, int as_id, gfn_t start,
+				 gfn_t end, bool can_yield, bool flush,
+				 bool shared)
 {
 	struct kvm_mmu_page *root;
-	bool flush = false;
 
-	for_each_tdp_mmu_root_yield_safe(kvm, root)
-		flush = zap_gfn_range(kvm, root, start, end, can_yield, flush);
+	for_each_tdp_mmu_root_yield_safe(kvm, root, as_id, shared)
+		flush = zap_gfn_range(kvm, root, start, end, can_yield, flush,
+				      shared);
 
 	return flush;
 }
@@ -728,9 +786,81 @@
 void kvm_tdp_mmu_zap_all(struct kvm *kvm)
 {
 	gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
-	bool flush;
+	bool flush = false;
+	int i;
 
-	flush = kvm_tdp_mmu_zap_gfn_range(kvm, 0, max_gfn);
+	for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++)
+		flush = kvm_tdp_mmu_zap_gfn_range(kvm, i, 0, max_gfn,
+						  flush, false);
+
+	if (flush)
+		kvm_flush_remote_tlbs(kvm);
+}
+
+static struct kvm_mmu_page *next_invalidated_root(struct kvm *kvm,
+						  struct kvm_mmu_page *prev_root)
+{
+	struct kvm_mmu_page *next_root;
+
+	if (prev_root)
+		next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						  &prev_root->link,
+						  typeof(*prev_root), link);
+	else
+		next_root = list_first_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						   typeof(*next_root), link);
+
+	while (next_root && !(next_root->role.invalid &&
+			      refcount_read(&next_root->tdp_mmu_root_count)))
+		next_root = list_next_or_null_rcu(&kvm->arch.tdp_mmu_roots,
+						  &next_root->link,
+						  typeof(*next_root), link);
+
+	return next_root;
+}
+
+/*
+ * Since kvm_mmu_zap_all_fast has acquired a reference to each
+ * invalidated root, they will not be freed until this function drops the
+ * reference. Before dropping that reference, tear down the paging
+ * structure so that whichever thread does drop the last reference
+ * only has to do a trivial ammount of work. Since the roots are invalid,
+ * no new SPTEs should be created under them.
+ */
+void kvm_tdp_mmu_zap_invalidated_roots(struct kvm *kvm)
+{
+	gfn_t max_gfn = 1ULL << (shadow_phys_bits - PAGE_SHIFT);
+	struct kvm_mmu_page *next_root;
+	struct kvm_mmu_page *root;
+	bool flush = false;
+
+	lockdep_assert_held_read(&kvm->mmu_lock);
+
+	rcu_read_lock();
+
+	root = next_invalidated_root(kvm, NULL);
+
+	while (root) {
+		next_root = next_invalidated_root(kvm, root);
+
+		rcu_read_unlock();
+
+		flush = zap_gfn_range(kvm, root, 0, max_gfn, true, flush,
+				      true);
+
+		/*
+		 * Put the reference acquired in
+		 * kvm_tdp_mmu_invalidate_roots
+		 */
+		kvm_tdp_mmu_put_root(kvm, root, true);
+
+		root = next_root;
+
+		rcu_read_lock();
+	}
+
+	rcu_read_unlock();
+
 	if (flush)
 		kvm_flush_remote_tlbs(kvm);
 }
@@ -777,12 +907,11 @@
 		trace_mark_mmio_spte(rcu_dereference(iter->sptep), iter->gfn,
 				     new_spte);
 		ret = RET_PF_EMULATE;
-	} else
+	} else {
 		trace_kvm_mmu_set_spte(iter->level, iter->gfn,
 				       rcu_dereference(iter->sptep));
+	}
 
-	trace_kvm_mmu_set_spte(iter->level, iter->gfn,
-			       rcu_dereference(iter->sptep));
 	if (!prefault)
 		vcpu->stat.pf_fixed++;
 
@@ -882,140 +1011,122 @@
 	return ret;
 }
 
-static __always_inline int
-kvm_tdp_mmu_handle_hva_range(struct kvm *kvm,
-			     unsigned long start,
-			     unsigned long end,
-			     unsigned long data,
-			     int (*handler)(struct kvm *kvm,
-					    struct kvm_memory_slot *slot,
-					    struct kvm_mmu_page *root,
-					    gfn_t start,
-					    gfn_t end,
-					    unsigned long data))
+bool kvm_tdp_mmu_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range,
+				 bool flush)
 {
-	struct kvm_memslots *slots;
-	struct kvm_memory_slot *memslot;
 	struct kvm_mmu_page *root;
-	int ret = 0;
-	int as_id;
 
-	for_each_tdp_mmu_root_yield_safe(kvm, root) {
-		as_id = kvm_mmu_page_as_id(root);
-		slots = __kvm_memslots(kvm, as_id);
-		kvm_for_each_memslot(memslot, slots) {
-			unsigned long hva_start, hva_end;
-			gfn_t gfn_start, gfn_end;
+	for_each_tdp_mmu_root(kvm, root, range->slot->as_id)
+		flush |= zap_gfn_range(kvm, root, range->start, range->end,
+				       range->may_block, flush, false);
 
-			hva_start = max(start, memslot->userspace_addr);
-			hva_end = min(end, memslot->userspace_addr +
-				      (memslot->npages << PAGE_SHIFT));
-			if (hva_start >= hva_end)
-				continue;
-			/*
-			 * {gfn(page) | page intersects with [hva_start, hva_end)} =
-			 * {gfn_start, gfn_start+1, ..., gfn_end-1}.
-			 */
-			gfn_start = hva_to_gfn_memslot(hva_start, memslot);
-			gfn_end = hva_to_gfn_memslot(hva_end + PAGE_SIZE - 1, memslot);
+	return flush;
+}
 
-			ret |= handler(kvm, memslot, root, gfn_start,
-				       gfn_end, data);
-		}
+typedef bool (*tdp_handler_t)(struct kvm *kvm, struct tdp_iter *iter,
+			      struct kvm_gfn_range *range);
+
+static __always_inline bool kvm_tdp_mmu_handle_gfn(struct kvm *kvm,
+						   struct kvm_gfn_range *range,
+						   tdp_handler_t handler)
+{
+	struct kvm_mmu_page *root;
+	struct tdp_iter iter;
+	bool ret = false;
+
+	rcu_read_lock();
+
+	/*
+	 * Don't support rescheduling, none of the MMU notifiers that funnel
+	 * into this helper allow blocking; it'd be dead, wasteful code.
+	 */
+	for_each_tdp_mmu_root(kvm, root, range->slot->as_id) {
+		tdp_root_for_each_leaf_pte(iter, root, range->start, range->end)
+			ret |= handler(kvm, &iter, range);
 	}
 
+	rcu_read_unlock();
+
 	return ret;
 }
 
-static int zap_gfn_range_hva_wrapper(struct kvm *kvm,
-				     struct kvm_memory_slot *slot,
-				     struct kvm_mmu_page *root, gfn_t start,
-				     gfn_t end, unsigned long unused)
-{
-	return zap_gfn_range(kvm, root, start, end, false, false);
-}
-
-int kvm_tdp_mmu_zap_hva_range(struct kvm *kvm, unsigned long start,
-			      unsigned long end)
-{
-	return kvm_tdp_mmu_handle_hva_range(kvm, start, end, 0,
-					    zap_gfn_range_hva_wrapper);
-}
-
 /*
  * Mark the SPTEs range of GFNs [start, end) unaccessed and return non-zero
  * if any of the GFNs in the range have been accessed.
  */
-static int age_gfn_range(struct kvm *kvm, struct kvm_memory_slot *slot,
-			 struct kvm_mmu_page *root, gfn_t start, gfn_t end,
-			 unsigned long unused)
+static bool age_gfn_range(struct kvm *kvm, struct tdp_iter *iter,
+			  struct kvm_gfn_range *range)
 {
-	struct tdp_iter iter;
-	int young = 0;
 	u64 new_spte = 0;
 
-	rcu_read_lock();
+	/* If we have a non-accessed entry we don't need to change the pte. */
+	if (!is_accessed_spte(iter->old_spte))
+		return false;
 
-	tdp_root_for_each_leaf_pte(iter, root, start, end) {
+	new_spte = iter->old_spte;
+
+	if (spte_ad_enabled(new_spte)) {
+		new_spte &= ~shadow_accessed_mask;
+	} else {
 		/*
-		 * If we have a non-accessed entry we don't need to change the
-		 * pte.
+		 * Capture the dirty status of the page, so that it doesn't get
+		 * lost when the SPTE is marked for access tracking.
 		 */
-		if (!is_accessed_spte(iter.old_spte))
-			continue;
+		if (is_writable_pte(new_spte))
+			kvm_set_pfn_dirty(spte_to_pfn(new_spte));
 
-		new_spte = iter.old_spte;
-
-		if (spte_ad_enabled(new_spte)) {
-			clear_bit((ffs(shadow_accessed_mask) - 1),
-				  (unsigned long *)&new_spte);
-		} else {
-			/*
-			 * Capture the dirty status of the page, so that it doesn't get
-			 * lost when the SPTE is marked for access tracking.
-			 */
-			if (is_writable_pte(new_spte))
-				kvm_set_pfn_dirty(spte_to_pfn(new_spte));
-
-			new_spte = mark_spte_for_access_track(new_spte);
-		}
-		new_spte &= ~shadow_dirty_mask;
-
-		tdp_mmu_set_spte_no_acc_track(kvm, &iter, new_spte);
-		young = 1;
-
-		trace_kvm_age_page(iter.gfn, iter.level, slot, young);
+		new_spte = mark_spte_for_access_track(new_spte);
 	}
 
-	rcu_read_unlock();
+	tdp_mmu_set_spte_no_acc_track(kvm, iter, new_spte);
 
-	return young;
+	return true;
 }
 
-int kvm_tdp_mmu_age_hva_range(struct kvm *kvm, unsigned long start,
-			      unsigned long end)
+bool kvm_tdp_mmu_age_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	return kvm_tdp_mmu_handle_hva_range(kvm, start, end, 0,
-					    age_gfn_range);
+	return kvm_tdp_mmu_handle_gfn(kvm, range, age_gfn_range);
 }
 
-static int test_age_gfn(struct kvm *kvm, struct kvm_memory_slot *slot,
-			struct kvm_mmu_page *root, gfn_t gfn, gfn_t unused,
-			unsigned long unused2)
+static bool test_age_gfn(struct kvm *kvm, struct tdp_iter *iter,
+			 struct kvm_gfn_range *range)
 {
-	struct tdp_iter iter;
-
-	tdp_root_for_each_leaf_pte(iter, root, gfn, gfn + 1)
-		if (is_accessed_spte(iter.old_spte))
-			return 1;
-
-	return 0;
+	return is_accessed_spte(iter->old_spte);
 }
 
-int kvm_tdp_mmu_test_age_hva(struct kvm *kvm, unsigned long hva)
+bool kvm_tdp_mmu_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	return kvm_tdp_mmu_handle_hva_range(kvm, hva, hva + 1, 0,
-					    test_age_gfn);
+	return kvm_tdp_mmu_handle_gfn(kvm, range, test_age_gfn);
+}
+
+static bool set_spte_gfn(struct kvm *kvm, struct tdp_iter *iter,
+			 struct kvm_gfn_range *range)
+{
+	u64 new_spte;
+
+	/* Huge pages aren't expected to be modified without first being zapped. */
+	WARN_ON(pte_huge(range->pte) || range->start + 1 != range->end);
+
+	if (iter->level != PG_LEVEL_4K ||
+	    !is_shadow_present_pte(iter->old_spte))
+		return false;
+
+	/*
+	 * Note, when changing a read-only SPTE, it's not strictly necessary to
+	 * zero the SPTE before setting the new PFN, but doing so preserves the
+	 * invariant that the PFN of a present * leaf SPTE can never change.
+	 * See __handle_changed_spte().
+	 */
+	tdp_mmu_set_spte(kvm, iter, 0);
+
+	if (!pte_write(range->pte)) {
+		new_spte = kvm_mmu_changed_pte_notifier_make_spte(iter->old_spte,
+								  pte_pfn(range->pte));
+
+		tdp_mmu_set_spte(kvm, iter, new_spte);
+	}
+
+	return true;
 }
 
 /*
@@ -1024,57 +1135,15 @@
  * notifier.
  * Returns non-zero if a flush is needed before releasing the MMU lock.
  */
-static int set_tdp_spte(struct kvm *kvm, struct kvm_memory_slot *slot,
-			struct kvm_mmu_page *root, gfn_t gfn, gfn_t unused,
-			unsigned long data)
+bool kvm_tdp_mmu_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range)
 {
-	struct tdp_iter iter;
-	pte_t *ptep = (pte_t *)data;
-	kvm_pfn_t new_pfn;
-	u64 new_spte;
-	int need_flush = 0;
+	bool flush = kvm_tdp_mmu_handle_gfn(kvm, range, set_spte_gfn);
 
-	rcu_read_lock();
+	/* FIXME: return 'flush' instead of flushing here. */
+	if (flush)
+		kvm_flush_remote_tlbs_with_address(kvm, range->start, 1);
 
-	WARN_ON(pte_huge(*ptep));
-
-	new_pfn = pte_pfn(*ptep);
-
-	tdp_root_for_each_pte(iter, root, gfn, gfn + 1) {
-		if (iter.level != PG_LEVEL_4K)
-			continue;
-
-		if (!is_shadow_present_pte(iter.old_spte))
-			break;
-
-		tdp_mmu_set_spte(kvm, &iter, 0);
-
-		kvm_flush_remote_tlbs_with_address(kvm, iter.gfn, 1);
-
-		if (!pte_write(*ptep)) {
-			new_spte = kvm_mmu_changed_pte_notifier_make_spte(
-					iter.old_spte, new_pfn);
-
-			tdp_mmu_set_spte(kvm, &iter, new_spte);
-		}
-
-		need_flush = 1;
-	}
-
-	if (need_flush)
-		kvm_flush_remote_tlbs_with_address(kvm, gfn, 1);
-
-	rcu_read_unlock();
-
-	return 0;
-}
-
-int kvm_tdp_mmu_set_spte_hva(struct kvm *kvm, unsigned long address,
-			     pte_t *host_ptep)
-{
-	return kvm_tdp_mmu_handle_hva_range(kvm, address, address + 1,
-					    (unsigned long)host_ptep,
-					    set_tdp_spte);
+	return false;
 }
 
 /*
@@ -1095,7 +1164,8 @@
 
 	for_each_tdp_pte_min_level(iter, root->spt, root->role.level,
 				   min_level, start, end) {
-		if (tdp_mmu_iter_cond_resched(kvm, &iter, false))
+retry:
+		if (tdp_mmu_iter_cond_resched(kvm, &iter, false, true))
 			continue;
 
 		if (!is_shadow_present_pte(iter.old_spte) ||
@@ -1105,7 +1175,15 @@
 
 		new_spte = iter.old_spte & ~PT_WRITABLE_MASK;
 
-		tdp_mmu_set_spte_no_dirty_log(kvm, &iter, new_spte);
+		if (!tdp_mmu_set_spte_atomic_no_dirty_log(kvm, &iter,
+							  new_spte)) {
+			/*
+			 * The iter must explicitly re-read the SPTE because
+			 * the atomic cmpxchg failed.
+			 */
+			iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+			goto retry;
+		}
 		spte_set = true;
 	}
 
@@ -1122,17 +1200,13 @@
 			     int min_level)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 	bool spte_set = false;
 
-	for_each_tdp_mmu_root_yield_safe(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
+	lockdep_assert_held_read(&kvm->mmu_lock);
 
+	for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
 		spte_set |= wrprot_gfn_range(kvm, root, slot->base_gfn,
 			     slot->base_gfn + slot->npages, min_level);
-	}
 
 	return spte_set;
 }
@@ -1154,7 +1228,8 @@
 	rcu_read_lock();
 
 	tdp_root_for_each_leaf_pte(iter, root, start, end) {
-		if (tdp_mmu_iter_cond_resched(kvm, &iter, false))
+retry:
+		if (tdp_mmu_iter_cond_resched(kvm, &iter, false, true))
 			continue;
 
 		if (spte_ad_need_write_protect(iter.old_spte)) {
@@ -1169,7 +1244,15 @@
 				continue;
 		}
 
-		tdp_mmu_set_spte_no_dirty_log(kvm, &iter, new_spte);
+		if (!tdp_mmu_set_spte_atomic_no_dirty_log(kvm, &iter,
+							  new_spte)) {
+			/*
+			 * The iter must explicitly re-read the SPTE because
+			 * the atomic cmpxchg failed.
+			 */
+			iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+			goto retry;
+		}
 		spte_set = true;
 	}
 
@@ -1187,17 +1270,13 @@
 bool kvm_tdp_mmu_clear_dirty_slot(struct kvm *kvm, struct kvm_memory_slot *slot)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 	bool spte_set = false;
 
-	for_each_tdp_mmu_root_yield_safe(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
+	lockdep_assert_held_read(&kvm->mmu_lock);
 
+	for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
 		spte_set |= clear_dirty_gfn_range(kvm, root, slot->base_gfn,
 				slot->base_gfn + slot->npages);
-	}
 
 	return spte_set;
 }
@@ -1259,37 +1338,32 @@
 				       bool wrprot)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 
 	lockdep_assert_held_write(&kvm->mmu_lock);
-	for_each_tdp_mmu_root(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
-
+	for_each_tdp_mmu_root(kvm, root, slot->as_id)
 		clear_dirty_pt_masked(kvm, root, gfn, mask, wrprot);
-	}
 }
 
 /*
  * Clear leaf entries which could be replaced by large mappings, for
  * GFNs within the slot.
  */
-static void zap_collapsible_spte_range(struct kvm *kvm,
+static bool zap_collapsible_spte_range(struct kvm *kvm,
 				       struct kvm_mmu_page *root,
-				       struct kvm_memory_slot *slot)
+				       const struct kvm_memory_slot *slot,
+				       bool flush)
 {
 	gfn_t start = slot->base_gfn;
 	gfn_t end = start + slot->npages;
 	struct tdp_iter iter;
 	kvm_pfn_t pfn;
-	bool spte_set = false;
 
 	rcu_read_lock();
 
 	tdp_root_for_each_pte(iter, root, start, end) {
-		if (tdp_mmu_iter_cond_resched(kvm, &iter, spte_set)) {
-			spte_set = false;
+retry:
+		if (tdp_mmu_iter_cond_resched(kvm, &iter, flush, true)) {
+			flush = false;
 			continue;
 		}
 
@@ -1303,38 +1377,43 @@
 							    pfn, PG_LEVEL_NUM))
 			continue;
 
-		tdp_mmu_set_spte(kvm, &iter, 0);
-
-		spte_set = true;
+		if (!tdp_mmu_zap_spte_atomic(kvm, &iter)) {
+			/*
+			 * The iter must explicitly re-read the SPTE because
+			 * the atomic cmpxchg failed.
+			 */
+			iter.old_spte = READ_ONCE(*rcu_dereference(iter.sptep));
+			goto retry;
+		}
+		flush = true;
 	}
 
 	rcu_read_unlock();
-	if (spte_set)
-		kvm_flush_remote_tlbs(kvm);
+
+	return flush;
 }
 
 /*
  * Clear non-leaf entries (and free associated page tables) which could
  * be replaced by large mappings, for GFNs within the slot.
  */
-void kvm_tdp_mmu_zap_collapsible_sptes(struct kvm *kvm,
-				       struct kvm_memory_slot *slot)
+bool kvm_tdp_mmu_zap_collapsible_sptes(struct kvm *kvm,
+				       const struct kvm_memory_slot *slot,
+				       bool flush)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 
-	for_each_tdp_mmu_root_yield_safe(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
+	lockdep_assert_held_read(&kvm->mmu_lock);
 
-		zap_collapsible_spte_range(kvm, root, slot);
-	}
+	for_each_tdp_mmu_root_yield_safe(kvm, root, slot->as_id, true)
+		flush = zap_collapsible_spte_range(kvm, root, slot, flush);
+
+	return flush;
 }
 
 /*
  * Removes write access on the last level SPTE mapping this GFN and unsets the
- * SPTE_MMU_WRITABLE bit to ensure future writes continue to be intercepted.
+ * MMU-writable bit to ensure future writes continue to be intercepted.
  * Returns true if an SPTE was set and a TLB flush is needed.
  */
 static bool write_protect_gfn(struct kvm *kvm, struct kvm_mmu_page *root,
@@ -1351,7 +1430,7 @@
 			break;
 
 		new_spte = iter.old_spte &
-			~(PT_WRITABLE_MASK | SPTE_MMU_WRITEABLE);
+			~(PT_WRITABLE_MASK | shadow_mmu_writable_mask);
 
 		tdp_mmu_set_spte(kvm, &iter, new_spte);
 		spte_set = true;
@@ -1364,24 +1443,19 @@
 
 /*
  * Removes write access on the last level SPTE mapping this GFN and unsets the
- * SPTE_MMU_WRITABLE bit to ensure future writes continue to be intercepted.
+ * MMU-writable bit to ensure future writes continue to be intercepted.
  * Returns true if an SPTE was set and a TLB flush is needed.
  */
 bool kvm_tdp_mmu_write_protect_gfn(struct kvm *kvm,
 				   struct kvm_memory_slot *slot, gfn_t gfn)
 {
 	struct kvm_mmu_page *root;
-	int root_as_id;
 	bool spte_set = false;
 
 	lockdep_assert_held_write(&kvm->mmu_lock);
-	for_each_tdp_mmu_root(kvm, root) {
-		root_as_id = kvm_mmu_page_as_id(root);
-		if (root_as_id != slot->as_id)
-			continue;
-
+	for_each_tdp_mmu_root(kvm, root, slot->as_id)
 		spte_set |= write_protect_gfn(kvm, root, gfn);
-	}
+
 	return spte_set;
 }
 
diff --git a/arch/x86/kvm/mmu/tdp_mmu.h b/arch/x86/kvm/mmu/tdp_mmu.h
index 31096ec..0c684f2 100644
--- a/arch/x86/kvm/mmu/tdp_mmu.h
+++ b/arch/x86/kvm/mmu/tdp_mmu.h
@@ -6,14 +6,28 @@
 #include <linux/kvm_host.h>
 
 hpa_t kvm_tdp_mmu_get_vcpu_root_hpa(struct kvm_vcpu *vcpu);
-void kvm_tdp_mmu_free_root(struct kvm *kvm, struct kvm_mmu_page *root);
 
-bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, gfn_t start, gfn_t end,
-				 bool can_yield);
-static inline bool kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, gfn_t start,
-					     gfn_t end)
+__must_check static inline bool kvm_tdp_mmu_get_root(struct kvm *kvm,
+						     struct kvm_mmu_page *root)
 {
-	return __kvm_tdp_mmu_zap_gfn_range(kvm, start, end, true);
+	if (root->role.invalid)
+		return false;
+
+	return refcount_inc_not_zero(&root->tdp_mmu_root_count);
+}
+
+void kvm_tdp_mmu_put_root(struct kvm *kvm, struct kvm_mmu_page *root,
+			  bool shared);
+
+bool __kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, int as_id, gfn_t start,
+				 gfn_t end, bool can_yield, bool flush,
+				 bool shared);
+static inline bool kvm_tdp_mmu_zap_gfn_range(struct kvm *kvm, int as_id,
+					     gfn_t start, gfn_t end, bool flush,
+					     bool shared)
+{
+	return __kvm_tdp_mmu_zap_gfn_range(kvm, as_id, start, end, true, flush,
+					   shared);
 }
 static inline bool kvm_tdp_mmu_zap_sp(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
@@ -29,23 +43,21 @@
 	 * of the shadow page's gfn range and stop iterating before yielding.
 	 */
 	lockdep_assert_held_write(&kvm->mmu_lock);
-	return __kvm_tdp_mmu_zap_gfn_range(kvm, sp->gfn, end, false);
+	return __kvm_tdp_mmu_zap_gfn_range(kvm, kvm_mmu_page_as_id(sp),
+					   sp->gfn, end, false, false, false);
 }
 void kvm_tdp_mmu_zap_all(struct kvm *kvm);
+void kvm_tdp_mmu_zap_invalidated_roots(struct kvm *kvm);
 
 int kvm_tdp_mmu_map(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
 		    int map_writable, int max_level, kvm_pfn_t pfn,
 		    bool prefault);
 
-int kvm_tdp_mmu_zap_hva_range(struct kvm *kvm, unsigned long start,
-			      unsigned long end);
-
-int kvm_tdp_mmu_age_hva_range(struct kvm *kvm, unsigned long start,
-			      unsigned long end);
-int kvm_tdp_mmu_test_age_hva(struct kvm *kvm, unsigned long hva);
-
-int kvm_tdp_mmu_set_spte_hva(struct kvm *kvm, unsigned long address,
-			     pte_t *host_ptep);
+bool kvm_tdp_mmu_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range,
+				 bool flush);
+bool kvm_tdp_mmu_age_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range);
+bool kvm_tdp_mmu_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range);
+bool kvm_tdp_mmu_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range);
 
 bool kvm_tdp_mmu_wrprot_slot(struct kvm *kvm, struct kvm_memory_slot *slot,
 			     int min_level);
@@ -55,8 +67,9 @@
 				       struct kvm_memory_slot *slot,
 				       gfn_t gfn, unsigned long mask,
 				       bool wrprot);
-void kvm_tdp_mmu_zap_collapsible_sptes(struct kvm *kvm,
-				       struct kvm_memory_slot *slot);
+bool kvm_tdp_mmu_zap_collapsible_sptes(struct kvm *kvm,
+				       const struct kvm_memory_slot *slot,
+				       bool flush);
 
 bool kvm_tdp_mmu_write_protect_gfn(struct kvm *kvm,
 				   struct kvm_memory_slot *slot, gfn_t gfn);
diff --git a/arch/x86/kvm/svm/avic.c b/arch/x86/kvm/svm/avic.c
index 78bdcfa..cd0285f 100644
--- a/arch/x86/kvm/svm/avic.c
+++ b/arch/x86/kvm/svm/avic.c
@@ -270,7 +270,7 @@
 	if (id >= AVIC_MAX_PHYSICAL_ID_COUNT)
 		return -EINVAL;
 
-	if (!svm->vcpu.arch.apic->regs)
+	if (!vcpu->arch.apic->regs)
 		return -EINVAL;
 
 	if (kvm_apicv_activated(vcpu->kvm)) {
@@ -281,7 +281,7 @@
 			return ret;
 	}
 
-	svm->avic_backing_page = virt_to_page(svm->vcpu.arch.apic->regs);
+	svm->avic_backing_page = virt_to_page(vcpu->arch.apic->regs);
 
 	/* Setting AVIC backing page address in the phy APIC ID table */
 	entry = avic_get_physical_id_entry(vcpu, id);
@@ -315,15 +315,16 @@
 	}
 }
 
-int avic_incomplete_ipi_interception(struct vcpu_svm *svm)
+int avic_incomplete_ipi_interception(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	u32 icrh = svm->vmcb->control.exit_info_1 >> 32;
 	u32 icrl = svm->vmcb->control.exit_info_1;
 	u32 id = svm->vmcb->control.exit_info_2 >> 32;
 	u32 index = svm->vmcb->control.exit_info_2 & 0xFF;
-	struct kvm_lapic *apic = svm->vcpu.arch.apic;
+	struct kvm_lapic *apic = vcpu->arch.apic;
 
-	trace_kvm_avic_incomplete_ipi(svm->vcpu.vcpu_id, icrh, icrl, id, index);
+	trace_kvm_avic_incomplete_ipi(vcpu->vcpu_id, icrh, icrl, id, index);
 
 	switch (id) {
 	case AVIC_IPI_FAILURE_INVALID_INT_TYPE:
@@ -347,11 +348,11 @@
 		 * set the appropriate IRR bits on the valid target
 		 * vcpus. So, we just need to kick the appropriate vcpu.
 		 */
-		avic_kick_target_vcpus(svm->vcpu.kvm, apic, icrl, icrh);
+		avic_kick_target_vcpus(vcpu->kvm, apic, icrl, icrh);
 		break;
 	case AVIC_IPI_FAILURE_INVALID_TARGET:
 		WARN_ONCE(1, "Invalid IPI target: index=%u, vcpu=%d, icr=%#0x:%#0x\n",
-			  index, svm->vcpu.vcpu_id, icrh, icrl);
+			  index, vcpu->vcpu_id, icrh, icrl);
 		break;
 	case AVIC_IPI_FAILURE_INVALID_BACKING_PAGE:
 		WARN_ONCE(1, "Invalid backing page\n");
@@ -539,8 +540,9 @@
 	return ret;
 }
 
-int avic_unaccelerated_access_interception(struct vcpu_svm *svm)
+int avic_unaccelerated_access_interception(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	int ret = 0;
 	u32 offset = svm->vmcb->control.exit_info_1 &
 		     AVIC_UNACCEL_ACCESS_OFFSET_MASK;
@@ -550,7 +552,7 @@
 		     AVIC_UNACCEL_ACCESS_WRITE_MASK;
 	bool trap = is_avic_unaccelerated_access_trap(offset);
 
-	trace_kvm_avic_unaccelerated_access(svm->vcpu.vcpu_id, offset,
+	trace_kvm_avic_unaccelerated_access(vcpu->vcpu_id, offset,
 					    trap, write, vector);
 	if (trap) {
 		/* Handling Trap */
@@ -558,7 +560,7 @@
 		ret = avic_unaccel_trap_write(svm);
 	} else {
 		/* Handling Fault */
-		ret = kvm_emulate_instruction(&svm->vcpu, 0);
+		ret = kvm_emulate_instruction(vcpu, 0);
 	}
 
 	return ret;
@@ -572,7 +574,7 @@
 	if (!avic || !irqchip_in_kernel(vcpu->kvm))
 		return 0;
 
-	ret = avic_init_backing_page(&svm->vcpu);
+	ret = avic_init_backing_page(vcpu);
 	if (ret)
 		return ret;
 
diff --git a/arch/x86/kvm/svm/nested.c b/arch/x86/kvm/svm/nested.c
index fb204ea..f8b78ab 100644
--- a/arch/x86/kvm/svm/nested.c
+++ b/arch/x86/kvm/svm/nested.c
@@ -29,6 +29,8 @@
 #include "lapic.h"
 #include "svm.h"
 
+#define CC KVM_NESTED_VMENTER_CONSISTENCY_CHECK
+
 static void nested_svm_inject_npf_exit(struct kvm_vcpu *vcpu,
 				       struct x86_exception *fault)
 {
@@ -92,12 +94,12 @@
 static void nested_svm_init_mmu_context(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
-	struct vmcb *hsave = svm->nested.hsave;
 
 	WARN_ON(mmu_is_nested(vcpu));
 
 	vcpu->arch.mmu = &vcpu->arch.guest_mmu;
-	kvm_init_shadow_npt_mmu(vcpu, X86_CR0_PG, hsave->save.cr4, hsave->save.efer,
+	kvm_init_shadow_npt_mmu(vcpu, X86_CR0_PG, svm->vmcb01.ptr->save.cr4,
+				svm->vmcb01.ptr->save.efer,
 				svm->nested.ctl.nested_cr3);
 	vcpu->arch.mmu->get_guest_pgd     = nested_svm_get_tdp_cr3;
 	vcpu->arch.mmu->get_pdptr         = nested_svm_get_tdp_pdptr;
@@ -123,7 +125,7 @@
 		return;
 
 	c = &svm->vmcb->control;
-	h = &svm->nested.hsave->control;
+	h = &svm->vmcb01.ptr->control;
 	g = &svm->nested.ctl;
 
 	for (i = 0; i < MAX_INTERCEPT; i++)
@@ -213,44 +215,45 @@
 	return true;
 }
 
-static bool svm_get_nested_state_pages(struct kvm_vcpu *vcpu)
-{
-	struct vcpu_svm *svm = to_svm(vcpu);
-
-	if (WARN_ON(!is_guest_mode(vcpu)))
-		return true;
-
-	if (!nested_svm_vmrun_msrpm(svm)) {
-		vcpu->run->exit_reason = KVM_EXIT_INTERNAL_ERROR;
-		vcpu->run->internal.suberror =
-			KVM_INTERNAL_ERROR_EMULATION;
-		vcpu->run->internal.ndata = 0;
-		return false;
-	}
-
-	return true;
-}
-
 static bool nested_vmcb_check_controls(struct vmcb_control_area *control)
 {
-	if ((vmcb_is_intercept(control, INTERCEPT_VMRUN)) == 0)
+	if (CC(!vmcb_is_intercept(control, INTERCEPT_VMRUN)))
 		return false;
 
-	if (control->asid == 0)
+	if (CC(control->asid == 0))
 		return false;
 
-	if ((control->nested_ctl & SVM_NESTED_CTL_NP_ENABLE) &&
-	    !npt_enabled)
+	if (CC((control->nested_ctl & SVM_NESTED_CTL_NP_ENABLE) && !npt_enabled))
 		return false;
 
 	return true;
 }
 
-static bool nested_vmcb_check_save(struct vcpu_svm *svm, struct vmcb *vmcb12)
+static bool nested_vmcb_check_cr3_cr4(struct kvm_vcpu *vcpu,
+				      struct vmcb_save_area *save)
 {
-	struct kvm_vcpu *vcpu = &svm->vcpu;
-	bool vmcb12_lma;
+	/*
+	 * These checks are also performed by KVM_SET_SREGS,
+	 * except that EFER.LMA is not checked by SVM against
+	 * CR0.PG && EFER.LME.
+	 */
+	if ((save->efer & EFER_LME) && (save->cr0 & X86_CR0_PG)) {
+		if (CC(!(save->cr4 & X86_CR4_PAE)) ||
+		    CC(!(save->cr0 & X86_CR0_PE)) ||
+		    CC(kvm_vcpu_is_illegal_gpa(vcpu, save->cr3)))
+			return false;
+	}
 
+	if (CC(!kvm_is_valid_cr4(vcpu, save->cr4)))
+		return false;
+
+	return true;
+}
+
+/* Common checks that apply to both L1 and L2 state.  */
+static bool nested_vmcb_valid_sregs(struct kvm_vcpu *vcpu,
+				    struct vmcb_save_area *save)
+{
 	/*
 	 * FIXME: these should be done after copying the fields,
 	 * to avoid TOC/TOU races.  For these save area checks
@@ -258,31 +261,27 @@
 	 * kvm_set_cr4 handle failure; EFER_SVME is an exception
 	 * so it is force-set later in nested_prepare_vmcb_save.
 	 */
-	if ((vmcb12->save.efer & EFER_SVME) == 0)
+	if (CC(!(save->efer & EFER_SVME)))
 		return false;
 
-	if (((vmcb12->save.cr0 & X86_CR0_CD) == 0) && (vmcb12->save.cr0 & X86_CR0_NW))
+	if (CC((save->cr0 & X86_CR0_CD) == 0 && (save->cr0 & X86_CR0_NW)) ||
+	    CC(save->cr0 & ~0xffffffffULL))
 		return false;
 
-	if (!kvm_dr6_valid(vmcb12->save.dr6) || !kvm_dr7_valid(vmcb12->save.dr7))
+	if (CC(!kvm_dr6_valid(save->dr6)) || CC(!kvm_dr7_valid(save->dr7)))
 		return false;
 
-	vmcb12_lma = (vmcb12->save.efer & EFER_LME) && (vmcb12->save.cr0 & X86_CR0_PG);
+	if (!nested_vmcb_check_cr3_cr4(vcpu, save))
+		return false;
 
-	if (vmcb12_lma) {
-		if (!(vmcb12->save.cr4 & X86_CR4_PAE) ||
-		    !(vmcb12->save.cr0 & X86_CR0_PE) ||
-		    kvm_vcpu_is_illegal_gpa(vcpu, vmcb12->save.cr3))
-			return false;
-	}
-	if (!kvm_is_valid_cr4(&svm->vcpu, vmcb12->save.cr4))
+	if (CC(!kvm_valid_efer(vcpu, save->efer)))
 		return false;
 
 	return true;
 }
 
-static void load_nested_vmcb_control(struct vcpu_svm *svm,
-				     struct vmcb_control_area *control)
+static void nested_load_control_from_vmcb12(struct vcpu_svm *svm,
+					    struct vmcb_control_area *control)
 {
 	copy_vmcb_control_area(&svm->nested.ctl, control);
 
@@ -294,9 +293,9 @@
 
 /*
  * Synchronize fields that are written by the processor, so that
- * they can be copied back into the nested_vmcb.
+ * they can be copied back into the vmcb12.
  */
-void sync_nested_vmcb_control(struct vcpu_svm *svm)
+void nested_sync_control_from_vmcb02(struct vcpu_svm *svm)
 {
 	u32 mask;
 	svm->nested.ctl.event_inj      = svm->vmcb->control.event_inj;
@@ -324,8 +323,8 @@
  * Transfer any event that L0 or L1 wanted to inject into L2 to
  * EXIT_INT_INFO.
  */
-static void nested_vmcb_save_pending_event(struct vcpu_svm *svm,
-					   struct vmcb *vmcb12)
+static void nested_save_pending_event_to_vmcb12(struct vcpu_svm *svm,
+						struct vmcb *vmcb12)
 {
 	struct kvm_vcpu *vcpu = &svm->vcpu;
 	u32 exit_int_info = 0;
@@ -369,12 +368,12 @@
 static int nested_svm_load_cr3(struct kvm_vcpu *vcpu, unsigned long cr3,
 			       bool nested_npt)
 {
-	if (kvm_vcpu_is_illegal_gpa(vcpu, cr3))
+	if (CC(kvm_vcpu_is_illegal_gpa(vcpu, cr3)))
 		return -EINVAL;
 
 	if (!nested_npt && is_pae_paging(vcpu) &&
 	    (cr3 != kvm_read_cr3(vcpu) || pdptrs_changed(vcpu))) {
-		if (!load_pdptrs(vcpu, vcpu->arch.walk_mmu, cr3))
+		if (CC(!load_pdptrs(vcpu, vcpu->arch.walk_mmu, cr3)))
 			return -EINVAL;
 	}
 
@@ -393,8 +392,21 @@
 	return 0;
 }
 
-static void nested_prepare_vmcb_save(struct vcpu_svm *svm, struct vmcb *vmcb12)
+void nested_vmcb02_compute_g_pat(struct vcpu_svm *svm)
 {
+	if (!svm->nested.vmcb02.ptr)
+		return;
+
+	/* FIXME: merge g_pat from vmcb01 and vmcb12.  */
+	svm->nested.vmcb02.ptr->save.g_pat = svm->vmcb01.ptr->save.g_pat;
+}
+
+static void nested_vmcb02_prepare_save(struct vcpu_svm *svm, struct vmcb *vmcb12)
+{
+	bool new_vmcb12 = false;
+
+	nested_vmcb02_compute_g_pat(svm);
+
 	/* Load the nested guest state */
 	svm->vmcb->save.es = vmcb12->save.es;
 	svm->vmcb->save.cs = vmcb12->save.cs;
@@ -413,7 +425,9 @@
 
 	svm_set_cr0(&svm->vcpu, vmcb12->save.cr0);
 	svm_set_cr4(&svm->vcpu, vmcb12->save.cr4);
-	svm->vmcb->save.cr2 = svm->vcpu.arch.cr2 = vmcb12->save.cr2;
+
+	svm->vcpu.arch.cr2 = vmcb12->save.cr2;
+
 	kvm_rax_write(&svm->vcpu, vmcb12->save.rax);
 	kvm_rsp_write(&svm->vcpu, vmcb12->save.rsp);
 	kvm_rip_write(&svm->vcpu, vmcb12->save.rip);
@@ -422,15 +436,41 @@
 	svm->vmcb->save.rax = vmcb12->save.rax;
 	svm->vmcb->save.rsp = vmcb12->save.rsp;
 	svm->vmcb->save.rip = vmcb12->save.rip;
-	svm->vmcb->save.dr7 = vmcb12->save.dr7 | DR7_FIXED_1;
-	svm->vcpu.arch.dr6  = vmcb12->save.dr6 | DR6_ACTIVE_LOW;
-	svm->vmcb->save.cpl = vmcb12->save.cpl;
+
+	/* These bits will be set properly on the first execution when new_vmc12 is true */
+	if (unlikely(new_vmcb12 || vmcb_is_dirty(vmcb12, VMCB_DR))) {
+		svm->vmcb->save.dr7 = vmcb12->save.dr7 | DR7_FIXED_1;
+		svm->vcpu.arch.dr6  = vmcb12->save.dr6 | DR6_ACTIVE_LOW;
+		vmcb_mark_dirty(svm->vmcb, VMCB_DR);
+	}
 }
 
-static void nested_prepare_vmcb_control(struct vcpu_svm *svm)
+static void nested_vmcb02_prepare_control(struct vcpu_svm *svm)
 {
 	const u32 mask = V_INTR_MASKING_MASK | V_GIF_ENABLE_MASK | V_GIF_MASK;
 
+	/*
+	 * Filled at exit: exit_code, exit_code_hi, exit_info_1, exit_info_2,
+	 * exit_int_info, exit_int_info_err, next_rip, insn_len, insn_bytes.
+	 */
+
+	/*
+	 * Also covers avic_vapic_bar, avic_backing_page, avic_logical_id,
+	 * avic_physical_id.
+	 */
+	WARN_ON(svm->vmcb01.ptr->control.int_ctl & AVIC_ENABLE_MASK);
+
+	/* Copied from vmcb01.  msrpm_base can be overwritten later.  */
+	svm->vmcb->control.nested_ctl = svm->vmcb01.ptr->control.nested_ctl;
+	svm->vmcb->control.iopm_base_pa = svm->vmcb01.ptr->control.iopm_base_pa;
+	svm->vmcb->control.msrpm_base_pa = svm->vmcb01.ptr->control.msrpm_base_pa;
+
+	/* Done at vmrun: asid.  */
+
+	/* Also overwritten later if necessary.  */
+	svm->vmcb->control.tlb_ctl = TLB_CONTROL_DO_NOTHING;
+
+	/* nested_cr3.  */
 	if (nested_npt_enabled(svm))
 		nested_svm_init_mmu_context(&svm->vcpu);
 
@@ -439,7 +479,7 @@
 
 	svm->vmcb->control.int_ctl             =
 		(svm->nested.ctl.int_ctl & ~mask) |
-		(svm->nested.hsave->control.int_ctl & mask);
+		(svm->vmcb01.ptr->control.int_ctl & mask);
 
 	svm->vmcb->control.virt_ext            = svm->nested.ctl.virt_ext;
 	svm->vmcb->control.int_vector          = svm->nested.ctl.int_vector;
@@ -454,17 +494,28 @@
 	enter_guest_mode(&svm->vcpu);
 
 	/*
-	 * Merge guest and host intercepts - must be called  with vcpu in
-	 * guest-mode to take affect here
+	 * Merge guest and host intercepts - must be called with vcpu in
+	 * guest-mode to take effect.
 	 */
 	recalc_intercepts(svm);
-
-	vmcb_mark_all_dirty(svm->vmcb);
 }
 
-int enter_svm_guest_mode(struct vcpu_svm *svm, u64 vmcb12_gpa,
+static void nested_svm_copy_common_state(struct vmcb *from_vmcb, struct vmcb *to_vmcb)
+{
+	/*
+	 * Some VMCB state is shared between L1 and L2 and thus has to be
+	 * moved at the time of nested vmrun and vmexit.
+	 *
+	 * VMLOAD/VMSAVE state would also belong in this category, but KVM
+	 * always performs VMLOAD and VMSAVE from the VMCB01.
+	 */
+	to_vmcb->save.spec_ctrl = from_vmcb->save.spec_ctrl;
+}
+
+int enter_svm_guest_mode(struct kvm_vcpu *vcpu, u64 vmcb12_gpa,
 			 struct vmcb *vmcb12)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	int ret;
 
 	trace_kvm_nested_vmrun(svm->vmcb->save.rip, vmcb12_gpa,
@@ -482,7 +533,6 @@
 
 
 	svm->nested.vmcb12_gpa = vmcb12_gpa;
-	nested_prepare_vmcb_control(svm);
 	nested_prepare_vmcb_save(svm, vmcb12);
 
 	ret = nested_svm_load_cr3(&svm->vcpu, vmcb12->save.cr3,
@@ -491,46 +541,47 @@
 		return ret;
 
 	if (!npt_enabled)
-		svm->vcpu.arch.mmu->inject_page_fault = svm_inject_page_fault_nested;
+		vcpu->arch.mmu->inject_page_fault = svm_inject_page_fault_nested;
 
 	svm_set_gif(svm, true);
 
 	return 0;
 }
 
-int nested_svm_vmrun(struct vcpu_svm *svm)
+int nested_svm_vmrun(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	int ret;
 	struct vmcb *vmcb12;
-	struct vmcb *hsave = svm->nested.hsave;
-	struct vmcb *vmcb = svm->vmcb;
 	struct kvm_host_map map;
 	u64 vmcb12_gpa;
 
-	if (is_smm(&svm->vcpu)) {
-		kvm_queue_exception(&svm->vcpu, UD_VECTOR);
+	++vcpu->stat.nested_run;
+
+	if (is_smm(vcpu)) {
+		kvm_queue_exception(vcpu, UD_VECTOR);
 		return 1;
 	}
 
 	vmcb12_gpa = svm->vmcb->save.rax;
-	ret = kvm_vcpu_map(&svm->vcpu, gpa_to_gfn(vmcb12_gpa), &map);
+	ret = kvm_vcpu_map(vcpu, gpa_to_gfn(vmcb12_gpa), &map);
 	if (ret == -EINVAL) {
-		kvm_inject_gp(&svm->vcpu, 0);
+		kvm_inject_gp(vcpu, 0);
 		return 1;
 	} else if (ret) {
-		return kvm_skip_emulated_instruction(&svm->vcpu);
+		return kvm_skip_emulated_instruction(vcpu);
 	}
 
-	ret = kvm_skip_emulated_instruction(&svm->vcpu);
+	ret = kvm_skip_emulated_instruction(vcpu);
 
 	vmcb12 = map.hva;
 
 	if (WARN_ON_ONCE(!svm->nested.initialized))
 		return -EINVAL;
 
-	load_nested_vmcb_control(svm, &vmcb12->control);
+	nested_load_control_from_vmcb12(svm, &vmcb12->control);
 
-	if (!nested_vmcb_check_save(svm, vmcb12) ||
+	if (!nested_vmcb_valid_sregs(vcpu, &vmcb12->save) ||
 	    !nested_vmcb_check_controls(&svm->nested.ctl)) {
 		vmcb12->control.exit_code    = SVM_EXIT_ERR;
 		vmcb12->control.exit_code_hi = 0;
@@ -541,36 +592,25 @@
 
 
 	/* Clear internal status */
-	kvm_clear_exception_queue(&svm->vcpu);
-	kvm_clear_interrupt_queue(&svm->vcpu);
+	kvm_clear_exception_queue(vcpu);
+	kvm_clear_interrupt_queue(vcpu);
 
 	/*
-	 * Save the old vmcb, so we don't need to pick what we save, but can
-	 * restore everything when a VMEXIT occurs
+	 * Since vmcb01 is not in use, we can use it to store some of the L1
+	 * state.
 	 */
-	hsave->save.es     = vmcb->save.es;
-	hsave->save.cs     = vmcb->save.cs;
-	hsave->save.ss     = vmcb->save.ss;
-	hsave->save.ds     = vmcb->save.ds;
-	hsave->save.gdtr   = vmcb->save.gdtr;
-	hsave->save.idtr   = vmcb->save.idtr;
-	hsave->save.efer   = svm->vcpu.arch.efer;
-	hsave->save.cr0    = kvm_read_cr0(&svm->vcpu);
-	hsave->save.cr4    = svm->vcpu.arch.cr4;
-	hsave->save.rflags = kvm_get_rflags(&svm->vcpu);
-	hsave->save.rip    = kvm_rip_read(&svm->vcpu);
-	hsave->save.rsp    = vmcb->save.rsp;
-	hsave->save.rax    = vmcb->save.rax;
-	if (npt_enabled)
-		hsave->save.cr3    = vmcb->save.cr3;
-	else
-		hsave->save.cr3    = kvm_read_cr3(&svm->vcpu);
+	svm->vmcb01.ptr->save.efer   = vcpu->arch.efer;
+	svm->vmcb01.ptr->save.cr0    = kvm_read_cr0(vcpu);
+	svm->vmcb01.ptr->save.cr4    = vcpu->arch.cr4;
+	svm->vmcb01.ptr->save.rflags = kvm_get_rflags(vcpu);
+	svm->vmcb01.ptr->save.rip    = kvm_rip_read(vcpu);
 
-	copy_vmcb_control_area(&hsave->control, &vmcb->control);
+	if (!npt_enabled)
+		svm->vmcb01.ptr->save.cr3 = kvm_read_cr3(vcpu);
 
 	svm->nested.nested_run_pending = 1;
 
-	if (enter_svm_guest_mode(svm, vmcb12_gpa, vmcb12))
+	if (enter_svm_guest_mode(vcpu, vmcb12_gpa, vmcb12))
 		goto out_exit_err;
 
 	if (nested_svm_vmrun_msrpm(svm))
@@ -587,7 +627,7 @@
 	nested_svm_vmexit(svm);
 
 out:
-	kvm_vcpu_unmap(&svm->vcpu, &map, true);
+	kvm_vcpu_unmap(vcpu, &map, true);
 
 	return ret;
 }
@@ -610,27 +650,30 @@
 
 int nested_svm_vmexit(struct vcpu_svm *svm)
 {
-	int rc;
+	struct kvm_vcpu *vcpu = &svm->vcpu;
 	struct vmcb *vmcb12;
-	struct vmcb *hsave = svm->nested.hsave;
 	struct vmcb *vmcb = svm->vmcb;
 	struct kvm_host_map map;
+	int rc;
 
-	rc = kvm_vcpu_map(&svm->vcpu, gpa_to_gfn(svm->nested.vmcb12_gpa), &map);
+	/* Triple faults in L2 should never escape. */
+	WARN_ON_ONCE(kvm_check_request(KVM_REQ_TRIPLE_FAULT, vcpu));
+
+	rc = kvm_vcpu_map(vcpu, gpa_to_gfn(svm->nested.vmcb12_gpa), &map);
 	if (rc) {
 		if (rc == -EINVAL)
-			kvm_inject_gp(&svm->vcpu, 0);
+			kvm_inject_gp(vcpu, 0);
 		return 1;
 	}
 
 	vmcb12 = map.hva;
 
 	/* Exit Guest-Mode */
-	leave_guest_mode(&svm->vcpu);
+	leave_guest_mode(vcpu);
 	svm->nested.vmcb12_gpa = 0;
 	WARN_ON_ONCE(svm->nested.nested_run_pending);
 
-	kvm_clear_request(KVM_REQ_GET_NESTED_STATE_PAGES, &svm->vcpu);
+	kvm_clear_request(KVM_REQ_GET_NESTED_STATE_PAGES, vcpu);
 
 	/* in case we halted in L2 */
 	svm->vcpu.arch.mp_state = KVM_MP_STATE_RUNNABLE;
@@ -644,14 +687,14 @@
 	vmcb12->save.gdtr   = vmcb->save.gdtr;
 	vmcb12->save.idtr   = vmcb->save.idtr;
 	vmcb12->save.efer   = svm->vcpu.arch.efer;
-	vmcb12->save.cr0    = kvm_read_cr0(&svm->vcpu);
-	vmcb12->save.cr3    = kvm_read_cr3(&svm->vcpu);
+	vmcb12->save.cr0    = kvm_read_cr0(vcpu);
+	vmcb12->save.cr3    = kvm_read_cr3(vcpu);
 	vmcb12->save.cr2    = vmcb->save.cr2;
 	vmcb12->save.cr4    = svm->vcpu.arch.cr4;
-	vmcb12->save.rflags = kvm_get_rflags(&svm->vcpu);
-	vmcb12->save.rip    = kvm_rip_read(&svm->vcpu);
-	vmcb12->save.rsp    = kvm_rsp_read(&svm->vcpu);
-	vmcb12->save.rax    = kvm_rax_read(&svm->vcpu);
+	vmcb12->save.rflags = kvm_get_rflags(vcpu);
+	vmcb12->save.rip    = kvm_rip_read(vcpu);
+	vmcb12->save.rsp    = kvm_rsp_read(vcpu);
+	vmcb12->save.rax    = kvm_rax_read(vcpu);
 	vmcb12->save.dr7    = vmcb->save.dr7;
 	vmcb12->save.dr6    = svm->vcpu.arch.dr6;
 	vmcb12->save.cpl    = vmcb->save.cpl;
@@ -663,7 +706,7 @@
 	vmcb12->control.exit_info_2       = vmcb->control.exit_info_2;
 
 	if (vmcb12->control.exit_code != SVM_EXIT_ERR)
-		nested_vmcb_save_pending_event(svm, vmcb12);
+		nested_save_pending_event_to_vmcb12(svm, vmcb12);
 
 	if (svm->nrips_enabled)
 		vmcb12->control.next_rip  = vmcb->control.next_rip;
@@ -678,37 +721,39 @@
 	vmcb12->control.pause_filter_thresh =
 		svm->vmcb->control.pause_filter_thresh;
 
-	/* Restore the original control entries */
-	copy_vmcb_control_area(&vmcb->control, &hsave->control);
+	nested_svm_copy_common_state(svm->nested.vmcb02.ptr, svm->vmcb01.ptr);
 
-	/* On vmexit the  GIF is set to false */
+	svm_switch_vmcb(svm, &svm->vmcb01);
+	WARN_ON_ONCE(svm->vmcb->control.exit_code != SVM_EXIT_VMRUN);
+
+	/*
+	 * On vmexit the  GIF is set to false and
+	 * no event can be injected in L1.
+	 */
 	svm_set_gif(svm, false);
+	svm->vmcb->control.exit_int_info = 0;
 
-	svm->vmcb->control.tsc_offset = svm->vcpu.arch.tsc_offset =
-		svm->vcpu.arch.l1_tsc_offset;
+	svm->vcpu.arch.tsc_offset = svm->vcpu.arch.l1_tsc_offset;
+	if (svm->vmcb->control.tsc_offset != svm->vcpu.arch.tsc_offset) {
+		svm->vmcb->control.tsc_offset = svm->vcpu.arch.tsc_offset;
+		vmcb_mark_dirty(svm->vmcb, VMCB_INTERCEPTS);
+	}
 
 	svm->nested.ctl.nested_cr3 = 0;
 
-	/* Restore selected save entries */
-	svm->vmcb->save.es = hsave->save.es;
-	svm->vmcb->save.cs = hsave->save.cs;
-	svm->vmcb->save.ss = hsave->save.ss;
-	svm->vmcb->save.ds = hsave->save.ds;
-	svm->vmcb->save.gdtr = hsave->save.gdtr;
-	svm->vmcb->save.idtr = hsave->save.idtr;
-	kvm_set_rflags(&svm->vcpu, hsave->save.rflags);
-	kvm_set_rflags(&svm->vcpu, hsave->save.rflags | X86_EFLAGS_FIXED);
-	svm_set_efer(&svm->vcpu, hsave->save.efer);
-	svm_set_cr0(&svm->vcpu, hsave->save.cr0 | X86_CR0_PE);
-	svm_set_cr4(&svm->vcpu, hsave->save.cr4);
-	kvm_rax_write(&svm->vcpu, hsave->save.rax);
-	kvm_rsp_write(&svm->vcpu, hsave->save.rsp);
-	kvm_rip_write(&svm->vcpu, hsave->save.rip);
-	svm->vmcb->save.dr7 = DR7_FIXED_1;
-	svm->vmcb->save.cpl = 0;
-	svm->vmcb->control.exit_int_info = 0;
+	/*
+	 * Restore processor state that had been saved in vmcb01
+	 */
+	kvm_set_rflags(vcpu, svm->vmcb->save.rflags);
+	svm_set_efer(vcpu, svm->vmcb->save.efer);
+	svm_set_cr0(vcpu, svm->vmcb->save.cr0 | X86_CR0_PE);
+	svm_set_cr4(vcpu, svm->vmcb->save.cr4);
+	kvm_rax_write(vcpu, svm->vmcb->save.rax);
+	kvm_rsp_write(vcpu, svm->vmcb->save.rsp);
+	kvm_rip_write(vcpu, svm->vmcb->save.rip);
 
-	vmcb_mark_all_dirty(svm->vmcb);
+	svm->vcpu.arch.dr7 = DR7_FIXED_1;
+	kvm_update_dr7(&svm->vcpu);
 
 	trace_kvm_nested_vmexit_inject(vmcb12->control.exit_code,
 				       vmcb12->control.exit_info_1,
@@ -717,50 +762,62 @@
 				       vmcb12->control.exit_int_info_err,
 				       KVM_ISA_SVM);
 
-	kvm_vcpu_unmap(&svm->vcpu, &map, true);
+	kvm_vcpu_unmap(vcpu, &map, true);
 
-	nested_svm_uninit_mmu_context(&svm->vcpu);
+	nested_svm_uninit_mmu_context(vcpu);
 
-	rc = nested_svm_load_cr3(&svm->vcpu, hsave->save.cr3, false);
+	rc = nested_svm_load_cr3(vcpu, svm->vmcb->save.cr3, false);
 	if (rc)
 		return 1;
 
-	if (npt_enabled)
-		svm->vmcb->save.cr3 = hsave->save.cr3;
-
 	/*
 	 * Drop what we picked up for L2 via svm_complete_interrupts() so it
 	 * doesn't end up in L1.
 	 */
 	svm->vcpu.arch.nmi_injected = false;
-	kvm_clear_exception_queue(&svm->vcpu);
-	kvm_clear_interrupt_queue(&svm->vcpu);
+	kvm_clear_exception_queue(vcpu);
+	kvm_clear_interrupt_queue(vcpu);
+
+	/*
+	 * If we are here following the completion of a VMRUN that
+	 * is being single-stepped, queue the pending #DB intercept
+	 * right now so that it an be accounted for before we execute
+	 * L1's next instruction.
+	 */
+	if (unlikely(svm->vmcb->save.rflags & X86_EFLAGS_TF))
+		kvm_queue_exception(&(svm->vcpu), DB_VECTOR);
 
 	return 0;
 }
 
+static void nested_svm_triple_fault(struct kvm_vcpu *vcpu)
+{
+	nested_svm_simple_vmexit(to_svm(vcpu), SVM_EXIT_SHUTDOWN);
+}
+
 int svm_allocate_nested(struct vcpu_svm *svm)
 {
-	struct page *hsave_page;
+	struct page *vmcb02_page;
 
 	if (svm->nested.initialized)
 		return 0;
 
-	hsave_page = alloc_page(GFP_KERNEL_ACCOUNT | __GFP_ZERO);
-	if (!hsave_page)
+	vmcb02_page = alloc_page(GFP_KERNEL_ACCOUNT | __GFP_ZERO);
+	if (!vmcb02_page)
 		return -ENOMEM;
-	svm->nested.hsave = page_address(hsave_page);
+	svm->nested.vmcb02.ptr = page_address(vmcb02_page);
+	svm->nested.vmcb02.pa = __sme_set(page_to_pfn(vmcb02_page) << PAGE_SHIFT);
 
 	svm->nested.msrpm = svm_vcpu_alloc_msrpm();
 	if (!svm->nested.msrpm)
-		goto err_free_hsave;
+		goto err_free_vmcb02;
 	svm_vcpu_init_msrpm(&svm->vcpu, svm->nested.msrpm);
 
 	svm->nested.initialized = true;
 	return 0;
 
-err_free_hsave:
-	__free_page(hsave_page);
+err_free_vmcb02:
+	__free_page(vmcb02_page);
 	return -ENOMEM;
 }
 
@@ -772,8 +829,8 @@
 	svm_vcpu_free_msrpm(svm->nested.msrpm);
 	svm->nested.msrpm = NULL;
 
-	__free_page(virt_to_page(svm->nested.hsave));
-	svm->nested.hsave = NULL;
+	__free_page(virt_to_page(svm->nested.vmcb02.ptr));
+	svm->nested.vmcb02.ptr = NULL;
 
 	svm->nested.initialized = false;
 }
@@ -783,18 +840,19 @@
  */
 void svm_leave_nested(struct vcpu_svm *svm)
 {
-	if (is_guest_mode(&svm->vcpu)) {
-		struct vmcb *hsave = svm->nested.hsave;
-		struct vmcb *vmcb = svm->vmcb;
+	struct kvm_vcpu *vcpu = &svm->vcpu;
 
+	if (is_guest_mode(vcpu)) {
 		svm->nested.nested_run_pending = 0;
-		leave_guest_mode(&svm->vcpu);
-		copy_vmcb_control_area(&vmcb->control, &hsave->control);
-		nested_svm_uninit_mmu_context(&svm->vcpu);
+		leave_guest_mode(vcpu);
+
+		svm_switch_vmcb(svm, &svm->nested.vmcb02);
+
+		nested_svm_uninit_mmu_context(vcpu);
 		vmcb_mark_all_dirty(svm->vmcb);
 	}
 
-	kvm_clear_request(KVM_REQ_GET_NESTED_STATE_PAGES, &svm->vcpu);
+	kvm_clear_request(KVM_REQ_GET_NESTED_STATE_PAGES, vcpu);
 }
 
 static int nested_svm_exit_handled_msr(struct vcpu_svm *svm)
@@ -903,16 +961,15 @@
 	return vmexit;
 }
 
-int nested_svm_check_permissions(struct vcpu_svm *svm)
+int nested_svm_check_permissions(struct kvm_vcpu *vcpu)
 {
-	if (!(svm->vcpu.arch.efer & EFER_SVME) ||
-	    !is_paging(&svm->vcpu)) {
-		kvm_queue_exception(&svm->vcpu, UD_VECTOR);
+	if (!(vcpu->arch.efer & EFER_SVME) || !is_paging(vcpu)) {
+		kvm_queue_exception(vcpu, UD_VECTOR);
 		return 1;
 	}
 
-	if (svm->vmcb->save.cpl) {
-		kvm_inject_gp(&svm->vcpu, 0);
+	if (to_svm(vcpu)->vmcb->save.cpl) {
+		kvm_inject_gp(vcpu, 0);
 		return 1;
 	}
 
@@ -960,50 +1017,11 @@
 	nested_svm_vmexit(svm);
 }
 
-static void nested_svm_smi(struct vcpu_svm *svm)
-{
-	svm->vmcb->control.exit_code = SVM_EXIT_SMI;
-	svm->vmcb->control.exit_info_1 = 0;
-	svm->vmcb->control.exit_info_2 = 0;
-
-	nested_svm_vmexit(svm);
-}
-
-static void nested_svm_nmi(struct vcpu_svm *svm)
-{
-	svm->vmcb->control.exit_code = SVM_EXIT_NMI;
-	svm->vmcb->control.exit_info_1 = 0;
-	svm->vmcb->control.exit_info_2 = 0;
-
-	nested_svm_vmexit(svm);
-}
-
-static void nested_svm_intr(struct vcpu_svm *svm)
-{
-	trace_kvm_nested_intr_vmexit(svm->vmcb->save.rip);
-
-	svm->vmcb->control.exit_code   = SVM_EXIT_INTR;
-	svm->vmcb->control.exit_info_1 = 0;
-	svm->vmcb->control.exit_info_2 = 0;
-
-	nested_svm_vmexit(svm);
-}
-
 static inline bool nested_exit_on_init(struct vcpu_svm *svm)
 {
 	return vmcb_is_intercept(&svm->nested.ctl, INTERCEPT_INIT);
 }
 
-static void nested_svm_init(struct vcpu_svm *svm)
-{
-	svm->vmcb->control.exit_code   = SVM_EXIT_INIT;
-	svm->vmcb->control.exit_info_1 = 0;
-	svm->vmcb->control.exit_info_2 = 0;
-
-	nested_svm_vmexit(svm);
-}
-
-
 static int svm_check_nested_events(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
@@ -1017,12 +1035,18 @@
 			return -EBUSY;
 		if (!nested_exit_on_init(svm))
 			return 0;
-		nested_svm_init(svm);
+		nested_svm_simple_vmexit(svm, SVM_EXIT_INIT);
 		return 0;
 	}
 
 	if (vcpu->arch.exception.pending) {
-		if (block_nested_events)
+		/*
+		 * Only a pending nested run can block a pending exception.
+		 * Otherwise an injected NMI/interrupt should either be
+		 * lost or delivered to the nested hypervisor in the EXITINTINFO
+		 * vmcb field, while delivering the pending exception.
+		 */
+		if (svm->nested.nested_run_pending)
                         return -EBUSY;
 		if (!nested_exit_on_exception(svm))
 			return 0;
@@ -1035,7 +1059,7 @@
 			return -EBUSY;
 		if (!nested_exit_on_smi(svm))
 			return 0;
-		nested_svm_smi(svm);
+		nested_svm_simple_vmexit(svm, SVM_EXIT_SMI);
 		return 0;
 	}
 
@@ -1044,7 +1068,7 @@
 			return -EBUSY;
 		if (!nested_exit_on_nmi(svm))
 			return 0;
-		nested_svm_nmi(svm);
+		nested_svm_simple_vmexit(svm, SVM_EXIT_NMI);
 		return 0;
 	}
 
@@ -1053,7 +1077,8 @@
 			return -EBUSY;
 		if (!nested_exit_on_intr(svm))
 			return 0;
-		nested_svm_intr(svm);
+		trace_kvm_nested_intr_vmexit(svm->vmcb->save.rip);
+		nested_svm_simple_vmexit(svm, SVM_EXIT_INTR);
 		return 0;
 	}
 
@@ -1072,8 +1097,8 @@
 	case SVM_EXIT_EXCP_BASE ... SVM_EXIT_EXCP_BASE + 0x1f: {
 		u32 excp_bits = 1 << (exit_code - SVM_EXIT_EXCP_BASE);
 
-		if (get_host_vmcb(svm)->control.intercepts[INTERCEPT_EXCEPTION] &
-				excp_bits)
+		if (svm->vmcb01.ptr->control.intercepts[INTERCEPT_EXCEPTION] &
+		    excp_bits)
 			return NESTED_EXIT_HOST;
 		else if (exit_code == SVM_EXIT_EXCP_BASE + PF_VECTOR &&
 			 svm->vcpu.arch.apf.host_apf_flags)
@@ -1137,10 +1162,9 @@
 	if (copy_to_user(&user_vmcb->control, &svm->nested.ctl,
 			 sizeof(user_vmcb->control)))
 		return -EFAULT;
-	if (copy_to_user(&user_vmcb->save, &svm->nested.hsave->save,
+	if (copy_to_user(&user_vmcb->save, &svm->vmcb01.ptr->save,
 			 sizeof(user_vmcb->save)))
 		return -EFAULT;
-
 out:
 	return kvm_state.size;
 }
@@ -1150,7 +1174,6 @@
 				struct kvm_nested_state *kvm_state)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
-	struct vmcb *hsave = svm->nested.hsave;
 	struct vmcb __user *user_vmcb = (struct vmcb __user *)
 		&user_kvm_nested_state->data.svm[0];
 	struct vmcb_control_area *ctl;
@@ -1195,8 +1218,8 @@
 		return -EINVAL;
 
 	ret  = -ENOMEM;
-	ctl  = kzalloc(sizeof(*ctl),  GFP_KERNEL);
-	save = kzalloc(sizeof(*save), GFP_KERNEL);
+	ctl  = kzalloc(sizeof(*ctl),  GFP_KERNEL_ACCOUNT);
+	save = kzalloc(sizeof(*save), GFP_KERNEL_ACCOUNT);
 	if (!ctl || !save)
 		goto out_free;
 
@@ -1212,7 +1235,7 @@
 
 	/*
 	 * Processor state contains L2 state.  Check that it is
-	 * valid for guest mode (see nested_vmcb_checks).
+	 * valid for guest mode (see nested_vmcb_check_save).
 	 */
 	cr0 = kvm_read_cr0(vcpu);
         if (((cr0 & X86_CR0_CD) == 0) && (cr0 & X86_CR0_NW))
@@ -1221,29 +1244,48 @@
 	/*
 	 * Validate host state saved from before VMRUN (see
 	 * nested_svm_check_permissions).
-	 * TODO: validate reserved bits for all saved state.
 	 */
-	if (!(save->cr0 & X86_CR0_PG))
-		goto out_free;
-	if (!(save->efer & EFER_SVME))
+	if (!(save->cr0 & X86_CR0_PG) ||
+	    !(save->cr0 & X86_CR0_PE) ||
+	    (save->rflags & X86_EFLAGS_VM) ||
+	    !nested_vmcb_valid_sregs(vcpu, save))
 		goto out_free;
 
 	/*
-	 * All checks done, we can enter guest mode.  L1 control fields
-	 * come from the nested save state.  Guest state is already
-	 * in the registers, the save area of the nested state instead
-	 * contains saved L1 state.
+	 * All checks done, we can enter guest mode. Userspace provides
+	 * vmcb12.control, which will be combined with L1 and stored into
+	 * vmcb02, and the L1 save state which we store in vmcb01.
+	 * L2 registers if needed are moved from the current VMCB to VMCB02.
 	 */
 
 	svm->nested.nested_run_pending =
 		!!(kvm_state->flags & KVM_STATE_NESTED_RUN_PENDING);
 
-	copy_vmcb_control_area(&hsave->control, &svm->vmcb->control);
-	hsave->save = *save;
-
 	svm->nested.vmcb12_gpa = kvm_state->hdr.svm.vmcb_pa;
-	load_nested_vmcb_control(svm, ctl);
-	nested_prepare_vmcb_control(svm);
+	if (svm->current_vmcb == &svm->vmcb01)
+		svm->nested.vmcb02.ptr->save = svm->vmcb01.ptr->save;
+
+	svm->vmcb01.ptr->save.es = save->es;
+	svm->vmcb01.ptr->save.cs = save->cs;
+	svm->vmcb01.ptr->save.ss = save->ss;
+	svm->vmcb01.ptr->save.ds = save->ds;
+	svm->vmcb01.ptr->save.gdtr = save->gdtr;
+	svm->vmcb01.ptr->save.idtr = save->idtr;
+	svm->vmcb01.ptr->save.rflags = save->rflags | X86_EFLAGS_FIXED;
+	svm->vmcb01.ptr->save.efer = save->efer;
+	svm->vmcb01.ptr->save.cr0 = save->cr0;
+	svm->vmcb01.ptr->save.cr3 = save->cr3;
+	svm->vmcb01.ptr->save.cr4 = save->cr4;
+	svm->vmcb01.ptr->save.rax = save->rax;
+	svm->vmcb01.ptr->save.rsp = save->rsp;
+	svm->vmcb01.ptr->save.rip = save->rip;
+	svm->vmcb01.ptr->save.cpl = 0;
+
+	nested_load_control_from_vmcb12(svm, ctl);
+
+	svm_switch_vmcb(svm, &svm->nested.vmcb02);
+
+	nested_vmcb02_prepare_control(svm);
 
 	kvm_make_request(KVM_REQ_GET_NESTED_STATE_PAGES, vcpu);
 	ret = 0;
@@ -1254,8 +1296,31 @@
 	return ret;
 }
 
+static bool svm_get_nested_state_pages(struct kvm_vcpu *vcpu)
+{
+	struct vcpu_svm *svm = to_svm(vcpu);
+
+	if (WARN_ON(!is_guest_mode(vcpu)))
+		return true;
+
+	if (nested_svm_load_cr3(&svm->vcpu, vcpu->arch.cr3,
+				nested_npt_enabled(svm)))
+		return false;
+
+	if (!nested_svm_vmrun_msrpm(svm)) {
+		vcpu->run->exit_reason = KVM_EXIT_INTERNAL_ERROR;
+		vcpu->run->internal.suberror =
+			KVM_INTERNAL_ERROR_EMULATION;
+		vcpu->run->internal.ndata = 0;
+		return false;
+	}
+
+	return true;
+}
+
 struct kvm_x86_nested_ops svm_nested_ops = {
 	.check_events = svm_check_nested_events,
+	.triple_fault = nested_svm_triple_fault,
 	.get_nested_state_pages = svm_get_nested_state_pages,
 	.get_state = svm_get_nested_state,
 	.set_state = svm_set_nested_state,
diff --git a/arch/x86/kvm/svm/sev.c b/arch/x86/kvm/svm/sev.c
index 874ea30..5457138 100644
--- a/arch/x86/kvm/svm/sev.c
+++ b/arch/x86/kvm/svm/sev.c
@@ -66,6 +66,11 @@
 	return ret;
 }
 
+static inline bool is_mirroring_enc_context(struct kvm *kvm)
+{
+	return to_kvm_svm(kvm)->sev_info.enc_context_owner;
+}
+
 /* Must be called with the sev_bitmap_lock held */
 static bool __sev_recycle_asids(int min_asid, int max_asid)
 {
@@ -87,7 +92,7 @@
 	return true;
 }
 
-static int sev_asid_new(struct kvm_sev_info *sev)
+static int sev_asid_new(bool es_active)
 {
 	int pos, min_asid, max_asid;
 	bool retry = true;
@@ -98,8 +103,8 @@
 	 * SEV-enabled guests must use asid from min_sev_asid to max_sev_asid.
 	 * SEV-ES-enabled guest can use from 1 to min_sev_asid - 1.
 	 */
-	min_asid = sev->es_active ? 0 : min_sev_asid - 1;
-	max_asid = sev->es_active ? min_sev_asid - 1 : max_sev_asid;
+	min_asid = es_active ? 0 : min_sev_asid - 1;
+	max_asid = es_active ? min_sev_asid - 1 : max_sev_asid;
 again:
 	pos = find_next_zero_bit(sev_asid_bitmap, max_sev_asid, min_asid);
 	if (pos >= max_asid) {
@@ -179,13 +184,17 @@
 static int sev_guest_init(struct kvm *kvm, struct kvm_sev_cmd *argp)
 {
 	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
+	bool es_active = argp->id == KVM_SEV_ES_INIT;
 	int asid, ret;
 
+	if (kvm->created_vcpus)
+		return -EINVAL;
+
 	ret = -EBUSY;
 	if (unlikely(sev->active))
 		return ret;
 
-	asid = sev_asid_new(sev);
+	asid = sev_asid_new(es_active);
 	if (asid < 0)
 		return ret;
 
@@ -194,6 +203,7 @@
 		goto e_free;
 
 	sev->active = true;
+	sev->es_active = es_active;
 	sev->asid = asid;
 	INIT_LIST_HEAD(&sev->regions_list);
 
@@ -204,16 +214,6 @@
 	return ret;
 }
 
-static int sev_es_guest_init(struct kvm *kvm, struct kvm_sev_cmd *argp)
-{
-	if (!sev_es)
-		return -ENOTTY;
-
-	to_kvm_svm(kvm)->sev_info.es_active = true;
-
-	return sev_guest_init(kvm, argp);
-}
-
 static int sev_bind_asid(struct kvm *kvm, unsigned int handle, int *error)
 {
 	struct sev_data_activate *data;
@@ -564,6 +564,7 @@
 {
 	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
 	struct sev_data_launch_update_vmsa *vmsa;
+	struct kvm_vcpu *vcpu;
 	int i, ret;
 
 	if (!sev_es_guest(kvm))
@@ -573,8 +574,8 @@
 	if (!vmsa)
 		return -ENOMEM;
 
-	for (i = 0; i < kvm->created_vcpus; i++) {
-		struct vcpu_svm *svm = to_svm(kvm->vcpus[i]);
+	kvm_for_each_vcpu(i, vcpu, kvm) {
+		struct vcpu_svm *svm = to_svm(vcpu);
 
 		/* Perform some pre-encryption checks against the VMSA */
 		ret = sev_es_sync_vmsa(svm);
@@ -637,7 +638,7 @@
 		}
 
 		ret = -ENOMEM;
-		blob = kmalloc(params.len, GFP_KERNEL);
+		blob = kmalloc(params.len, GFP_KERNEL_ACCOUNT);
 		if (!blob)
 			goto e_free;
 
@@ -1074,7 +1075,7 @@
 		}
 
 		ret = -ENOMEM;
-		blob = kmalloc(params.len, GFP_KERNEL);
+		blob = kmalloc(params.len, GFP_KERNEL_ACCOUNT);
 		if (!blob)
 			goto e_free;
 
@@ -1124,15 +1125,22 @@
 	if (copy_from_user(&sev_cmd, argp, sizeof(struct kvm_sev_cmd)))
 		return -EFAULT;
 
+	/* enc_context_owner handles all memory enc operations */
+	if (is_mirroring_enc_context(kvm))
+		return -ENOTTY;
+
 	mutex_lock(&kvm->lock);
 
 	switch (sev_cmd.id) {
+	case KVM_SEV_ES_INIT:
+		if (!sev_es) {
+			r = -ENOTTY;
+			goto out;
+		}
+		fallthrough;
 	case KVM_SEV_INIT:
 		r = sev_guest_init(kvm, &sev_cmd);
 		break;
-	case KVM_SEV_ES_INIT:
-		r = sev_es_guest_init(kvm, &sev_cmd);
-		break;
 	case KVM_SEV_LAUNCH_START:
 		r = sev_launch_start(kvm, &sev_cmd);
 		break;
@@ -1186,6 +1194,10 @@
 	if (!sev_guest(kvm))
 		return -ENOTTY;
 
+	/* If kvm is mirroring encryption context it isn't responsible for it */
+	if (is_mirroring_enc_context(kvm))
+		return -ENOTTY;
+
 	if (range->addr > ULONG_MAX || range->size > ULONG_MAX)
 		return -EINVAL;
 
@@ -1252,6 +1264,10 @@
 	struct enc_region *region;
 	int ret;
 
+	/* If kvm is mirroring encryption context it isn't responsible for it */
+	if (is_mirroring_enc_context(kvm))
+		return -ENOTTY;
+
 	mutex_lock(&kvm->lock);
 
 	if (!sev_guest(kvm)) {
@@ -1282,6 +1298,71 @@
 	return ret;
 }
 
+int svm_vm_copy_asid_from(struct kvm *kvm, unsigned int source_fd)
+{
+	struct file *source_kvm_file;
+	struct kvm *source_kvm;
+	struct kvm_sev_info *mirror_sev;
+	unsigned int asid;
+	int ret;
+
+	source_kvm_file = fget(source_fd);
+	if (!file_is_kvm(source_kvm_file)) {
+		ret = -EBADF;
+		goto e_source_put;
+	}
+
+	source_kvm = source_kvm_file->private_data;
+	mutex_lock(&source_kvm->lock);
+
+	if (!sev_guest(source_kvm)) {
+		ret = -ENOTTY;
+		goto e_source_unlock;
+	}
+
+	/* Mirrors of mirrors should work, but let's not get silly */
+	if (is_mirroring_enc_context(source_kvm) || source_kvm == kvm) {
+		ret = -ENOTTY;
+		goto e_source_unlock;
+	}
+
+	asid = to_kvm_svm(source_kvm)->sev_info.asid;
+
+	/*
+	 * The mirror kvm holds an enc_context_owner ref so its asid can't
+	 * disappear until we're done with it
+	 */
+	kvm_get_kvm(source_kvm);
+
+	fput(source_kvm_file);
+	mutex_unlock(&source_kvm->lock);
+	mutex_lock(&kvm->lock);
+
+	if (sev_guest(kvm)) {
+		ret = -ENOTTY;
+		goto e_mirror_unlock;
+	}
+
+	/* Set enc_context_owner and copy its encryption context over */
+	mirror_sev = &to_kvm_svm(kvm)->sev_info;
+	mirror_sev->enc_context_owner = source_kvm;
+	mirror_sev->asid = asid;
+	mirror_sev->active = true;
+
+	mutex_unlock(&kvm->lock);
+	return 0;
+
+e_mirror_unlock:
+	mutex_unlock(&kvm->lock);
+	kvm_put_kvm(source_kvm);
+	return ret;
+e_source_unlock:
+	mutex_unlock(&source_kvm->lock);
+e_source_put:
+	fput(source_kvm_file);
+	return ret;
+}
+
 void sev_vm_destroy(struct kvm *kvm)
 {
 	struct kvm_sev_info *sev = &to_kvm_svm(kvm)->sev_info;
@@ -1293,6 +1374,12 @@
 
 	mutex_lock(&kvm->lock);
 
+	/* If this is a mirror_kvm release the enc_context_owner and skip sev cleanup */
+	if (is_mirroring_enc_context(kvm)) {
+		kvm_put_kvm(sev->enc_context_owner);
+		return;
+	}
+
 	/*
 	 * Ensure that all guest tagged cache entries are flushed before
 	 * releasing the pages back to the system for use. CLFLUSH will
@@ -1775,7 +1862,7 @@
 			       len, GHCB_SCRATCH_AREA_LIMIT);
 			return false;
 		}
-		scratch_va = kzalloc(len, GFP_KERNEL);
+		scratch_va = kzalloc(len, GFP_KERNEL_ACCOUNT);
 		if (!scratch_va)
 			return false;
 
@@ -1849,7 +1936,7 @@
 		vcpu->arch.regs[VCPU_REGS_RAX] = cpuid_fn;
 		vcpu->arch.regs[VCPU_REGS_RCX] = 0;
 
-		ret = svm_invoke_exit_handler(svm, SVM_EXIT_CPUID);
+		ret = svm_invoke_exit_handler(vcpu, SVM_EXIT_CPUID);
 		if (!ret) {
 			ret = -EINVAL;
 			break;
@@ -1899,8 +1986,9 @@
 	return ret;
 }
 
-int sev_handle_vmgexit(struct vcpu_svm *svm)
+int sev_handle_vmgexit(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	struct vmcb_control_area *control = &svm->vmcb->control;
 	u64 ghcb_gpa, exit_code;
 	struct ghcb *ghcb;
@@ -1912,13 +2000,13 @@
 		return sev_handle_vmgexit_msr_protocol(svm);
 
 	if (!ghcb_gpa) {
-		vcpu_unimpl(&svm->vcpu, "vmgexit: GHCB gpa is not set\n");
+		vcpu_unimpl(vcpu, "vmgexit: GHCB gpa is not set\n");
 		return -EINVAL;
 	}
 
-	if (kvm_vcpu_map(&svm->vcpu, ghcb_gpa >> PAGE_SHIFT, &svm->ghcb_map)) {
+	if (kvm_vcpu_map(vcpu, ghcb_gpa >> PAGE_SHIFT, &svm->ghcb_map)) {
 		/* Unable to map GHCB from guest */
-		vcpu_unimpl(&svm->vcpu, "vmgexit: error mapping GHCB [%#llx] from guest\n",
+		vcpu_unimpl(vcpu, "vmgexit: error mapping GHCB [%#llx] from guest\n",
 			    ghcb_gpa);
 		return -EINVAL;
 	}
@@ -1926,7 +2014,7 @@
 	svm->ghcb = svm->ghcb_map.hva;
 	ghcb = svm->ghcb_map.hva;
 
-	trace_kvm_vmgexit_enter(svm->vcpu.vcpu_id, ghcb);
+	trace_kvm_vmgexit_enter(vcpu->vcpu_id, ghcb);
 
 	exit_code = ghcb_get_sw_exit_code(ghcb);
 
@@ -1944,7 +2032,7 @@
 		if (!setup_vmgexit_scratch(svm, true, control->exit_info_2))
 			break;
 
-		ret = kvm_sev_es_mmio_read(&svm->vcpu,
+		ret = kvm_sev_es_mmio_read(vcpu,
 					   control->exit_info_1,
 					   control->exit_info_2,
 					   svm->ghcb_sa);
@@ -1953,19 +2041,19 @@
 		if (!setup_vmgexit_scratch(svm, false, control->exit_info_2))
 			break;
 
-		ret = kvm_sev_es_mmio_write(&svm->vcpu,
+		ret = kvm_sev_es_mmio_write(vcpu,
 					    control->exit_info_1,
 					    control->exit_info_2,
 					    svm->ghcb_sa);
 		break;
 	case SVM_VMGEXIT_NMI_COMPLETE:
-		ret = svm_invoke_exit_handler(svm, SVM_EXIT_IRET);
+		ret = svm_invoke_exit_handler(vcpu, SVM_EXIT_IRET);
 		break;
 	case SVM_VMGEXIT_AP_HLT_LOOP:
-		ret = kvm_emulate_ap_reset_hold(&svm->vcpu);
+		ret = kvm_emulate_ap_reset_hold(vcpu);
 		break;
 	case SVM_VMGEXIT_AP_JUMP_TABLE: {
-		struct kvm_sev_info *sev = &to_kvm_svm(svm->vcpu.kvm)->sev_info;
+		struct kvm_sev_info *sev = &to_kvm_svm(vcpu->kvm)->sev_info;
 
 		switch (control->exit_info_1) {
 		case 0:
@@ -1990,12 +2078,12 @@
 		break;
 	}
 	case SVM_VMGEXIT_UNSUPPORTED_EVENT:
-		vcpu_unimpl(&svm->vcpu,
+		vcpu_unimpl(vcpu,
 			    "vmgexit: unsupported event - exit_info_1=%#llx, exit_info_2=%#llx\n",
 			    control->exit_info_1, control->exit_info_2);
 		break;
 	default:
-		ret = svm_invoke_exit_handler(svm, exit_code);
+		ret = svm_invoke_exit_handler(vcpu, exit_code);
 	}
 
 	return ret;
diff --git a/arch/x86/kvm/svm/svm.c b/arch/x86/kvm/svm/svm.c
index 58a45bb..48b396f3 100644
--- a/arch/x86/kvm/svm/svm.c
+++ b/arch/x86/kvm/svm/svm.c
@@ -95,6 +95,8 @@
 } direct_access_msrs[MAX_DIRECT_ACCESS_MSRS] = {
 	{ .index = MSR_STAR,				.always = true  },
 	{ .index = MSR_IA32_SYSENTER_CS,		.always = true  },
+	{ .index = MSR_IA32_SYSENTER_EIP,		.always = false },
+	{ .index = MSR_IA32_SYSENTER_ESP,		.always = false },
 #ifdef CONFIG_X86_64
 	{ .index = MSR_GS_BASE,				.always = true  },
 	{ .index = MSR_FS_BASE,				.always = true  },
@@ -279,7 +281,7 @@
 			 * In this case we will return to the nested guest
 			 * as soon as we leave SMM.
 			 */
-			if (!is_smm(&svm->vcpu))
+			if (!is_smm(vcpu))
 				svm_free_nested(svm);
 
 		} else {
@@ -363,10 +365,10 @@
 	bool has_error_code = vcpu->arch.exception.has_error_code;
 	u32 error_code = vcpu->arch.exception.error_code;
 
-	kvm_deliver_exception_payload(&svm->vcpu);
+	kvm_deliver_exception_payload(vcpu);
 
 	if (nr == BP_VECTOR && !nrips) {
-		unsigned long rip, old_rip = kvm_rip_read(&svm->vcpu);
+		unsigned long rip, old_rip = kvm_rip_read(vcpu);
 
 		/*
 		 * For guest debugging where we have to reinject #BP if some
@@ -375,8 +377,8 @@
 		 * raises a fault that is not intercepted. Still better than
 		 * failing in all cases.
 		 */
-		(void)skip_emulated_instruction(&svm->vcpu);
-		rip = kvm_rip_read(&svm->vcpu);
+		(void)skip_emulated_instruction(vcpu);
+		rip = kvm_rip_read(vcpu);
 		svm->int3_rip = rip + svm->vmcb->save.cs.base;
 		svm->int3_injected = rip - old_rip;
 	}
@@ -881,7 +883,7 @@
 	 */
 	mask = (mask_bit < 52) ? rsvd_bits(mask_bit, 51) | PT_PRESENT_MASK : 0;
 
-	kvm_mmu_set_mmio_spte_mask(mask, PT_WRITABLE_MASK | PT_USER_MASK);
+	kvm_mmu_set_mmio_spte_mask(mask, mask, PT_WRITABLE_MASK | PT_USER_MASK);
 }
 
 static void svm_hardware_teardown(void)
@@ -1084,8 +1086,8 @@
 	if (is_guest_mode(vcpu)) {
 		/* Write L1's TSC offset.  */
 		g_tsc_offset = svm->vmcb->control.tsc_offset -
-			       svm->nested.hsave->control.tsc_offset;
-		svm->nested.hsave->control.tsc_offset = offset;
+			       svm->vmcb01.ptr->control.tsc_offset;
+		svm->vmcb01.ptr->control.tsc_offset = offset;
 	}
 
 	trace_kvm_write_tsc_offset(vcpu->vcpu_id,
@@ -1113,12 +1115,13 @@
 	}
 }
 
-static void init_vmcb(struct vcpu_svm *svm)
+static void init_vmcb(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	struct vmcb_control_area *control = &svm->vmcb->control;
 	struct vmcb_save_area *save = &svm->vmcb->save;
 
-	svm->vcpu.arch.hflags = 0;
+	vcpu->arch.hflags = 0;
 
 	svm_set_intercept(svm, INTERCEPT_CR0_READ);
 	svm_set_intercept(svm, INTERCEPT_CR3_READ);
@@ -1126,7 +1129,7 @@
 	svm_set_intercept(svm, INTERCEPT_CR0_WRITE);
 	svm_set_intercept(svm, INTERCEPT_CR3_WRITE);
 	svm_set_intercept(svm, INTERCEPT_CR4_WRITE);
-	if (!kvm_vcpu_apicv_active(&svm->vcpu))
+	if (!kvm_vcpu_apicv_active(vcpu))
 		svm_set_intercept(svm, INTERCEPT_CR8_WRITE);
 
 	set_dr_intercepts(svm);
@@ -1170,12 +1173,12 @@
 	svm_set_intercept(svm, INTERCEPT_RDPRU);
 	svm_set_intercept(svm, INTERCEPT_RSM);
 
-	if (!kvm_mwait_in_guest(svm->vcpu.kvm)) {
+	if (!kvm_mwait_in_guest(vcpu->kvm)) {
 		svm_set_intercept(svm, INTERCEPT_MONITOR);
 		svm_set_intercept(svm, INTERCEPT_MWAIT);
 	}
 
-	if (!kvm_hlt_in_guest(svm->vcpu.kvm))
+	if (!kvm_hlt_in_guest(vcpu->kvm))
 		svm_set_intercept(svm, INTERCEPT_HLT);
 
 	control->iopm_base_pa = __sme_set(iopm_base);
@@ -1201,19 +1204,19 @@
 	init_sys_seg(&save->ldtr, SEG_TYPE_LDT);
 	init_sys_seg(&save->tr, SEG_TYPE_BUSY_TSS16);
 
-	svm_set_cr4(&svm->vcpu, 0);
-	svm_set_efer(&svm->vcpu, 0);
+	svm_set_cr4(vcpu, 0);
+	svm_set_efer(vcpu, 0);
 	save->dr6 = 0xffff0ff0;
-	kvm_set_rflags(&svm->vcpu, X86_EFLAGS_FIXED);
+	kvm_set_rflags(vcpu, X86_EFLAGS_FIXED);
 	save->rip = 0x0000fff0;
-	svm->vcpu.arch.regs[VCPU_REGS_RIP] = save->rip;
+	vcpu->arch.regs[VCPU_REGS_RIP] = save->rip;
 
 	/*
 	 * svm_set_cr0() sets PG and WP and clears NW and CD on save->cr0.
 	 * It also updates the guest-visible cr0 value.
 	 */
-	svm_set_cr0(&svm->vcpu, X86_CR0_NW | X86_CR0_CD | X86_CR0_ET);
-	kvm_mmu_reset_context(&svm->vcpu);
+	svm_set_cr0(vcpu, X86_CR0_NW | X86_CR0_CD | X86_CR0_ET);
+	kvm_mmu_reset_context(vcpu);
 
 	save->cr4 = X86_CR4_PAE;
 	/* rdx = ?? */
@@ -1225,17 +1228,18 @@
 		clr_exception_intercept(svm, PF_VECTOR);
 		svm_clr_intercept(svm, INTERCEPT_CR3_READ);
 		svm_clr_intercept(svm, INTERCEPT_CR3_WRITE);
-		save->g_pat = svm->vcpu.arch.pat;
+		save->g_pat = vcpu->arch.pat;
 		save->cr3 = 0;
 		save->cr4 = 0;
 	}
-	svm->asid_generation = 0;
+	svm->current_vmcb->asid_generation = 0;
 	svm->asid = 0;
 
 	svm->nested.vmcb12_gpa = 0;
-	svm->vcpu.arch.hflags = 0;
+	svm->nested.last_vmcb12_gpa = 0;
+	vcpu->arch.hflags = 0;
 
-	if (!kvm_pause_in_guest(svm->vcpu.kvm)) {
+	if (!kvm_pause_in_guest(vcpu->kvm)) {
 		control->pause_filter_count = pause_filter_count;
 		if (pause_filter_thresh)
 			control->pause_filter_thresh = pause_filter_thresh;
@@ -1246,18 +1250,15 @@
 
 	svm_check_invpcid(svm);
 
-	if (kvm_vcpu_apicv_active(&svm->vcpu))
-		avic_init_vmcb(svm);
-
 	/*
-	 * If hardware supports Virtual VMLOAD VMSAVE then enable it
-	 * in VMCB and clear intercepts to avoid #VMEXIT.
+	 * If the host supports V_SPEC_CTRL then disable the interception
+	 * of MSR_IA32_SPEC_CTRL.
 	 */
-	if (vls) {
-		svm_clr_intercept(svm, INTERCEPT_VMLOAD);
-		svm_clr_intercept(svm, INTERCEPT_VMSAVE);
-		svm->vmcb->control.virt_ext |= VIRTUAL_VMLOAD_VMSAVE_ENABLE_MASK;
-	}
+	if (boot_cpu_has(X86_FEATURE_V_SPEC_CTRL))
+		set_msr_interception(vcpu, svm->msrpm, MSR_IA32_SPEC_CTRL, 1, 1);
+
+	if (kvm_vcpu_apicv_active(vcpu))
+		avic_init_vmcb(svm);
 
 	if (vgif) {
 		svm_clr_intercept(svm, INTERCEPT_STGI);
@@ -1265,11 +1266,11 @@
 		svm->vmcb->control.int_ctl |= V_GIF_ENABLE_MASK;
 	}
 
-	if (sev_guest(svm->vcpu.kvm)) {
+	if (sev_guest(vcpu->kvm)) {
 		svm->vmcb->control.nested_ctl |= SVM_NESTED_CTL_SEV_ENABLE;
 		clr_exception_intercept(svm, UD_VECTOR);
 
-		if (sev_es_guest(svm->vcpu.kvm)) {
+		if (sev_es_guest(vcpu->kvm)) {
 			/* Perform SEV-ES specific VMCB updates */
 			sev_es_init_vmcb(svm);
 		}
@@ -1291,12 +1292,12 @@
 	svm->virt_spec_ctrl = 0;
 
 	if (!init_event) {
-		svm->vcpu.arch.apic_base = APIC_DEFAULT_PHYS_BASE |
-					   MSR_IA32_APICBASE_ENABLE;
-		if (kvm_vcpu_is_reset_bsp(&svm->vcpu))
-			svm->vcpu.arch.apic_base |= MSR_IA32_APICBASE_BSP;
+		vcpu->arch.apic_base = APIC_DEFAULT_PHYS_BASE |
+				       MSR_IA32_APICBASE_ENABLE;
+		if (kvm_vcpu_is_reset_bsp(vcpu))
+			vcpu->arch.apic_base |= MSR_IA32_APICBASE_BSP;
 	}
-	init_vmcb(svm);
+	init_vmcb(vcpu);
 
 	kvm_cpuid(vcpu, &eax, &dummy, &dummy, &dummy, false);
 	kvm_rdx_write(vcpu, eax);
@@ -1305,10 +1306,25 @@
 		avic_update_vapic_bar(svm, APIC_DEFAULT_PHYS_BASE);
 }
 
+void svm_switch_vmcb(struct vcpu_svm *svm, struct kvm_vmcb_info *target_vmcb)
+{
+	svm->current_vmcb = target_vmcb;
+	svm->vmcb = target_vmcb->ptr;
+	svm->vmcb_pa = target_vmcb->pa;
+
+	/*
+	* Track the physical CPU the target_vmcb is running on
+	* in order to mark the VMCB dirty if the cpu changes at
+	* its next vmrun.
+	*/
+
+	svm->current_vmcb->cpu = svm->vcpu.cpu;
+}
+
 static int svm_create_vcpu(struct kvm_vcpu *vcpu)
 {
 	struct vcpu_svm *svm;
-	struct page *vmcb_page;
+	struct page *vmcb01_page;
 	struct page *vmsa_page = NULL;
 	int err;
 
@@ -1316,11 +1332,11 @@
 	svm = to_svm(vcpu);
 
 	err = -ENOMEM;
-	vmcb_page = alloc_page(GFP_KERNEL_ACCOUNT | __GFP_ZERO);
-	if (!vmcb_page)
+	vmcb01_page = alloc_page(GFP_KERNEL_ACCOUNT | __GFP_ZERO);
+	if (!vmcb01_page)
 		goto out;
 
-	if (sev_es_guest(svm->vcpu.kvm)) {
+	if (sev_es_guest(vcpu->kvm)) {
 		/*
 		 * SEV-ES guests require a separate VMSA page used to contain
 		 * the encrypted register state of the guest.
@@ -1356,20 +1372,21 @@
 
 	svm_vcpu_init_msrpm(vcpu, svm->msrpm);
 
-	svm->vmcb = page_address(vmcb_page);
-	svm->vmcb_pa = __sme_set(page_to_pfn(vmcb_page) << PAGE_SHIFT);
+	svm->vmcb01.ptr = page_address(vmcb01_page);
+	svm->vmcb01.pa = __sme_set(page_to_pfn(vmcb01_page) << PAGE_SHIFT);
 
 	if (vmsa_page)
 		svm->vmsa = page_address(vmsa_page);
 
-	svm->asid_generation = 0;
 	svm->guest_state_loaded = false;
-	init_vmcb(svm);
+
+	svm_switch_vmcb(svm, &svm->vmcb01);
+	init_vmcb(vcpu);
 
 	svm_init_osvw(vcpu);
 	vcpu->arch.microcode_version = 0x01000065;
 
-	if (sev_es_guest(svm->vcpu.kvm))
+	if (sev_es_guest(vcpu->kvm))
 		/* Perform SEV-ES specific VMCB creation updates */
 		sev_es_create_vcpu(svm);
 
@@ -1379,7 +1396,7 @@
 	if (vmsa_page)
 		__free_page(vmsa_page);
 error_free_vmcb_page:
-	__free_page(vmcb_page);
+	__free_page(vmcb01_page);
 out:
 	return err;
 }
@@ -1407,7 +1424,7 @@
 
 	sev_free_vcpu(vcpu);
 
-	__free_page(pfn_to_page(__sme_clr(svm->vmcb_pa) >> PAGE_SHIFT));
+	__free_page(pfn_to_page(__sme_clr(svm->vmcb01.pa) >> PAGE_SHIFT));
 	__free_pages(virt_to_page(svm->msrpm), MSRPM_ALLOC_ORDER);
 }
 
@@ -1432,7 +1449,7 @@
 	 * Save additional host state that will be restored on VMEXIT (sev-es)
 	 * or subsequent vmload of host save area.
 	 */
-	if (sev_es_guest(svm->vcpu.kvm)) {
+	if (sev_es_guest(vcpu->kvm)) {
 		sev_es_prepare_guest_switch(svm, vcpu->cpu);
 	} else {
 		vmsave(__sme_page_pa(sd->save_area));
@@ -1476,11 +1493,6 @@
 	struct vcpu_svm *svm = to_svm(vcpu);
 	struct svm_cpu_data *sd = per_cpu(svm_data, cpu);
 
-	if (unlikely(cpu != vcpu->cpu)) {
-		svm->asid_generation = 0;
-		vmcb_mark_all_dirty(svm->vmcb);
-	}
-
 	if (sd->current_vmcb != svm->vmcb) {
 		sd->current_vmcb = svm->vmcb;
 		indirect_branch_prediction_barrier();
@@ -1564,7 +1576,7 @@
 	/* Drop int_ctl fields related to VINTR injection.  */
 	svm->vmcb->control.int_ctl &= mask;
 	if (is_guest_mode(&svm->vcpu)) {
-		svm->nested.hsave->control.int_ctl &= mask;
+		svm->vmcb01.ptr->control.int_ctl &= mask;
 
 		WARN_ON((svm->vmcb->control.int_ctl & V_TPR_MASK) !=
 			(svm->nested.ctl.int_ctl & V_TPR_MASK));
@@ -1577,16 +1589,17 @@
 static struct vmcb_seg *svm_seg(struct kvm_vcpu *vcpu, int seg)
 {
 	struct vmcb_save_area *save = &to_svm(vcpu)->vmcb->save;
+	struct vmcb_save_area *save01 = &to_svm(vcpu)->vmcb01.ptr->save;
 
 	switch (seg) {
 	case VCPU_SREG_CS: return &save->cs;
 	case VCPU_SREG_DS: return &save->ds;
 	case VCPU_SREG_ES: return &save->es;
-	case VCPU_SREG_FS: return &save->fs;
-	case VCPU_SREG_GS: return &save->gs;
+	case VCPU_SREG_FS: return &save01->fs;
+	case VCPU_SREG_GS: return &save01->gs;
 	case VCPU_SREG_SS: return &save->ss;
-	case VCPU_SREG_TR: return &save->tr;
-	case VCPU_SREG_LDTR: return &save->ldtr;
+	case VCPU_SREG_TR: return &save01->tr;
+	case VCPU_SREG_LDTR: return &save01->ldtr;
 	}
 	BUG();
 	return NULL;
@@ -1709,37 +1722,10 @@
 	vmcb_mark_dirty(svm->vmcb, VMCB_DT);
 }
 
-static void update_cr0_intercept(struct vcpu_svm *svm)
-{
-	ulong gcr0;
-	u64 *hcr0;
-
-	/*
-	 * SEV-ES guests must always keep the CR intercepts cleared. CR
-	 * tracking is done using the CR write traps.
-	 */
-	if (sev_es_guest(svm->vcpu.kvm))
-		return;
-
-	gcr0 = svm->vcpu.arch.cr0;
-	hcr0 = &svm->vmcb->save.cr0;
-	*hcr0 = (*hcr0 & ~SVM_CR0_SELECTIVE_MASK)
-		| (gcr0 & SVM_CR0_SELECTIVE_MASK);
-
-	vmcb_mark_dirty(svm->vmcb, VMCB_CR);
-
-	if (gcr0 == *hcr0) {
-		svm_clr_intercept(svm, INTERCEPT_CR0_READ);
-		svm_clr_intercept(svm, INTERCEPT_CR0_WRITE);
-	} else {
-		svm_set_intercept(svm, INTERCEPT_CR0_READ);
-		svm_set_intercept(svm, INTERCEPT_CR0_WRITE);
-	}
-}
-
 void svm_set_cr0(struct kvm_vcpu *vcpu, unsigned long cr0)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
+	u64 hcr0 = cr0;
 
 #ifdef CONFIG_X86_64
 	if (vcpu->arch.efer & EFER_LME && !vcpu->arch.guest_state_protected) {
@@ -1757,7 +1743,7 @@
 	vcpu->arch.cr0 = cr0;
 
 	if (!npt_enabled)
-		cr0 |= X86_CR0_PG | X86_CR0_WP;
+		hcr0 |= X86_CR0_PG | X86_CR0_WP;
 
 	/*
 	 * re-enable caching here because the QEMU bios
@@ -1765,10 +1751,26 @@
 	 * reboot
 	 */
 	if (kvm_check_has_quirk(vcpu->kvm, KVM_X86_QUIRK_CD_NW_CLEARED))
-		cr0 &= ~(X86_CR0_CD | X86_CR0_NW);
-	svm->vmcb->save.cr0 = cr0;
+		hcr0 &= ~(X86_CR0_CD | X86_CR0_NW);
+
+	svm->vmcb->save.cr0 = hcr0;
 	vmcb_mark_dirty(svm->vmcb, VMCB_CR);
-	update_cr0_intercept(svm);
+
+	/*
+	 * SEV-ES guests must always keep the CR intercepts cleared. CR
+	 * tracking is done using the CR write traps.
+	 */
+	if (sev_es_guest(vcpu->kvm))
+		return;
+
+	if (hcr0 == cr0) {
+		/* Selective CR0 write remains on.  */
+		svm_clr_intercept(svm, INTERCEPT_CR0_READ);
+		svm_clr_intercept(svm, INTERCEPT_CR0_WRITE);
+	} else {
+		svm_set_intercept(svm, INTERCEPT_CR0_READ);
+		svm_set_intercept(svm, INTERCEPT_CR0_WRITE);
+	}
 }
 
 static bool svm_is_valid_cr4(struct kvm_vcpu *vcpu, unsigned long cr4)
@@ -1847,7 +1849,7 @@
 		vmcb_mark_dirty(svm->vmcb, VMCB_ASID);
 	}
 
-	svm->asid_generation = sd->asid_generation;
+	svm->current_vmcb->asid_generation = sd->asid_generation;
 	svm->asid = sd->next_asid++;
 }
 
@@ -1896,39 +1898,43 @@
 	vmcb_mark_dirty(svm->vmcb, VMCB_DR);
 }
 
-static int pf_interception(struct vcpu_svm *svm)
+static int pf_interception(struct kvm_vcpu *vcpu)
 {
-	u64 fault_address = __sme_clr(svm->vmcb->control.exit_info_2);
+	struct vcpu_svm *svm = to_svm(vcpu);
+
+	u64 fault_address = svm->vmcb->control.exit_info_2;
 	u64 error_code = svm->vmcb->control.exit_info_1;
 
-	return kvm_handle_page_fault(&svm->vcpu, error_code, fault_address,
+	return kvm_handle_page_fault(vcpu, error_code, fault_address,
 			static_cpu_has(X86_FEATURE_DECODEASSISTS) ?
 			svm->vmcb->control.insn_bytes : NULL,
 			svm->vmcb->control.insn_len);
 }
 
-static int npf_interception(struct vcpu_svm *svm)
+static int npf_interception(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
+
 	u64 fault_address = __sme_clr(svm->vmcb->control.exit_info_2);
 	u64 error_code = svm->vmcb->control.exit_info_1;
 
 	trace_kvm_page_fault(fault_address, error_code);
-	return kvm_mmu_page_fault(&svm->vcpu, fault_address, error_code,
+	return kvm_mmu_page_fault(vcpu, fault_address, error_code,
 			static_cpu_has(X86_FEATURE_DECODEASSISTS) ?
 			svm->vmcb->control.insn_bytes : NULL,
 			svm->vmcb->control.insn_len);
 }
 
-static int db_interception(struct vcpu_svm *svm)
+static int db_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_run *kvm_run = svm->vcpu.run;
-	struct kvm_vcpu *vcpu = &svm->vcpu;
+	struct kvm_run *kvm_run = vcpu->run;
+	struct vcpu_svm *svm = to_svm(vcpu);
 
-	if (!(svm->vcpu.guest_debug &
+	if (!(vcpu->guest_debug &
 	      (KVM_GUESTDBG_SINGLESTEP | KVM_GUESTDBG_USE_HW_BP)) &&
 		!svm->nmi_singlestep) {
 		u32 payload = svm->vmcb->save.dr6 ^ DR6_ACTIVE_LOW;
-		kvm_queue_exception_p(&svm->vcpu, DB_VECTOR, payload);
+		kvm_queue_exception_p(vcpu, DB_VECTOR, payload);
 		return 1;
 	}
 
@@ -1938,7 +1944,7 @@
 		kvm_make_request(KVM_REQ_EVENT, vcpu);
 	}
 
-	if (svm->vcpu.guest_debug &
+	if (vcpu->guest_debug &
 	    (KVM_GUESTDBG_SINGLESTEP | KVM_GUESTDBG_USE_HW_BP)) {
 		kvm_run->exit_reason = KVM_EXIT_DEBUG;
 		kvm_run->debug.arch.dr6 = svm->vmcb->save.dr6;
@@ -1952,9 +1958,10 @@
 	return 1;
 }
 
-static int bp_interception(struct vcpu_svm *svm)
+static int bp_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_run *kvm_run = svm->vcpu.run;
+	struct vcpu_svm *svm = to_svm(vcpu);
+	struct kvm_run *kvm_run = vcpu->run;
 
 	kvm_run->exit_reason = KVM_EXIT_DEBUG;
 	kvm_run->debug.arch.pc = svm->vmcb->save.cs.base + svm->vmcb->save.rip;
@@ -1962,14 +1969,14 @@
 	return 0;
 }
 
-static int ud_interception(struct vcpu_svm *svm)
+static int ud_interception(struct kvm_vcpu *vcpu)
 {
-	return handle_ud(&svm->vcpu);
+	return handle_ud(vcpu);
 }
 
-static int ac_interception(struct vcpu_svm *svm)
+static int ac_interception(struct kvm_vcpu *vcpu)
 {
-	kvm_queue_exception_e(&svm->vcpu, AC_VECTOR, 0);
+	kvm_queue_exception_e(vcpu, AC_VECTOR, 0);
 	return 1;
 }
 
@@ -2012,7 +2019,7 @@
 	return true;
 }
 
-static void svm_handle_mce(struct vcpu_svm *svm)
+static void svm_handle_mce(struct kvm_vcpu *vcpu)
 {
 	if (is_erratum_383()) {
 		/*
@@ -2021,7 +2028,7 @@
 		 */
 		pr_err("KVM: Guest triggered AMD Erratum 383\n");
 
-		kvm_make_request(KVM_REQ_TRIPLE_FAULT, &svm->vcpu);
+		kvm_make_request(KVM_REQ_TRIPLE_FAULT, vcpu);
 
 		return;
 	}
@@ -2033,20 +2040,21 @@
 	kvm_machine_check();
 }
 
-static int mc_interception(struct vcpu_svm *svm)
+static int mc_interception(struct kvm_vcpu *vcpu)
 {
 	return 1;
 }
 
-static int shutdown_interception(struct vcpu_svm *svm)
+static int shutdown_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_run *kvm_run = svm->vcpu.run;
+	struct kvm_run *kvm_run = vcpu->run;
+	struct vcpu_svm *svm = to_svm(vcpu);
 
 	/*
 	 * The VM save area has already been encrypted so it
 	 * cannot be reinitialized - just terminate.
 	 */
-	if (sev_es_guest(svm->vcpu.kvm))
+	if (sev_es_guest(vcpu->kvm))
 		return -EINVAL;
 
 	/*
@@ -2054,20 +2062,20 @@
 	 * so reinitialize it.
 	 */
 	clear_page(svm->vmcb);
-	init_vmcb(svm);
+	init_vmcb(vcpu);
 
 	kvm_run->exit_reason = KVM_EXIT_SHUTDOWN;
 	return 0;
 }
 
-static int io_interception(struct vcpu_svm *svm)
+static int io_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_vcpu *vcpu = &svm->vcpu;
+	struct vcpu_svm *svm = to_svm(vcpu);
 	u32 io_info = svm->vmcb->control.exit_info_1; /* address size bug? */
 	int size, in, string;
 	unsigned port;
 
-	++svm->vcpu.stat.io_exits;
+	++vcpu->stat.io_exits;
 	string = (io_info & SVM_IOIO_STR_MASK) != 0;
 	in = (io_info & SVM_IOIO_TYPE_MASK) != 0;
 	port = io_info >> 16;
@@ -2082,93 +2090,69 @@
 
 	svm->next_rip = svm->vmcb->control.exit_info_2;
 
-	return kvm_fast_pio(&svm->vcpu, size, port, in);
+	return kvm_fast_pio(vcpu, size, port, in);
 }
 
-static int nmi_interception(struct vcpu_svm *svm)
+static int nmi_interception(struct kvm_vcpu *vcpu)
 {
 	return 1;
 }
 
-static int intr_interception(struct vcpu_svm *svm)
+static int intr_interception(struct kvm_vcpu *vcpu)
 {
-	++svm->vcpu.stat.irq_exits;
+	++vcpu->stat.irq_exits;
 	return 1;
 }
 
-static int nop_on_interception(struct vcpu_svm *svm)
+static int vmload_vmsave_interception(struct kvm_vcpu *vcpu, bool vmload)
 {
-	return 1;
-}
-
-static int halt_interception(struct vcpu_svm *svm)
-{
-	return kvm_emulate_halt(&svm->vcpu);
-}
-
-static int vmmcall_interception(struct vcpu_svm *svm)
-{
-	return kvm_emulate_hypercall(&svm->vcpu);
-}
-
-static int vmload_interception(struct vcpu_svm *svm)
-{
-	struct vmcb *nested_vmcb;
+	struct vcpu_svm *svm = to_svm(vcpu);
+	struct vmcb *vmcb12;
 	struct kvm_host_map map;
 	int ret;
 
-	if (nested_svm_check_permissions(svm))
+	if (nested_svm_check_permissions(vcpu))
 		return 1;
 
-	ret = kvm_vcpu_map(&svm->vcpu, gpa_to_gfn(svm->vmcb->save.rax), &map);
+	ret = kvm_vcpu_map(vcpu, gpa_to_gfn(svm->vmcb->save.rax), &map);
 	if (ret) {
 		if (ret == -EINVAL)
-			kvm_inject_gp(&svm->vcpu, 0);
+			kvm_inject_gp(vcpu, 0);
 		return 1;
 	}
 
-	nested_vmcb = map.hva;
+	vmcb12 = map.hva;
 
-	ret = kvm_skip_emulated_instruction(&svm->vcpu);
+	ret = kvm_skip_emulated_instruction(vcpu);
 
-	nested_svm_vmloadsave(nested_vmcb, svm->vmcb);
-	kvm_vcpu_unmap(&svm->vcpu, &map, true);
+	if (vmload) {
+		nested_svm_vmloadsave(vmcb12, svm->vmcb);
+		svm->sysenter_eip_hi = 0;
+		svm->sysenter_esp_hi = 0;
+	} else
+		nested_svm_vmloadsave(svm->vmcb, vmcb12);
+
+	kvm_vcpu_unmap(vcpu, &map, true);
 
 	return ret;
 }
 
-static int vmsave_interception(struct vcpu_svm *svm)
+static int vmload_interception(struct kvm_vcpu *vcpu)
 {
-	struct vmcb *nested_vmcb;
-	struct kvm_host_map map;
-	int ret;
-
-	if (nested_svm_check_permissions(svm))
-		return 1;
-
-	ret = kvm_vcpu_map(&svm->vcpu, gpa_to_gfn(svm->vmcb->save.rax), &map);
-	if (ret) {
-		if (ret == -EINVAL)
-			kvm_inject_gp(&svm->vcpu, 0);
-		return 1;
-	}
-
-	nested_vmcb = map.hva;
-
-	ret = kvm_skip_emulated_instruction(&svm->vcpu);
-
-	nested_svm_vmloadsave(svm->vmcb, nested_vmcb);
-	kvm_vcpu_unmap(&svm->vcpu, &map, true);
-
-	return ret;
+	return vmload_vmsave_interception(vcpu, true);
 }
 
-static int vmrun_interception(struct vcpu_svm *svm)
+static int vmsave_interception(struct kvm_vcpu *vcpu)
 {
-	if (nested_svm_check_permissions(svm))
+	return vmload_vmsave_interception(vcpu, false);
+}
+
+static int vmrun_interception(struct kvm_vcpu *vcpu)
+{
+	if (nested_svm_check_permissions(vcpu))
 		return 1;
 
-	return nested_svm_vmrun(svm);
+	return nested_svm_vmrun(vcpu);
 }
 
 enum {
@@ -2207,7 +2191,7 @@
 		[SVM_INSTR_VMLOAD] = SVM_EXIT_VMLOAD,
 		[SVM_INSTR_VMSAVE] = SVM_EXIT_VMSAVE,
 	};
-	int (*const svm_instr_handlers[])(struct vcpu_svm *svm) = {
+	int (*const svm_instr_handlers[])(struct kvm_vcpu *vcpu) = {
 		[SVM_INSTR_VMRUN] = vmrun_interception,
 		[SVM_INSTR_VMLOAD] = vmload_interception,
 		[SVM_INSTR_VMSAVE] = vmsave_interception,
@@ -2216,17 +2200,13 @@
 	int ret;
 
 	if (is_guest_mode(vcpu)) {
-		svm->vmcb->control.exit_code = guest_mode_exit_codes[opcode];
-		svm->vmcb->control.exit_info_1 = 0;
-		svm->vmcb->control.exit_info_2 = 0;
-
 		/* Returns '1' or -errno on failure, '0' on success. */
-		ret = nested_svm_vmexit(svm);
+		ret = nested_svm_simple_vmexit(svm, guest_mode_exit_codes[opcode]);
 		if (ret)
 			return ret;
 		return 1;
 	}
-	return svm_instr_handlers[opcode](svm);
+	return svm_instr_handlers[opcode](vcpu);
 }
 
 /*
@@ -2237,9 +2217,9 @@
  *      regions (e.g. SMM memory on host).
  *   2) VMware backdoor
  */
-static int gp_interception(struct vcpu_svm *svm)
+static int gp_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_vcpu *vcpu = &svm->vcpu;
+	struct vcpu_svm *svm = to_svm(vcpu);
 	u32 error_code = svm->vmcb->control.exit_info_1;
 	int opcode;
 
@@ -2304,73 +2284,52 @@
 	}
 }
 
-static int stgi_interception(struct vcpu_svm *svm)
+static int stgi_interception(struct kvm_vcpu *vcpu)
 {
 	int ret;
 
-	if (nested_svm_check_permissions(svm))
+	if (nested_svm_check_permissions(vcpu))
 		return 1;
 
-	ret = kvm_skip_emulated_instruction(&svm->vcpu);
-	svm_set_gif(svm, true);
+	ret = kvm_skip_emulated_instruction(vcpu);
+	svm_set_gif(to_svm(vcpu), true);
 	return ret;
 }
 
-static int clgi_interception(struct vcpu_svm *svm)
+static int clgi_interception(struct kvm_vcpu *vcpu)
 {
 	int ret;
 
-	if (nested_svm_check_permissions(svm))
+	if (nested_svm_check_permissions(vcpu))
 		return 1;
 
-	ret = kvm_skip_emulated_instruction(&svm->vcpu);
-	svm_set_gif(svm, false);
+	ret = kvm_skip_emulated_instruction(vcpu);
+	svm_set_gif(to_svm(vcpu), false);
 	return ret;
 }
 
-static int invlpga_interception(struct vcpu_svm *svm)
+static int invlpga_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_vcpu *vcpu = &svm->vcpu;
-
-	trace_kvm_invlpga(svm->vmcb->save.rip, kvm_rcx_read(&svm->vcpu),
-			  kvm_rax_read(&svm->vcpu));
+	trace_kvm_invlpga(to_svm(vcpu)->vmcb->save.rip, kvm_rcx_read(vcpu),
+			  kvm_rax_read(vcpu));
 
 	/* Let's treat INVLPGA the same as INVLPG (can be optimized!) */
-	kvm_mmu_invlpg(vcpu, kvm_rax_read(&svm->vcpu));
+	kvm_mmu_invlpg(vcpu, kvm_rax_read(vcpu));
 
-	return kvm_skip_emulated_instruction(&svm->vcpu);
+	return kvm_skip_emulated_instruction(vcpu);
 }
 
-static int skinit_interception(struct vcpu_svm *svm)
+static int skinit_interception(struct kvm_vcpu *vcpu)
 {
-	trace_kvm_skinit(svm->vmcb->save.rip, kvm_rax_read(&svm->vcpu));
+	trace_kvm_skinit(to_svm(vcpu)->vmcb->save.rip, kvm_rax_read(vcpu));
 
-	kvm_queue_exception(&svm->vcpu, UD_VECTOR);
+	kvm_queue_exception(vcpu, UD_VECTOR);
 	return 1;
 }
 
-static int wbinvd_interception(struct vcpu_svm *svm)
+static int task_switch_interception(struct kvm_vcpu *vcpu)
 {
-	return kvm_emulate_wbinvd(&svm->vcpu);
-}
-
-static int xsetbv_interception(struct vcpu_svm *svm)
-{
-	u64 new_bv = kvm_read_edx_eax(&svm->vcpu);
-	u32 index = kvm_rcx_read(&svm->vcpu);
-
-	int err = kvm_set_xcr(&svm->vcpu, index, new_bv);
-	return kvm_complete_insn_gp(&svm->vcpu, err);
-}
-
-static int rdpru_interception(struct vcpu_svm *svm)
-{
-	kvm_queue_exception(&svm->vcpu, UD_VECTOR);
-	return 1;
-}
-
-static int task_switch_interception(struct vcpu_svm *svm)
-{
+	struct vcpu_svm *svm = to_svm(vcpu);
 	u16 tss_selector;
 	int reason;
 	int int_type = svm->vmcb->control.exit_int_info &
@@ -2399,7 +2358,7 @@
 	if (reason == TASK_SWITCH_GATE) {
 		switch (type) {
 		case SVM_EXITINTINFO_TYPE_NMI:
-			svm->vcpu.arch.nmi_injected = false;
+			vcpu->arch.nmi_injected = false;
 			break;
 		case SVM_EXITINTINFO_TYPE_EXEPT:
 			if (svm->vmcb->control.exit_info_2 &
@@ -2408,10 +2367,10 @@
 				error_code =
 					(u32)svm->vmcb->control.exit_info_2;
 			}
-			kvm_clear_exception_queue(&svm->vcpu);
+			kvm_clear_exception_queue(vcpu);
 			break;
 		case SVM_EXITINTINFO_TYPE_INTR:
-			kvm_clear_interrupt_queue(&svm->vcpu);
+			kvm_clear_interrupt_queue(vcpu);
 			break;
 		default:
 			break;
@@ -2422,77 +2381,58 @@
 	    int_type == SVM_EXITINTINFO_TYPE_SOFT ||
 	    (int_type == SVM_EXITINTINFO_TYPE_EXEPT &&
 	     (int_vec == OF_VECTOR || int_vec == BP_VECTOR))) {
-		if (!skip_emulated_instruction(&svm->vcpu))
+		if (!skip_emulated_instruction(vcpu))
 			return 0;
 	}
 
 	if (int_type != SVM_EXITINTINFO_TYPE_SOFT)
 		int_vec = -1;
 
-	return kvm_task_switch(&svm->vcpu, tss_selector, int_vec, reason,
+	return kvm_task_switch(vcpu, tss_selector, int_vec, reason,
 			       has_error_code, error_code);
 }
 
-static int cpuid_interception(struct vcpu_svm *svm)
+static int iret_interception(struct kvm_vcpu *vcpu)
 {
-	return kvm_emulate_cpuid(&svm->vcpu);
-}
+	struct vcpu_svm *svm = to_svm(vcpu);
 
-static int iret_interception(struct vcpu_svm *svm)
-{
-	++svm->vcpu.stat.nmi_window_exits;
-	svm->vcpu.arch.hflags |= HF_IRET_MASK;
-	if (!sev_es_guest(svm->vcpu.kvm)) {
+	++vcpu->stat.nmi_window_exits;
+	vcpu->arch.hflags |= HF_IRET_MASK;
+	if (!sev_es_guest(vcpu->kvm)) {
 		svm_clr_intercept(svm, INTERCEPT_IRET);
-		svm->nmi_iret_rip = kvm_rip_read(&svm->vcpu);
+		svm->nmi_iret_rip = kvm_rip_read(vcpu);
 	}
-	kvm_make_request(KVM_REQ_EVENT, &svm->vcpu);
+	kvm_make_request(KVM_REQ_EVENT, vcpu);
 	return 1;
 }
 
-static int invd_interception(struct vcpu_svm *svm)
-{
-	/* Treat an INVD instruction as a NOP and just skip it. */
-	return kvm_skip_emulated_instruction(&svm->vcpu);
-}
-
-static int invlpg_interception(struct vcpu_svm *svm)
+static int invlpg_interception(struct kvm_vcpu *vcpu)
 {
 	if (!static_cpu_has(X86_FEATURE_DECODEASSISTS))
-		return kvm_emulate_instruction(&svm->vcpu, 0);
+		return kvm_emulate_instruction(vcpu, 0);
 
-	kvm_mmu_invlpg(&svm->vcpu, svm->vmcb->control.exit_info_1);
-	return kvm_skip_emulated_instruction(&svm->vcpu);
+	kvm_mmu_invlpg(vcpu, to_svm(vcpu)->vmcb->control.exit_info_1);
+	return kvm_skip_emulated_instruction(vcpu);
 }
 
-static int emulate_on_interception(struct vcpu_svm *svm)
+static int emulate_on_interception(struct kvm_vcpu *vcpu)
 {
-	return kvm_emulate_instruction(&svm->vcpu, 0);
+	return kvm_emulate_instruction(vcpu, 0);
 }
 
-static int rsm_interception(struct vcpu_svm *svm)
+static int rsm_interception(struct kvm_vcpu *vcpu)
 {
-	return kvm_emulate_instruction_from_buffer(&svm->vcpu, rsm_ins_bytes, 2);
+	return kvm_emulate_instruction_from_buffer(vcpu, rsm_ins_bytes, 2);
 }
 
-static int rdpmc_interception(struct vcpu_svm *svm)
-{
-	int err;
-
-	if (!nrips)
-		return emulate_on_interception(svm);
-
-	err = kvm_rdpmc(&svm->vcpu);
-	return kvm_complete_insn_gp(&svm->vcpu, err);
-}
-
-static bool check_selective_cr0_intercepted(struct vcpu_svm *svm,
+static bool check_selective_cr0_intercepted(struct kvm_vcpu *vcpu,
 					    unsigned long val)
 {
-	unsigned long cr0 = svm->vcpu.arch.cr0;
+	struct vcpu_svm *svm = to_svm(vcpu);
+	unsigned long cr0 = vcpu->arch.cr0;
 	bool ret = false;
 
-	if (!is_guest_mode(&svm->vcpu) ||
+	if (!is_guest_mode(vcpu) ||
 	    (!(vmcb_is_intercept(&svm->nested.ctl, INTERCEPT_SELECTIVE_CR0))))
 		return false;
 
@@ -2509,17 +2449,18 @@
 
 #define CR_VALID (1ULL << 63)
 
-static int cr_interception(struct vcpu_svm *svm)
+static int cr_interception(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	int reg, cr;
 	unsigned long val;
 	int err;
 
 	if (!static_cpu_has(X86_FEATURE_DECODEASSISTS))
-		return emulate_on_interception(svm);
+		return emulate_on_interception(vcpu);
 
 	if (unlikely((svm->vmcb->control.exit_info_1 & CR_VALID) == 0))
-		return emulate_on_interception(svm);
+		return emulate_on_interception(vcpu);
 
 	reg = svm->vmcb->control.exit_info_1 & SVM_EXITINFO_REG_MASK;
 	if (svm->vmcb->control.exit_code == SVM_EXIT_CR0_SEL_WRITE)
@@ -2530,61 +2471,61 @@
 	err = 0;
 	if (cr >= 16) { /* mov to cr */
 		cr -= 16;
-		val = kvm_register_read(&svm->vcpu, reg);
+		val = kvm_register_read(vcpu, reg);
 		trace_kvm_cr_write(cr, val);
 		switch (cr) {
 		case 0:
-			if (!check_selective_cr0_intercepted(svm, val))
-				err = kvm_set_cr0(&svm->vcpu, val);
+			if (!check_selective_cr0_intercepted(vcpu, val))
+				err = kvm_set_cr0(vcpu, val);
 			else
 				return 1;
 
 			break;
 		case 3:
-			err = kvm_set_cr3(&svm->vcpu, val);
+			err = kvm_set_cr3(vcpu, val);
 			break;
 		case 4:
-			err = kvm_set_cr4(&svm->vcpu, val);
+			err = kvm_set_cr4(vcpu, val);
 			break;
 		case 8:
-			err = kvm_set_cr8(&svm->vcpu, val);
+			err = kvm_set_cr8(vcpu, val);
 			break;
 		default:
 			WARN(1, "unhandled write to CR%d", cr);
-			kvm_queue_exception(&svm->vcpu, UD_VECTOR);
+			kvm_queue_exception(vcpu, UD_VECTOR);
 			return 1;
 		}
 	} else { /* mov from cr */
 		switch (cr) {
 		case 0:
-			val = kvm_read_cr0(&svm->vcpu);
+			val = kvm_read_cr0(vcpu);
 			break;
 		case 2:
-			val = svm->vcpu.arch.cr2;
+			val = vcpu->arch.cr2;
 			break;
 		case 3:
-			val = kvm_read_cr3(&svm->vcpu);
+			val = kvm_read_cr3(vcpu);
 			break;
 		case 4:
-			val = kvm_read_cr4(&svm->vcpu);
+			val = kvm_read_cr4(vcpu);
 			break;
 		case 8:
-			val = kvm_get_cr8(&svm->vcpu);
+			val = kvm_get_cr8(vcpu);
 			break;
 		default:
 			WARN(1, "unhandled read from CR%d", cr);
-			kvm_queue_exception(&svm->vcpu, UD_VECTOR);
+			kvm_queue_exception(vcpu, UD_VECTOR);
 			return 1;
 		}
-		kvm_register_write(&svm->vcpu, reg, val);
+		kvm_register_write(vcpu, reg, val);
 		trace_kvm_cr_read(cr, val);
 	}
-	return kvm_complete_insn_gp(&svm->vcpu, err);
+	return kvm_complete_insn_gp(vcpu, err);
 }
 
-static int cr_trap(struct vcpu_svm *svm)
+static int cr_trap(struct kvm_vcpu *vcpu)
 {
-	struct kvm_vcpu *vcpu = &svm->vcpu;
+	struct vcpu_svm *svm = to_svm(vcpu);
 	unsigned long old_value, new_value;
 	unsigned int cr;
 	int ret = 0;
@@ -2606,7 +2547,7 @@
 		kvm_post_set_cr4(vcpu, old_value, new_value);
 		break;
 	case 8:
-		ret = kvm_set_cr8(&svm->vcpu, new_value);
+		ret = kvm_set_cr8(vcpu, new_value);
 		break;
 	default:
 		WARN(1, "unhandled CR%d write trap", cr);
@@ -2617,57 +2558,57 @@
 	return kvm_complete_insn_gp(vcpu, ret);
 }
 
-static int dr_interception(struct vcpu_svm *svm)
+static int dr_interception(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	int reg, dr;
 	unsigned long val;
 	int err = 0;
 
-	if (svm->vcpu.guest_debug == 0) {
+	if (vcpu->guest_debug == 0) {
 		/*
 		 * No more DR vmexits; force a reload of the debug registers
 		 * and reenter on this instruction.  The next vmexit will
 		 * retrieve the full state of the debug registers.
 		 */
 		clr_dr_intercepts(svm);
-		svm->vcpu.arch.switch_db_regs |= KVM_DEBUGREG_WONT_EXIT;
+		vcpu->arch.switch_db_regs |= KVM_DEBUGREG_WONT_EXIT;
 		return 1;
 	}
 
 	if (!boot_cpu_has(X86_FEATURE_DECODEASSISTS))
-		return emulate_on_interception(svm);
+		return emulate_on_interception(vcpu);
 
 	reg = svm->vmcb->control.exit_info_1 & SVM_EXITINFO_REG_MASK;
 	dr = svm->vmcb->control.exit_code - SVM_EXIT_READ_DR0;
 	if (dr >= 16) { /* mov to DRn  */
 		dr -= 16;
-		val = kvm_register_read(&svm->vcpu, reg);
-		err = kvm_set_dr(&svm->vcpu, dr, val);
+		val = kvm_register_read(vcpu, reg);
+		err = kvm_set_dr(vcpu, dr, val);
 	} else {
-		kvm_get_dr(&svm->vcpu, dr, &val);
-		kvm_register_write(&svm->vcpu, reg, val);
+		kvm_get_dr(vcpu, dr, &val);
+		kvm_register_write(vcpu, reg, val);
 	}
 
-	return kvm_complete_insn_gp(&svm->vcpu, err);
+	return kvm_complete_insn_gp(vcpu, err);
 }
 
-static int cr8_write_interception(struct vcpu_svm *svm)
+static int cr8_write_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_run *kvm_run = svm->vcpu.run;
 	int r;
 
-	u8 cr8_prev = kvm_get_cr8(&svm->vcpu);
+	u8 cr8_prev = kvm_get_cr8(vcpu);
 	/* instruction emulation calls kvm_set_cr8() */
-	r = cr_interception(svm);
-	if (lapic_in_kernel(&svm->vcpu))
+	r = cr_interception(vcpu);
+	if (lapic_in_kernel(vcpu))
 		return r;
-	if (cr8_prev <= kvm_get_cr8(&svm->vcpu))
+	if (cr8_prev <= kvm_get_cr8(vcpu))
 		return r;
-	kvm_run->exit_reason = KVM_EXIT_SET_TPR;
+	vcpu->run->exit_reason = KVM_EXIT_SET_TPR;
 	return 0;
 }
 
-static int efer_trap(struct vcpu_svm *svm)
+static int efer_trap(struct kvm_vcpu *vcpu)
 {
 	struct msr_data msr_info;
 	int ret;
@@ -2680,10 +2621,10 @@
 	 */
 	msr_info.host_initiated = false;
 	msr_info.index = MSR_EFER;
-	msr_info.data = svm->vmcb->control.exit_info_1 & ~EFER_SVME;
-	ret = kvm_set_msr_common(&svm->vcpu, &msr_info);
+	msr_info.data = to_svm(vcpu)->vmcb->control.exit_info_1 & ~EFER_SVME;
+	ret = kvm_set_msr_common(vcpu, &msr_info);
 
-	return kvm_complete_insn_gp(&svm->vcpu, ret);
+	return kvm_complete_insn_gp(vcpu, ret);
 }
 
 static int svm_get_msr_feature(struct kvm_msr_entry *msr)
@@ -2710,30 +2651,34 @@
 
 	switch (msr_info->index) {
 	case MSR_STAR:
-		msr_info->data = svm->vmcb->save.star;
+		msr_info->data = svm->vmcb01.ptr->save.star;
 		break;
 #ifdef CONFIG_X86_64
 	case MSR_LSTAR:
-		msr_info->data = svm->vmcb->save.lstar;
+		msr_info->data = svm->vmcb01.ptr->save.lstar;
 		break;
 	case MSR_CSTAR:
-		msr_info->data = svm->vmcb->save.cstar;
+		msr_info->data = svm->vmcb01.ptr->save.cstar;
 		break;
 	case MSR_KERNEL_GS_BASE:
-		msr_info->data = svm->vmcb->save.kernel_gs_base;
+		msr_info->data = svm->vmcb01.ptr->save.kernel_gs_base;
 		break;
 	case MSR_SYSCALL_MASK:
-		msr_info->data = svm->vmcb->save.sfmask;
+		msr_info->data = svm->vmcb01.ptr->save.sfmask;
 		break;
 #endif
 	case MSR_IA32_SYSENTER_CS:
-		msr_info->data = svm->vmcb->save.sysenter_cs;
+		msr_info->data = svm->vmcb01.ptr->save.sysenter_cs;
 		break;
 	case MSR_IA32_SYSENTER_EIP:
-		msr_info->data = svm->sysenter_eip;
+		msr_info->data = (u32)svm->vmcb01.ptr->save.sysenter_eip;
+		if (guest_cpuid_is_intel(vcpu))
+			msr_info->data |= (u64)svm->sysenter_eip_hi << 32;
 		break;
 	case MSR_IA32_SYSENTER_ESP:
-		msr_info->data = svm->sysenter_esp;
+		msr_info->data = svm->vmcb01.ptr->save.sysenter_esp;
+		if (guest_cpuid_is_intel(vcpu))
+			msr_info->data |= (u64)svm->sysenter_esp_hi << 32;
 		break;
 	case MSR_TSC_AUX:
 		if (!boot_cpu_has(X86_FEATURE_RDTSCP))
@@ -2771,7 +2716,10 @@
 		    !guest_has_spec_ctrl_msr(vcpu))
 			return 1;
 
-		msr_info->data = svm->spec_ctrl;
+		if (boot_cpu_has(X86_FEATURE_V_SPEC_CTRL))
+			msr_info->data = svm->vmcb->save.spec_ctrl;
+		else
+			msr_info->data = svm->spec_ctrl;
 		break;
 	case MSR_AMD64_VIRT_SPEC_CTRL:
 		if (!msr_info->host_initiated &&
@@ -2809,8 +2757,8 @@
 static int svm_complete_emulated_msr(struct kvm_vcpu *vcpu, int err)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
-	if (!sev_es_guest(svm->vcpu.kvm) || !err)
-		return kvm_complete_insn_gp(&svm->vcpu, err);
+	if (!sev_es_guest(vcpu->kvm) || !err)
+		return kvm_complete_insn_gp(vcpu, err);
 
 	ghcb_set_sw_exit_info_1(svm->ghcb, 1);
 	ghcb_set_sw_exit_info_2(svm->ghcb,
@@ -2820,11 +2768,6 @@
 	return 1;
 }
 
-static int rdmsr_interception(struct vcpu_svm *svm)
-{
-	return kvm_emulate_rdmsr(&svm->vcpu);
-}
-
 static int svm_set_vm_cr(struct kvm_vcpu *vcpu, u64 data)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
@@ -2861,7 +2804,9 @@
 		if (!kvm_mtrr_valid(vcpu, MSR_IA32_CR_PAT, data))
 			return 1;
 		vcpu->arch.pat = data;
-		svm->vmcb->save.g_pat = data;
+		svm->vmcb01.ptr->save.g_pat = data;
+		if (is_guest_mode(vcpu))
+			nested_vmcb02_compute_g_pat(svm);
 		vmcb_mark_dirty(svm->vmcb, VMCB_NPT);
 		break;
 	case MSR_IA32_SPEC_CTRL:
@@ -2872,7 +2817,10 @@
 		if (kvm_spec_ctrl_test_value(data))
 			return 1;
 
-		svm->spec_ctrl = data;
+		if (boot_cpu_has(X86_FEATURE_V_SPEC_CTRL))
+			svm->vmcb->save.spec_ctrl = data;
+		else
+			svm->spec_ctrl = data;
 		if (!data)
 			break;
 
@@ -2915,32 +2863,39 @@
 		svm->virt_spec_ctrl = data;
 		break;
 	case MSR_STAR:
-		svm->vmcb->save.star = data;
+		svm->vmcb01.ptr->save.star = data;
 		break;
 #ifdef CONFIG_X86_64
 	case MSR_LSTAR:
-		svm->vmcb->save.lstar = data;
+		svm->vmcb01.ptr->save.lstar = data;
 		break;
 	case MSR_CSTAR:
-		svm->vmcb->save.cstar = data;
+		svm->vmcb01.ptr->save.cstar = data;
 		break;
 	case MSR_KERNEL_GS_BASE:
-		svm->vmcb->save.kernel_gs_base = data;
+		svm->vmcb01.ptr->save.kernel_gs_base = data;
 		break;
 	case MSR_SYSCALL_MASK:
-		svm->vmcb->save.sfmask = data;
+		svm->vmcb01.ptr->save.sfmask = data;
 		break;
 #endif
 	case MSR_IA32_SYSENTER_CS:
-		svm->vmcb->save.sysenter_cs = data;
+		svm->vmcb01.ptr->save.sysenter_cs = data;
 		break;
 	case MSR_IA32_SYSENTER_EIP:
-		svm->sysenter_eip = data;
-		svm->vmcb->save.sysenter_eip = data;
+		svm->vmcb01.ptr->save.sysenter_eip = (u32)data;
+		/*
+		 * We only intercept the MSR_IA32_SYSENTER_{EIP|ESP} msrs
+		 * when we spoof an Intel vendor ID (for cross vendor migration).
+		 * In this case we use this intercept to track the high
+		 * 32 bit part of these msrs to support Intel's
+		 * implementation of SYSENTER/SYSEXIT.
+		 */
+		svm->sysenter_eip_hi = guest_cpuid_is_intel(vcpu) ? (data >> 32) : 0;
 		break;
 	case MSR_IA32_SYSENTER_ESP:
-		svm->sysenter_esp = data;
-		svm->vmcb->save.sysenter_esp = data;
+		svm->vmcb01.ptr->save.sysenter_esp = (u32)data;
+		svm->sysenter_esp_hi = guest_cpuid_is_intel(vcpu) ? (data >> 32) : 0;
 		break;
 	case MSR_TSC_AUX:
 		if (!boot_cpu_has(X86_FEATURE_RDTSCP))
@@ -3006,38 +2961,32 @@
 	return 0;
 }
 
-static int wrmsr_interception(struct vcpu_svm *svm)
+static int msr_interception(struct kvm_vcpu *vcpu)
 {
-	return kvm_emulate_wrmsr(&svm->vcpu);
-}
-
-static int msr_interception(struct vcpu_svm *svm)
-{
-	if (svm->vmcb->control.exit_info_1)
-		return wrmsr_interception(svm);
+	if (to_svm(vcpu)->vmcb->control.exit_info_1)
+		return kvm_emulate_wrmsr(vcpu);
 	else
-		return rdmsr_interception(svm);
+		return kvm_emulate_rdmsr(vcpu);
 }
 
-static int interrupt_window_interception(struct vcpu_svm *svm)
+static int interrupt_window_interception(struct kvm_vcpu *vcpu)
 {
-	kvm_make_request(KVM_REQ_EVENT, &svm->vcpu);
-	svm_clear_vintr(svm);
+	kvm_make_request(KVM_REQ_EVENT, vcpu);
+	svm_clear_vintr(to_svm(vcpu));
 
 	/*
 	 * For AVIC, the only reason to end up here is ExtINTs.
 	 * In this case AVIC was temporarily disabled for
 	 * requesting the IRQ window and we have to re-enable it.
 	 */
-	svm_toggle_avic_for_irq_window(&svm->vcpu, true);
+	svm_toggle_avic_for_irq_window(vcpu, true);
 
-	++svm->vcpu.stat.irq_window_exits;
+	++vcpu->stat.irq_window_exits;
 	return 1;
 }
 
-static int pause_interception(struct vcpu_svm *svm)
+static int pause_interception(struct kvm_vcpu *vcpu)
 {
-	struct kvm_vcpu *vcpu = &svm->vcpu;
 	bool in_kernel;
 
 	/*
@@ -3045,35 +2994,18 @@
 	 * vcpu->arch.preempted_in_kernel can never be true.  Just
 	 * set in_kernel to false as well.
 	 */
-	in_kernel = !sev_es_guest(svm->vcpu.kvm) && svm_get_cpl(vcpu) == 0;
+	in_kernel = !sev_es_guest(vcpu->kvm) && svm_get_cpl(vcpu) == 0;
 
 	if (!kvm_pause_in_guest(vcpu->kvm))
 		grow_ple_window(vcpu);
 
 	kvm_vcpu_on_spin(vcpu, in_kernel);
-	return 1;
+	return kvm_skip_emulated_instruction(vcpu);
 }
 
-static int nop_interception(struct vcpu_svm *svm)
+static int invpcid_interception(struct kvm_vcpu *vcpu)
 {
-	return kvm_skip_emulated_instruction(&(svm->vcpu));
-}
-
-static int monitor_interception(struct vcpu_svm *svm)
-{
-	printk_once(KERN_WARNING "kvm: MONITOR instruction emulated as NOP!\n");
-	return nop_interception(svm);
-}
-
-static int mwait_interception(struct vcpu_svm *svm)
-{
-	printk_once(KERN_WARNING "kvm: MWAIT instruction emulated as NOP!\n");
-	return nop_interception(svm);
-}
-
-static int invpcid_interception(struct vcpu_svm *svm)
-{
-	struct kvm_vcpu *vcpu = &svm->vcpu;
+	struct vcpu_svm *svm = to_svm(vcpu);
 	unsigned long type;
 	gva_t gva;
 
@@ -3098,7 +3030,7 @@
 	return kvm_handle_invpcid(vcpu, type, gva);
 }
 
-static int (*const svm_exit_handlers[])(struct vcpu_svm *svm) = {
+static int (*const svm_exit_handlers[])(struct kvm_vcpu *vcpu) = {
 	[SVM_EXIT_READ_CR0]			= cr_interception,
 	[SVM_EXIT_READ_CR3]			= cr_interception,
 	[SVM_EXIT_READ_CR4]			= cr_interception,
@@ -3133,15 +3065,15 @@
 	[SVM_EXIT_EXCP_BASE + GP_VECTOR]	= gp_interception,
 	[SVM_EXIT_INTR]				= intr_interception,
 	[SVM_EXIT_NMI]				= nmi_interception,
-	[SVM_EXIT_SMI]				= nop_on_interception,
-	[SVM_EXIT_INIT]				= nop_on_interception,
+	[SVM_EXIT_SMI]				= kvm_emulate_as_nop,
+	[SVM_EXIT_INIT]				= kvm_emulate_as_nop,
 	[SVM_EXIT_VINTR]			= interrupt_window_interception,
-	[SVM_EXIT_RDPMC]			= rdpmc_interception,
-	[SVM_EXIT_CPUID]			= cpuid_interception,
+	[SVM_EXIT_RDPMC]			= kvm_emulate_rdpmc,
+	[SVM_EXIT_CPUID]			= kvm_emulate_cpuid,
 	[SVM_EXIT_IRET]                         = iret_interception,
-	[SVM_EXIT_INVD]                         = invd_interception,
+	[SVM_EXIT_INVD]                         = kvm_emulate_invd,
 	[SVM_EXIT_PAUSE]			= pause_interception,
-	[SVM_EXIT_HLT]				= halt_interception,
+	[SVM_EXIT_HLT]				= kvm_emulate_halt,
 	[SVM_EXIT_INVLPG]			= invlpg_interception,
 	[SVM_EXIT_INVLPGA]			= invlpga_interception,
 	[SVM_EXIT_IOIO]				= io_interception,
@@ -3149,17 +3081,17 @@
 	[SVM_EXIT_TASK_SWITCH]			= task_switch_interception,
 	[SVM_EXIT_SHUTDOWN]			= shutdown_interception,
 	[SVM_EXIT_VMRUN]			= vmrun_interception,
-	[SVM_EXIT_VMMCALL]			= vmmcall_interception,
+	[SVM_EXIT_VMMCALL]			= kvm_emulate_hypercall,
 	[SVM_EXIT_VMLOAD]			= vmload_interception,
 	[SVM_EXIT_VMSAVE]			= vmsave_interception,
 	[SVM_EXIT_STGI]				= stgi_interception,
 	[SVM_EXIT_CLGI]				= clgi_interception,
 	[SVM_EXIT_SKINIT]			= skinit_interception,
-	[SVM_EXIT_WBINVD]                       = wbinvd_interception,
-	[SVM_EXIT_MONITOR]			= monitor_interception,
-	[SVM_EXIT_MWAIT]			= mwait_interception,
-	[SVM_EXIT_XSETBV]			= xsetbv_interception,
-	[SVM_EXIT_RDPRU]			= rdpru_interception,
+	[SVM_EXIT_WBINVD]                       = kvm_emulate_wbinvd,
+	[SVM_EXIT_MONITOR]			= kvm_emulate_monitor,
+	[SVM_EXIT_MWAIT]			= kvm_emulate_mwait,
+	[SVM_EXIT_XSETBV]			= kvm_emulate_xsetbv,
+	[SVM_EXIT_RDPRU]			= kvm_handle_invalid_op,
 	[SVM_EXIT_EFER_WRITE_TRAP]		= efer_trap,
 	[SVM_EXIT_CR0_WRITE_TRAP]		= cr_trap,
 	[SVM_EXIT_CR4_WRITE_TRAP]		= cr_trap,
@@ -3177,6 +3109,7 @@
 	struct vcpu_svm *svm = to_svm(vcpu);
 	struct vmcb_control_area *control = &svm->vmcb->control;
 	struct vmcb_save_area *save = &svm->vmcb->save;
+	struct vmcb_save_area *save01 = &svm->vmcb01.ptr->save;
 
 	if (!dump_invalid_vmcb) {
 		pr_warn_ratelimited("set kvm_amd.dump_invalid_vmcb=1 to dump internal KVM state.\n");
@@ -3239,28 +3172,28 @@
 	       save->ds.limit, save->ds.base);
 	pr_err("%-5s s: %04x a: %04x l: %08x b: %016llx\n",
 	       "fs:",
-	       save->fs.selector, save->fs.attrib,
-	       save->fs.limit, save->fs.base);
+	       save01->fs.selector, save01->fs.attrib,
+	       save01->fs.limit, save01->fs.base);
 	pr_err("%-5s s: %04x a: %04x l: %08x b: %016llx\n",
 	       "gs:",
-	       save->gs.selector, save->gs.attrib,
-	       save->gs.limit, save->gs.base);
+	       save01->gs.selector, save01->gs.attrib,
+	       save01->gs.limit, save01->gs.base);
 	pr_err("%-5s s: %04x a: %04x l: %08x b: %016llx\n",
 	       "gdtr:",
 	       save->gdtr.selector, save->gdtr.attrib,
 	       save->gdtr.limit, save->gdtr.base);
 	pr_err("%-5s s: %04x a: %04x l: %08x b: %016llx\n",
 	       "ldtr:",
-	       save->ldtr.selector, save->ldtr.attrib,
-	       save->ldtr.limit, save->ldtr.base);
+	       save01->ldtr.selector, save01->ldtr.attrib,
+	       save01->ldtr.limit, save01->ldtr.base);
 	pr_err("%-5s s: %04x a: %04x l: %08x b: %016llx\n",
 	       "idtr:",
 	       save->idtr.selector, save->idtr.attrib,
 	       save->idtr.limit, save->idtr.base);
 	pr_err("%-5s s: %04x a: %04x l: %08x b: %016llx\n",
 	       "tr:",
-	       save->tr.selector, save->tr.attrib,
-	       save->tr.limit, save->tr.base);
+	       save01->tr.selector, save01->tr.attrib,
+	       save01->tr.limit, save01->tr.base);
 	pr_err("cpl:            %d                efer:         %016llx\n",
 		save->cpl, save->efer);
 	pr_err("%-15s %016llx %-13s %016llx\n",
@@ -3274,15 +3207,15 @@
 	pr_err("%-15s %016llx %-13s %016llx\n",
 	       "rsp:", save->rsp, "rax:", save->rax);
 	pr_err("%-15s %016llx %-13s %016llx\n",
-	       "star:", save->star, "lstar:", save->lstar);
+	       "star:", save01->star, "lstar:", save01->lstar);
 	pr_err("%-15s %016llx %-13s %016llx\n",
-	       "cstar:", save->cstar, "sfmask:", save->sfmask);
+	       "cstar:", save01->cstar, "sfmask:", save01->sfmask);
 	pr_err("%-15s %016llx %-13s %016llx\n",
-	       "kernel_gs_base:", save->kernel_gs_base,
-	       "sysenter_cs:", save->sysenter_cs);
+	       "kernel_gs_base:", save01->kernel_gs_base,
+	       "sysenter_cs:", save01->sysenter_cs);
 	pr_err("%-15s %016llx %-13s %016llx\n",
-	       "sysenter_esp:", save->sysenter_esp,
-	       "sysenter_eip:", save->sysenter_eip);
+	       "sysenter_esp:", save01->sysenter_esp,
+	       "sysenter_eip:", save01->sysenter_eip);
 	pr_err("%-15s %016llx %-13s %016llx\n",
 	       "gpat:", save->g_pat, "dbgctl:", save->dbgctl);
 	pr_err("%-15s %016llx %-13s %016llx\n",
@@ -3309,24 +3242,24 @@
 	return -EINVAL;
 }
 
-int svm_invoke_exit_handler(struct vcpu_svm *svm, u64 exit_code)
+int svm_invoke_exit_handler(struct kvm_vcpu *vcpu, u64 exit_code)
 {
-	if (svm_handle_invalid_exit(&svm->vcpu, exit_code))
+	if (svm_handle_invalid_exit(vcpu, exit_code))
 		return 0;
 
 #ifdef CONFIG_RETPOLINE
 	if (exit_code == SVM_EXIT_MSR)
-		return msr_interception(svm);
+		return msr_interception(vcpu);
 	else if (exit_code == SVM_EXIT_VINTR)
-		return interrupt_window_interception(svm);
+		return interrupt_window_interception(vcpu);
 	else if (exit_code == SVM_EXIT_INTR)
-		return intr_interception(svm);
+		return intr_interception(vcpu);
 	else if (exit_code == SVM_EXIT_HLT)
-		return halt_interception(svm);
+		return kvm_emulate_halt(vcpu);
 	else if (exit_code == SVM_EXIT_NPF)
-		return npf_interception(svm);
+		return npf_interception(vcpu);
 #endif
-	return svm_exit_handlers[exit_code](svm);
+	return svm_exit_handlers[exit_code](vcpu);
 }
 
 static void svm_get_exit_info(struct kvm_vcpu *vcpu, u64 *info1, u64 *info2,
@@ -3395,7 +3328,7 @@
 	if (exit_fastpath != EXIT_FASTPATH_NONE)
 		return 1;
 
-	return svm_invoke_exit_handler(svm, exit_code);
+	return svm_invoke_exit_handler(vcpu, exit_code);
 }
 
 static void reload_tss(struct kvm_vcpu *vcpu)
@@ -3406,15 +3339,28 @@
 	load_TR_desc();
 }
 
-static void pre_svm_run(struct vcpu_svm *svm)
+static void pre_svm_run(struct kvm_vcpu *vcpu)
 {
-	struct svm_cpu_data *sd = per_cpu(svm_data, svm->vcpu.cpu);
+	struct svm_cpu_data *sd = per_cpu(svm_data, vcpu->cpu);
+	struct vcpu_svm *svm = to_svm(vcpu);
 
-	if (sev_guest(svm->vcpu.kvm))
-		return pre_sev_run(svm, svm->vcpu.cpu);
+	/*
+	 * If the previous vmrun of the vmcb occurred on
+	 * a different physical cpu then we must mark the vmcb dirty.
+	 * and assign a new asid.
+	*/
+
+	if (unlikely(svm->current_vmcb->cpu != vcpu->cpu)) {
+		svm->current_vmcb->asid_generation = 0;
+		vmcb_mark_all_dirty(svm->vmcb);
+		svm->current_vmcb->cpu = vcpu->cpu;
+        }
+
+	if (sev_guest(vcpu->kvm))
+		return pre_sev_run(svm, vcpu->cpu);
 
 	/* FIXME: handle wraparound of asid_generation */
-	if (svm->asid_generation != sd->asid_generation)
+	if (svm->current_vmcb->asid_generation != sd->asid_generation)
 		new_asid(svm, sd);
 }
 
@@ -3424,7 +3370,7 @@
 
 	svm->vmcb->control.event_inj = SVM_EVTINJ_VALID | SVM_EVTINJ_TYPE_NMI;
 	vcpu->arch.hflags |= HF_NMI_MASK;
-	if (!sev_es_guest(svm->vcpu.kvm))
+	if (!sev_es_guest(vcpu->kvm))
 		svm_set_intercept(svm, INTERCEPT_IRET);
 	++vcpu->stat.nmi_injections;
 }
@@ -3478,7 +3424,7 @@
 		return false;
 
 	ret = (vmcb->control.int_state & SVM_INTERRUPT_SHADOW_MASK) ||
-	      (svm->vcpu.arch.hflags & HF_NMI_MASK);
+	      (vcpu->arch.hflags & HF_NMI_MASK);
 
 	return ret;
 }
@@ -3498,9 +3444,7 @@
 
 static bool svm_get_nmi_mask(struct kvm_vcpu *vcpu)
 {
-	struct vcpu_svm *svm = to_svm(vcpu);
-
-	return !!(svm->vcpu.arch.hflags & HF_NMI_MASK);
+	return !!(vcpu->arch.hflags & HF_NMI_MASK);
 }
 
 static void svm_set_nmi_mask(struct kvm_vcpu *vcpu, bool masked)
@@ -3508,12 +3452,12 @@
 	struct vcpu_svm *svm = to_svm(vcpu);
 
 	if (masked) {
-		svm->vcpu.arch.hflags |= HF_NMI_MASK;
-		if (!sev_es_guest(svm->vcpu.kvm))
+		vcpu->arch.hflags |= HF_NMI_MASK;
+		if (!sev_es_guest(vcpu->kvm))
 			svm_set_intercept(svm, INTERCEPT_IRET);
 	} else {
-		svm->vcpu.arch.hflags &= ~HF_NMI_MASK;
-		if (!sev_es_guest(svm->vcpu.kvm))
+		vcpu->arch.hflags &= ~HF_NMI_MASK;
+		if (!sev_es_guest(vcpu->kvm))
 			svm_clr_intercept(svm, INTERCEPT_IRET);
 	}
 }
@@ -3526,7 +3470,7 @@
 	if (!gif_set(svm))
 		return true;
 
-	if (sev_es_guest(svm->vcpu.kvm)) {
+	if (sev_es_guest(vcpu->kvm)) {
 		/*
 		 * SEV-ES guests to not expose RFLAGS. Use the VMCB interrupt mask
 		 * bit to determine the state of the IF flag.
@@ -3536,7 +3480,7 @@
 	} else if (is_guest_mode(vcpu)) {
 		/* As long as interrupts are being delivered...  */
 		if ((svm->nested.ctl.int_ctl & V_INTR_MASKING_MASK)
-		    ? !(svm->nested.hsave->save.rflags & X86_EFLAGS_IF)
+		    ? !(svm->vmcb01.ptr->save.rflags & X86_EFLAGS_IF)
 		    : !(kvm_get_rflags(vcpu) & X86_EFLAGS_IF))
 			return true;
 
@@ -3595,8 +3539,7 @@
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
 
-	if ((svm->vcpu.arch.hflags & (HF_NMI_MASK | HF_IRET_MASK))
-	    == HF_NMI_MASK)
+	if ((vcpu->arch.hflags & (HF_NMI_MASK | HF_IRET_MASK)) == HF_NMI_MASK)
 		return; /* IRET will cause a vm exit */
 
 	if (!gif_set(svm)) {
@@ -3638,7 +3581,7 @@
 	if (static_cpu_has(X86_FEATURE_FLUSHBYASID))
 		svm->vmcb->control.tlb_ctl = TLB_CONTROL_FLUSH_ASID;
 	else
-		svm->asid_generation--;
+		svm->current_vmcb->asid_generation--;
 }
 
 static void svm_flush_tlb_gva(struct kvm_vcpu *vcpu, gva_t gva)
@@ -3675,8 +3618,9 @@
 	svm->vmcb->control.int_ctl |= cr8 & V_TPR_MASK;
 }
 
-static void svm_complete_interrupts(struct vcpu_svm *svm)
+static void svm_complete_interrupts(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
 	u8 vector;
 	int type;
 	u32 exitintinfo = svm->vmcb->control.exit_int_info;
@@ -3688,28 +3632,28 @@
 	 * If we've made progress since setting HF_IRET_MASK, we've
 	 * executed an IRET and can allow NMI injection.
 	 */
-	if ((svm->vcpu.arch.hflags & HF_IRET_MASK) &&
-	    (sev_es_guest(svm->vcpu.kvm) ||
-	     kvm_rip_read(&svm->vcpu) != svm->nmi_iret_rip)) {
-		svm->vcpu.arch.hflags &= ~(HF_NMI_MASK | HF_IRET_MASK);
-		kvm_make_request(KVM_REQ_EVENT, &svm->vcpu);
+	if ((vcpu->arch.hflags & HF_IRET_MASK) &&
+	    (sev_es_guest(vcpu->kvm) ||
+	     kvm_rip_read(vcpu) != svm->nmi_iret_rip)) {
+		vcpu->arch.hflags &= ~(HF_NMI_MASK | HF_IRET_MASK);
+		kvm_make_request(KVM_REQ_EVENT, vcpu);
 	}
 
-	svm->vcpu.arch.nmi_injected = false;
-	kvm_clear_exception_queue(&svm->vcpu);
-	kvm_clear_interrupt_queue(&svm->vcpu);
+	vcpu->arch.nmi_injected = false;
+	kvm_clear_exception_queue(vcpu);
+	kvm_clear_interrupt_queue(vcpu);
 
 	if (!(exitintinfo & SVM_EXITINTINFO_VALID))
 		return;
 
-	kvm_make_request(KVM_REQ_EVENT, &svm->vcpu);
+	kvm_make_request(KVM_REQ_EVENT, vcpu);
 
 	vector = exitintinfo & SVM_EXITINTINFO_VEC_MASK;
 	type = exitintinfo & SVM_EXITINTINFO_TYPE_MASK;
 
 	switch (type) {
 	case SVM_EXITINTINFO_TYPE_NMI:
-		svm->vcpu.arch.nmi_injected = true;
+		vcpu->arch.nmi_injected = true;
 		break;
 	case SVM_EXITINTINFO_TYPE_EXEPT:
 		/*
@@ -3725,21 +3669,20 @@
 		 */
 		if (kvm_exception_is_soft(vector)) {
 			if (vector == BP_VECTOR && int3_injected &&
-			    kvm_is_linear_rip(&svm->vcpu, svm->int3_rip))
-				kvm_rip_write(&svm->vcpu,
-					      kvm_rip_read(&svm->vcpu) -
-					      int3_injected);
+			    kvm_is_linear_rip(vcpu, svm->int3_rip))
+				kvm_rip_write(vcpu,
+					      kvm_rip_read(vcpu) - int3_injected);
 			break;
 		}
 		if (exitintinfo & SVM_EXITINTINFO_VALID_ERR) {
 			u32 err = svm->vmcb->control.exit_int_info_err;
-			kvm_requeue_exception_e(&svm->vcpu, vector, err);
+			kvm_requeue_exception_e(vcpu, vector, err);
 
 		} else
-			kvm_requeue_exception(&svm->vcpu, vector);
+			kvm_requeue_exception(vcpu, vector);
 		break;
 	case SVM_EXITINTINFO_TYPE_INTR:
-		kvm_queue_interrupt(&svm->vcpu, vector, false);
+		kvm_queue_interrupt(vcpu, vector, false);
 		break;
 	default:
 		break;
@@ -3754,7 +3697,7 @@
 	control->exit_int_info = control->event_inj;
 	control->exit_int_info_err = control->event_inj_err;
 	control->event_inj = 0;
-	svm_complete_interrupts(svm);
+	svm_complete_interrupts(vcpu);
 }
 
 static fastpath_t svm_exit_handlers_fastpath(struct kvm_vcpu *vcpu)
@@ -3766,9 +3709,10 @@
 	return EXIT_FASTPATH_NONE;
 }
 
-static noinstr void svm_vcpu_enter_exit(struct kvm_vcpu *vcpu,
-					struct vcpu_svm *svm)
+static noinstr void svm_vcpu_enter_exit(struct kvm_vcpu *vcpu)
 {
+	struct vcpu_svm *svm = to_svm(vcpu);
+
 	/*
 	 * VMENTER enables interrupts (host state), but the kernel state is
 	 * interrupts disabled when this is invoked. Also tell RCU about
@@ -3789,12 +3733,14 @@
 	guest_enter_irqoff();
 	lockdep_hardirqs_on(CALLER_ADDR0);
 
-	if (sev_es_guest(svm->vcpu.kvm)) {
+	if (sev_es_guest(vcpu->kvm)) {
 		__svm_sev_es_vcpu_run(svm->vmcb_pa);
 	} else {
 		struct svm_cpu_data *sd = per_cpu(svm_data, vcpu->cpu);
 
-		__svm_vcpu_run(svm->vmcb_pa, (unsigned long *)&svm->vcpu.arch.regs);
+		vmload(svm->vmcb01.pa);
+		__svm_vcpu_run(svm->vmcb_pa, (unsigned long *)&vcpu->arch.regs);
+		vmsave(svm->vmcb01.pa);
 
 		vmload(__sme_page_pa(sd->save_area));
 	}
@@ -3845,7 +3791,7 @@
 		smp_send_reschedule(vcpu->cpu);
 	}
 
-	pre_svm_run(svm);
+	pre_svm_run(vcpu);
 
 	sync_lapic_to_cr8(vcpu);
 
@@ -3859,7 +3805,7 @@
 	 * Run with all-zero DR6 unless needed, so that we can get the exact cause
 	 * of a #DB.
 	 */
-	if (unlikely(svm->vcpu.arch.switch_db_regs & KVM_DEBUGREG_WONT_EXIT))
+	if (unlikely(vcpu->arch.switch_db_regs & KVM_DEBUGREG_WONT_EXIT))
 		svm_set_dr6(svm, vcpu->arch.dr6);
 	else
 		svm_set_dr6(svm, DR6_ACTIVE_LOW);
@@ -3875,9 +3821,10 @@
 	 * is no need to worry about the conditional branch over the wrmsr
 	 * being speculatively taken.
 	 */
-	x86_spec_ctrl_set_guest(svm->spec_ctrl, svm->virt_spec_ctrl);
+	if (!static_cpu_has(X86_FEATURE_V_SPEC_CTRL))
+		x86_spec_ctrl_set_guest(svm->spec_ctrl, svm->virt_spec_ctrl);
 
-	svm_vcpu_enter_exit(vcpu, svm);
+	svm_vcpu_enter_exit(vcpu);
 
 	/*
 	 * We do not use IBRS in the kernel. If this vCPU has used the
@@ -3894,15 +3841,17 @@
 	 * If the L02 MSR bitmap does not intercept the MSR, then we need to
 	 * save it.
 	 */
-	if (unlikely(!msr_write_intercepted(vcpu, MSR_IA32_SPEC_CTRL)))
+	if (!static_cpu_has(X86_FEATURE_V_SPEC_CTRL) &&
+	    unlikely(!msr_write_intercepted(vcpu, MSR_IA32_SPEC_CTRL)))
 		svm->spec_ctrl = native_read_msr(MSR_IA32_SPEC_CTRL);
 
-	if (!sev_es_guest(svm->vcpu.kvm))
+	if (!sev_es_guest(vcpu->kvm))
 		reload_tss(vcpu);
 
-	x86_spec_ctrl_restore_host(svm->spec_ctrl, svm->virt_spec_ctrl);
+	if (!static_cpu_has(X86_FEATURE_V_SPEC_CTRL))
+		x86_spec_ctrl_restore_host(svm->spec_ctrl, svm->virt_spec_ctrl);
 
-	if (!sev_es_guest(svm->vcpu.kvm)) {
+	if (!sev_es_guest(vcpu->kvm)) {
 		vcpu->arch.cr2 = svm->vmcb->save.cr2;
 		vcpu->arch.regs[VCPU_REGS_RAX] = svm->vmcb->save.rax;
 		vcpu->arch.regs[VCPU_REGS_RSP] = svm->vmcb->save.rsp;
@@ -3910,7 +3859,7 @@
 	}
 
 	if (unlikely(svm->vmcb->control.exit_code == SVM_EXIT_NMI))
-		kvm_before_interrupt(&svm->vcpu);
+		kvm_before_interrupt(vcpu);
 
 	kvm_load_host_xsave_state(vcpu);
 	stgi();
@@ -3918,13 +3867,13 @@
 	/* Any pending NMI will happen here */
 
 	if (unlikely(svm->vmcb->control.exit_code == SVM_EXIT_NMI))
-		kvm_after_interrupt(&svm->vcpu);
+		kvm_after_interrupt(vcpu);
 
 	sync_cr8_to_lapic(vcpu);
 
 	svm->next_rip = 0;
-	if (is_guest_mode(&svm->vcpu)) {
-		sync_nested_vmcb_control(svm);
+	if (is_guest_mode(vcpu)) {
+		nested_sync_control_from_vmcb02(svm);
 		svm->nested.nested_run_pending = 0;
 	}
 
@@ -3933,13 +3882,11 @@
 
 	/* if exit due to PF check for async PF */
 	if (svm->vmcb->control.exit_code == SVM_EXIT_EXCP_BASE + PF_VECTOR)
-		svm->vcpu.arch.apf.host_apf_flags =
+		vcpu->arch.apf.host_apf_flags =
 			kvm_read_and_reset_apf_flags();
 
-	if (npt_enabled) {
-		vcpu->arch.regs_avail &= ~(1 << VCPU_EXREG_PDPTR);
-		vcpu->arch.regs_dirty &= ~(1 << VCPU_EXREG_PDPTR);
-	}
+	if (npt_enabled)
+		kvm_register_clear_available(vcpu, VCPU_EXREG_PDPTR);
 
 	/*
 	 * We need to handle MC intercepts here before the vcpu has a chance to
@@ -3947,9 +3894,9 @@
 	 */
 	if (unlikely(svm->vmcb->control.exit_code ==
 		     SVM_EXIT_EXCP_BASE + MC_VECTOR))
-		svm_handle_mce(svm);
+		svm_handle_mce(vcpu);
 
-	svm_complete_interrupts(svm);
+	svm_complete_interrupts(vcpu);
 
 	if (is_guest_mode(vcpu))
 		return EXIT_FASTPATH_NONE;
@@ -3957,21 +3904,26 @@
 	return svm_exit_handlers_fastpath(vcpu);
 }
 
-static void svm_load_mmu_pgd(struct kvm_vcpu *vcpu, unsigned long root,
+static void svm_load_mmu_pgd(struct kvm_vcpu *vcpu, hpa_t root_hpa,
 			     int root_level)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
 	unsigned long cr3;
 
-	cr3 = __sme_set(root);
 	if (npt_enabled) {
-		svm->vmcb->control.nested_cr3 = cr3;
+		svm->vmcb->control.nested_cr3 = __sme_set(root_hpa);
 		vmcb_mark_dirty(svm->vmcb, VMCB_NPT);
 
 		/* Loading L2's CR3 is handled by enter_svm_guest_mode.  */
 		if (!test_bit(VCPU_EXREG_CR3, (ulong *)&vcpu->arch.regs_avail))
 			return;
 		cr3 = vcpu->arch.cr3;
+	} else if (vcpu->arch.mmu->shadow_root_level >= PT64_ROOT_4LEVEL) {
+		cr3 = __sme_set(root_hpa) | kvm_get_active_pcid(vcpu);
+	} else {
+		/* PCID in the guest should be impossible with a 32-bit MMU. */
+		WARN_ON_ONCE(kvm_get_active_pcid(vcpu));
+		cr3 = root_hpa;
 	}
 
 	svm->vmcb->save.cr3 = cr3;
@@ -4048,7 +4000,7 @@
 
 	/* Update nrips enabled cache */
 	svm->nrips_enabled = kvm_cpu_cap_has(X86_FEATURE_NRIPS) &&
-			     guest_cpuid_has(&svm->vcpu, X86_FEATURE_NRIPS);
+			     guest_cpuid_has(vcpu, X86_FEATURE_NRIPS);
 
 	/* Check again if INVPCID interception if required */
 	svm_check_invpcid(svm);
@@ -4060,24 +4012,50 @@
 			vcpu->arch.reserved_gpa_bits &= ~(1UL << (best->ebx & 0x3f));
 	}
 
-	if (!kvm_vcpu_apicv_active(vcpu))
-		return;
+	if (kvm_vcpu_apicv_active(vcpu)) {
+		/*
+		 * AVIC does not work with an x2APIC mode guest. If the X2APIC feature
+		 * is exposed to the guest, disable AVIC.
+		 */
+		if (guest_cpuid_has(vcpu, X86_FEATURE_X2APIC))
+			kvm_request_apicv_update(vcpu->kvm, false,
+						 APICV_INHIBIT_REASON_X2APIC);
 
-	/*
-	 * AVIC does not work with an x2APIC mode guest. If the X2APIC feature
-	 * is exposed to the guest, disable AVIC.
-	 */
-	if (guest_cpuid_has(vcpu, X86_FEATURE_X2APIC))
-		kvm_request_apicv_update(vcpu->kvm, false,
-					 APICV_INHIBIT_REASON_X2APIC);
+		/*
+		 * Currently, AVIC does not work with nested virtualization.
+		 * So, we disable AVIC when cpuid for SVM is set in the L1 guest.
+		 */
+		if (nested && guest_cpuid_has(vcpu, X86_FEATURE_SVM))
+			kvm_request_apicv_update(vcpu->kvm, false,
+						 APICV_INHIBIT_REASON_NESTED);
+	}
 
-	/*
-	 * Currently, AVIC does not work with nested virtualization.
-	 * So, we disable AVIC when cpuid for SVM is set in the L1 guest.
-	 */
-	if (nested && guest_cpuid_has(vcpu, X86_FEATURE_SVM))
-		kvm_request_apicv_update(vcpu->kvm, false,
-					 APICV_INHIBIT_REASON_NESTED);
+	if (guest_cpuid_is_intel(vcpu)) {
+		/*
+		 * We must intercept SYSENTER_EIP and SYSENTER_ESP
+		 * accesses because the processor only stores 32 bits.
+		 * For the same reason we cannot use virtual VMLOAD/VMSAVE.
+		 */
+		svm_set_intercept(svm, INTERCEPT_VMLOAD);
+		svm_set_intercept(svm, INTERCEPT_VMSAVE);
+		svm->vmcb->control.virt_ext &= ~VIRTUAL_VMLOAD_VMSAVE_ENABLE_MASK;
+
+		set_msr_interception(vcpu, svm->msrpm, MSR_IA32_SYSENTER_EIP, 0, 0);
+		set_msr_interception(vcpu, svm->msrpm, MSR_IA32_SYSENTER_ESP, 0, 0);
+	} else {
+		/*
+		 * If hardware supports Virtual VMLOAD VMSAVE then enable it
+		 * in VMCB and clear intercepts to avoid #VMEXIT.
+		 */
+		if (vls) {
+			svm_clr_intercept(svm, INTERCEPT_VMLOAD);
+			svm_clr_intercept(svm, INTERCEPT_VMSAVE);
+			svm->vmcb->control.virt_ext |= VIRTUAL_VMLOAD_VMSAVE_ENABLE_MASK;
+		}
+		/* No need to intercept these MSRs */
+		set_msr_interception(vcpu, svm->msrpm, MSR_IA32_SYSENTER_EIP, 1, 1);
+		set_msr_interception(vcpu, svm->msrpm, MSR_IA32_SYSENTER_ESP, 1, 1);
+	}
 }
 
 static bool svm_has_wbinvd_exit(void)
@@ -4349,15 +4327,15 @@
 			if (!(saved_efer & EFER_SVME))
 				return 1;
 
-			if (kvm_vcpu_map(&svm->vcpu,
+			if (kvm_vcpu_map(vcpu,
 					 gpa_to_gfn(vmcb12_gpa), &map) == -EINVAL)
 				return 1;
 
 			if (svm_allocate_nested(svm))
 				return 1;
 
-			ret = enter_svm_guest_mode(svm, vmcb12_gpa, map.hva);
-			kvm_vcpu_unmap(&svm->vcpu, &map, true);
+			ret = enter_svm_guest_mode(vcpu, vmcb12_gpa, map.hva);
+			kvm_vcpu_unmap(vcpu, &map, true);
 		}
 	}
 
@@ -4612,6 +4590,8 @@
 	.mem_enc_reg_region = svm_register_enc_region,
 	.mem_enc_unreg_region = svm_unregister_enc_region,
 
+	.vm_copy_enc_context_from = svm_vm_copy_asid_from,
+
 	.can_emulate_instruction = svm_can_emulate_instruction,
 
 	.apic_init_signal_blocked = svm_apic_init_signal_blocked,
diff --git a/arch/x86/kvm/svm/svm.h b/arch/x86/kvm/svm/svm.h
index 39e071f..02f8ece 100644
--- a/arch/x86/kvm/svm/svm.h
+++ b/arch/x86/kvm/svm/svm.h
@@ -28,7 +28,7 @@
 };
 #define NR_HOST_SAVE_USER_MSRS ARRAY_SIZE(host_save_user_msrs)
 
-#define MAX_DIRECT_ACCESS_MSRS	18
+#define MAX_DIRECT_ACCESS_MSRS	20
 #define MSRPM_OFFSETS	16
 extern u32 msrpm_offsets[MSRPM_OFFSETS] __read_mostly;
 extern bool npt_enabled;
@@ -65,6 +65,7 @@
 	unsigned long pages_locked; /* Number of pages locked */
 	struct list_head regions_list;  /* List of registered regions */
 	u64 ap_jump_table;	/* SEV-ES AP Jump Table address */
+	struct kvm *enc_context_owner; /* Owner of copied encryption context */
 };
 
 struct kvm_svm {
@@ -81,11 +82,19 @@
 
 struct kvm_vcpu;
 
+struct kvm_vmcb_info {
+	struct vmcb *ptr;
+	unsigned long pa;
+	int cpu;
+	uint64_t asid_generation;
+};
+
 struct svm_nested_state {
-	struct vmcb *hsave;
+	struct kvm_vmcb_info vmcb02;
 	u64 hsave_msr;
 	u64 vm_cr_msr;
 	u64 vmcb12_gpa;
+	u64 last_vmcb12_gpa;
 
 	/* These are the merged vectors */
 	u32 *msrpm;
@@ -104,11 +113,12 @@
 	struct kvm_vcpu vcpu;
 	struct vmcb *vmcb;
 	unsigned long vmcb_pa;
+	struct kvm_vmcb_info vmcb01;
+	struct kvm_vmcb_info *current_vmcb;
 	struct svm_cpu_data *svm_data;
 	u32 asid;
-	uint64_t asid_generation;
-	uint64_t sysenter_esp;
-	uint64_t sysenter_eip;
+	u32 sysenter_esp_hi;
+	u32 sysenter_eip_hi;
 	uint64_t tsc_aux;
 
 	u64 msr_decfg;
@@ -239,19 +249,16 @@
 	vmcb->control.clean &= ~(1 << bit);
 }
 
+static inline bool vmcb_is_dirty(struct vmcb *vmcb, int bit)
+{
+        return !test_bit(bit, (unsigned long *)&vmcb->control.clean);
+}
+
 static inline struct vcpu_svm *to_svm(struct kvm_vcpu *vcpu)
 {
 	return container_of(vcpu, struct vcpu_svm, vcpu);
 }
 
-static inline struct vmcb *get_host_vmcb(struct vcpu_svm *svm)
-{
-	if (is_guest_mode(&svm->vcpu))
-		return svm->nested.hsave;
-	else
-		return svm->vmcb;
-}
-
 static inline void vmcb_set_intercept(struct vmcb_control_area *control, u32 bit)
 {
 	WARN_ON_ONCE(bit >= 32 * MAX_INTERCEPT);
@@ -272,7 +279,7 @@
 
 static inline void set_dr_intercepts(struct vcpu_svm *svm)
 {
-	struct vmcb *vmcb = get_host_vmcb(svm);
+	struct vmcb *vmcb = svm->vmcb01.ptr;
 
 	if (!sev_es_guest(svm->vcpu.kvm)) {
 		vmcb_set_intercept(&vmcb->control, INTERCEPT_DR0_READ);
@@ -299,7 +306,7 @@
 
 static inline void clr_dr_intercepts(struct vcpu_svm *svm)
 {
-	struct vmcb *vmcb = get_host_vmcb(svm);
+	struct vmcb *vmcb = svm->vmcb01.ptr;
 
 	vmcb->control.intercepts[INTERCEPT_DR] = 0;
 
@@ -314,7 +321,7 @@
 
 static inline void set_exception_intercept(struct vcpu_svm *svm, u32 bit)
 {
-	struct vmcb *vmcb = get_host_vmcb(svm);
+	struct vmcb *vmcb = svm->vmcb01.ptr;
 
 	WARN_ON_ONCE(bit >= 32);
 	vmcb_set_intercept(&vmcb->control, INTERCEPT_EXCEPTION_OFFSET + bit);
@@ -324,7 +331,7 @@
 
 static inline void clr_exception_intercept(struct vcpu_svm *svm, u32 bit)
 {
-	struct vmcb *vmcb = get_host_vmcb(svm);
+	struct vmcb *vmcb = svm->vmcb01.ptr;
 
 	WARN_ON_ONCE(bit >= 32);
 	vmcb_clr_intercept(&vmcb->control, INTERCEPT_EXCEPTION_OFFSET + bit);
@@ -334,7 +341,7 @@
 
 static inline void svm_set_intercept(struct vcpu_svm *svm, int bit)
 {
-	struct vmcb *vmcb = get_host_vmcb(svm);
+	struct vmcb *vmcb = svm->vmcb01.ptr;
 
 	vmcb_set_intercept(&vmcb->control, bit);
 
@@ -343,7 +350,7 @@
 
 static inline void svm_clr_intercept(struct vcpu_svm *svm, int bit)
 {
-	struct vmcb *vmcb = get_host_vmcb(svm);
+	struct vmcb *vmcb = svm->vmcb01.ptr;
 
 	vmcb_clr_intercept(&vmcb->control, bit);
 
@@ -405,7 +412,7 @@
 bool svm_nmi_blocked(struct kvm_vcpu *vcpu);
 bool svm_interrupt_blocked(struct kvm_vcpu *vcpu);
 void svm_set_gif(struct vcpu_svm *svm, bool value);
-int svm_invoke_exit_handler(struct vcpu_svm *svm, u64 exit_code);
+int svm_invoke_exit_handler(struct kvm_vcpu *vcpu, u64 exit_code);
 void set_msr_interception(struct kvm_vcpu *vcpu, u32 *msrpm, u32 msr,
 			  int read, int write);
 
@@ -437,20 +444,30 @@
 	return vmcb_is_intercept(&svm->nested.ctl, INTERCEPT_NMI);
 }
 
-int enter_svm_guest_mode(struct vcpu_svm *svm, u64 vmcb_gpa,
-			 struct vmcb *nested_vmcb);
+int enter_svm_guest_mode(struct kvm_vcpu *vcpu, u64 vmcb_gpa, struct vmcb *vmcb12);
 void svm_leave_nested(struct vcpu_svm *svm);
 void svm_free_nested(struct vcpu_svm *svm);
 int svm_allocate_nested(struct vcpu_svm *svm);
-int nested_svm_vmrun(struct vcpu_svm *svm);
+int nested_svm_vmrun(struct kvm_vcpu *vcpu);
 void nested_svm_vmloadsave(struct vmcb *from_vmcb, struct vmcb *to_vmcb);
 int nested_svm_vmexit(struct vcpu_svm *svm);
+
+static inline int nested_svm_simple_vmexit(struct vcpu_svm *svm, u32 exit_code)
+{
+	svm->vmcb->control.exit_code   = exit_code;
+	svm->vmcb->control.exit_info_1 = 0;
+	svm->vmcb->control.exit_info_2 = 0;
+	return nested_svm_vmexit(svm);
+}
+
 int nested_svm_exit_handled(struct vcpu_svm *svm);
-int nested_svm_check_permissions(struct vcpu_svm *svm);
+int nested_svm_check_permissions(struct kvm_vcpu *vcpu);
 int nested_svm_check_exception(struct vcpu_svm *svm, unsigned nr,
 			       bool has_error_code, u32 error_code);
 int nested_svm_exit_special(struct vcpu_svm *svm);
-void sync_nested_vmcb_control(struct vcpu_svm *svm);
+void nested_sync_control_from_vmcb02(struct vcpu_svm *svm);
+void nested_vmcb02_compute_g_pat(struct vcpu_svm *svm);
+void svm_switch_vmcb(struct vcpu_svm *svm, struct kvm_vmcb_info *target_vmcb);
 
 extern struct kvm_x86_nested_ops svm_nested_ops;
 
@@ -491,8 +508,8 @@
 int avic_vm_init(struct kvm *kvm);
 void avic_init_vmcb(struct vcpu_svm *svm);
 void svm_toggle_avic_for_irq_window(struct kvm_vcpu *vcpu, bool activate);
-int avic_incomplete_ipi_interception(struct vcpu_svm *svm);
-int avic_unaccelerated_access_interception(struct vcpu_svm *svm);
+int avic_incomplete_ipi_interception(struct kvm_vcpu *vcpu);
+int avic_unaccelerated_access_interception(struct kvm_vcpu *vcpu);
 int avic_init_vcpu(struct vcpu_svm *svm);
 void avic_vcpu_load(struct kvm_vcpu *vcpu, int cpu);
 void avic_vcpu_put(struct kvm_vcpu *vcpu);
@@ -561,11 +578,12 @@
 			    struct kvm_enc_region *range);
 int svm_unregister_enc_region(struct kvm *kvm,
 			      struct kvm_enc_region *range);
+int svm_vm_copy_asid_from(struct kvm *kvm, unsigned int source_fd);
 void pre_sev_run(struct vcpu_svm *svm, int cpu);
 void __init sev_hardware_setup(void);
 void sev_hardware_teardown(void);
 void sev_free_vcpu(struct kvm_vcpu *vcpu);
-int sev_handle_vmgexit(struct vcpu_svm *svm);
+int sev_handle_vmgexit(struct kvm_vcpu *vcpu);
 int sev_es_string_io(struct vcpu_svm *svm, int size, unsigned int port, int in);
 void sev_es_init_vmcb(struct vcpu_svm *svm);
 void sev_es_create_vcpu(struct vcpu_svm *svm);
diff --git a/arch/x86/kvm/svm/vmenter.S b/arch/x86/kvm/svm/vmenter.S
index 6feb8c0..4fa17df 100644
--- a/arch/x86/kvm/svm/vmenter.S
+++ b/arch/x86/kvm/svm/vmenter.S
@@ -79,28 +79,10 @@
 
 	/* Enter guest mode */
 	sti
-1:	vmload %_ASM_AX
-	jmp 3f
-2:	cmpb $0, kvm_rebooting
-	jne 3f
-	ud2
-	_ASM_EXTABLE(1b, 2b)
 
-3:	vmrun %_ASM_AX
-	jmp 5f
-4:	cmpb $0, kvm_rebooting
-	jne 5f
-	ud2
-	_ASM_EXTABLE(3b, 4b)
+1:	vmrun %_ASM_AX
 
-5:	vmsave %_ASM_AX
-	jmp 7f
-6:	cmpb $0, kvm_rebooting
-	jne 7f
-	ud2
-	_ASM_EXTABLE(5b, 6b)
-7:
-	cli
+2:	cli
 
 #ifdef CONFIG_RETPOLINE
 	/* IMPORTANT: Stuff the RSB immediately after VM-Exit, before RET! */
@@ -167,6 +149,13 @@
 #endif
 	pop %_ASM_BP
 	ret
+
+3:	cmpb $0, kvm_rebooting
+	jne 2b
+	ud2
+
+	_ASM_EXTABLE(1b, 3b)
+
 SYM_FUNC_END(__svm_vcpu_run)
 
 /**
@@ -186,18 +175,15 @@
 #endif
 	push %_ASM_BX
 
-	/* Enter guest mode */
+	/* Move @vmcb to RAX. */
 	mov %_ASM_ARG1, %_ASM_AX
+
+	/* Enter guest mode */
 	sti
 
 1:	vmrun %_ASM_AX
-	jmp 3f
-2:	cmpb $0, kvm_rebooting
-	jne 3f
-	ud2
-	_ASM_EXTABLE(1b, 2b)
 
-3:	cli
+2:	cli
 
 #ifdef CONFIG_RETPOLINE
 	/* IMPORTANT: Stuff the RSB immediately after VM-Exit, before RET! */
@@ -217,4 +203,11 @@
 #endif
 	pop %_ASM_BP
 	ret
+
+3:	cmpb $0, kvm_rebooting
+	jne 2b
+	ud2
+
+	_ASM_EXTABLE(1b, 3b)
+
 SYM_FUNC_END(__svm_sev_es_vcpu_run)
diff --git a/arch/x86/kvm/vmx/nested.c b/arch/x86/kvm/vmx/nested.c
index bcca0b8..e02ebc2 100644
--- a/arch/x86/kvm/vmx/nested.c
+++ b/arch/x86/kvm/vmx/nested.c
@@ -21,13 +21,7 @@
 static bool __read_mostly nested_early_check = 0;
 module_param(nested_early_check, bool, S_IRUGO);
 
-#define CC(consistency_check)						\
-({									\
-	bool failed = (consistency_check);				\
-	if (failed)							\
-		trace_kvm_nested_vmenter_failed(#consistency_check, 0);	\
-	failed;								\
-})
+#define CC KVM_NESTED_VMENTER_CONSISTENCY_CHECK
 
 /*
  * Hyper-V requires all of these, so mark them as supported even though
@@ -2570,11 +2564,6 @@
 		return -EINVAL;
 	}
 
-	/* Shadow page tables on either EPT or shadow page tables. */
-	if (nested_vmx_load_cr3(vcpu, vmcs12->guest_cr3, nested_cpu_has_ept(vmcs12),
-				entry_failure_code))
-		return -EINVAL;
-
 	/*
 	 * Immediately write vmcs02.GUEST_CR3.  It will be propagated to vmcs12
 	 * on nested VM-Exit, which can occur without actually running L2 and
@@ -3115,11 +3104,16 @@
 static bool nested_get_vmcs12_pages(struct kvm_vcpu *vcpu)
 {
 	struct vmcs12 *vmcs12 = get_vmcs12(vcpu);
+	enum vm_entry_failure_code entry_failure_code;
 	struct vcpu_vmx *vmx = to_vmx(vcpu);
 	struct kvm_host_map *map;
 	struct page *page;
 	u64 hpa;
 
+	if (nested_vmx_load_cr3(vcpu, vmcs12->guest_cr3, nested_cpu_has_ept(vmcs12),
+				&entry_failure_code))
+		return false;
+
 	if (nested_cpu_has2(vmcs12, SECONDARY_EXEC_VIRTUALIZE_APIC_ACCESSES)) {
 		/*
 		 * Translate L1 physical address to host physical
@@ -3363,6 +3357,10 @@
 	}
 
 	if (from_vmentry) {
+		if (nested_vmx_load_cr3(vcpu, vmcs12->guest_cr3,
+		    nested_cpu_has_ept(vmcs12), &entry_failure_code))
+			goto vmentry_fail_vmexit_guest_mode;
+
 		failed_index = nested_vmx_load_msr(vcpu,
 						   vmcs12->vm_entry_msr_load_addr,
 						   vmcs12->vm_entry_msr_load_count);
@@ -3453,6 +3451,8 @@
 	u32 interrupt_shadow = vmx_get_interrupt_shadow(vcpu);
 	enum nested_evmptrld_status evmptrld_status;
 
+	++vcpu->stat.nested_run;
+
 	if (!nested_vmx_check_permission(vcpu))
 		return 1;
 
@@ -3810,9 +3810,15 @@
 
 	/*
 	 * Process any exceptions that are not debug traps before MTF.
+	 *
+	 * Note that only a pending nested run can block a pending exception.
+	 * Otherwise an injected NMI/interrupt should either be
+	 * lost or delivered to the nested hypervisor in the IDT_VECTORING_INFO,
+	 * while delivering the pending exception.
 	 */
+
 	if (vcpu->arch.exception.pending && !vmx_pending_dbg_trap(vcpu)) {
-		if (block_nested_events)
+		if (vmx->nested.nested_run_pending)
 			return -EBUSY;
 		if (!nested_vmx_check_exception(vcpu, &exit_qual))
 			goto no_vmexit;
@@ -3829,7 +3835,7 @@
 	}
 
 	if (vcpu->arch.exception.pending) {
-		if (block_nested_events)
+		if (vmx->nested.nested_run_pending)
 			return -EBUSY;
 		if (!nested_vmx_check_exception(vcpu, &exit_qual))
 			goto no_vmexit;
@@ -4422,6 +4428,9 @@
 	/* trying to cancel vmlaunch/vmresume is a bug */
 	WARN_ON_ONCE(vmx->nested.nested_run_pending);
 
+	/* Similarly, triple faults in L2 should never escape. */
+	WARN_ON_ONCE(kvm_check_request(KVM_REQ_TRIPLE_FAULT, vcpu));
+
 	kvm_clear_request(KVM_REQ_GET_NESTED_STATE_PAGES, vcpu);
 
 	/* Service the TLB flush request for L2 before switching to L1. */
@@ -4558,6 +4567,11 @@
 	vmx->fail = 0;
 }
 
+static void nested_vmx_triple_fault(struct kvm_vcpu *vcpu)
+{
+	nested_vmx_vmexit(vcpu, EXIT_REASON_TRIPLE_FAULT, 0, 0);
+}
+
 /*
  * Decode the memory-address operand of a vmx instruction, as recorded on an
  * exit caused by such an instruction (run by a guest hypervisor).
@@ -5479,16 +5493,11 @@
 		if (!nested_vmx_check_eptp(vcpu, new_eptp))
 			return 1;
 
-		kvm_mmu_unload(vcpu);
 		mmu->ept_ad = accessed_dirty;
 		mmu->mmu_role.base.ad_disabled = !accessed_dirty;
 		vmcs12->ept_pointer = new_eptp;
-		/*
-		 * TODO: Check what's the correct approach in case
-		 * mmu reload fails. Currently, we just let the next
-		 * reload potentially fail
-		 */
-		kvm_mmu_reload(vcpu);
+
+		kvm_make_request(KVM_REQ_MMU_RELOAD, vcpu);
 	}
 
 	return 0;
@@ -6599,6 +6608,7 @@
 struct kvm_x86_nested_ops vmx_nested_ops = {
 	.check_events = vmx_check_nested_events,
 	.hv_timer_pending = nested_vmx_preemption_timer_pending,
+	.triple_fault = nested_vmx_triple_fault,
 	.get_state = vmx_get_nested_state,
 	.set_state = vmx_set_nested_state,
 	.get_nested_state_pages = vmx_get_nested_state_pages,
diff --git a/arch/x86/kvm/vmx/vmx.c b/arch/x86/kvm/vmx/vmx.c
index 32cf828..c05e6e2 100644
--- a/arch/x86/kvm/vmx/vmx.c
+++ b/arch/x86/kvm/vmx/vmx.c
@@ -472,26 +472,6 @@
 static bool __read_mostly enlightened_vmcs = true;
 module_param(enlightened_vmcs, bool, 0444);
 
-/* check_ept_pointer() should be under protection of ept_pointer_lock. */
-static void check_ept_pointer_match(struct kvm *kvm)
-{
-	struct kvm_vcpu *vcpu;
-	u64 tmp_eptp = INVALID_PAGE;
-	int i;
-
-	kvm_for_each_vcpu(i, vcpu, kvm) {
-		if (!VALID_PAGE(tmp_eptp)) {
-			tmp_eptp = to_vmx(vcpu)->ept_pointer;
-		} else if (tmp_eptp != to_vmx(vcpu)->ept_pointer) {
-			to_kvm_vmx(kvm)->ept_pointers_match
-				= EPT_POINTERS_MISMATCH;
-			return;
-		}
-	}
-
-	to_kvm_vmx(kvm)->ept_pointers_match = EPT_POINTERS_MATCH;
-}
-
 static int kvm_fill_hv_flush_list_func(struct hv_guest_mapping_flush_list *flush,
 		void *data)
 {
@@ -501,47 +481,70 @@
 			range->pages);
 }
 
-static inline int __hv_remote_flush_tlb_with_range(struct kvm *kvm,
-		struct kvm_vcpu *vcpu, struct kvm_tlb_range *range)
+static inline int hv_remote_flush_root_ept(hpa_t root_ept,
+					   struct kvm_tlb_range *range)
 {
-	u64 ept_pointer = to_vmx(vcpu)->ept_pointer;
-
-	/*
-	 * FLUSH_GUEST_PHYSICAL_ADDRESS_SPACE hypercall needs address
-	 * of the base of EPT PML4 table, strip off EPT configuration
-	 * information.
-	 */
 	if (range)
-		return hyperv_flush_guest_mapping_range(ept_pointer & PAGE_MASK,
+		return hyperv_flush_guest_mapping_range(root_ept,
 				kvm_fill_hv_flush_list_func, (void *)range);
 	else
-		return hyperv_flush_guest_mapping(ept_pointer & PAGE_MASK);
+		return hyperv_flush_guest_mapping(root_ept);
 }
 
 static int hv_remote_flush_tlb_with_range(struct kvm *kvm,
 		struct kvm_tlb_range *range)
 {
+	struct kvm_vmx *kvm_vmx = to_kvm_vmx(kvm);
 	struct kvm_vcpu *vcpu;
-	int ret = 0, i;
+	int ret = 0, i, nr_unique_valid_roots;
+	hpa_t root;
 
-	spin_lock(&to_kvm_vmx(kvm)->ept_pointer_lock);
+	spin_lock(&kvm_vmx->hv_root_ept_lock);
 
-	if (to_kvm_vmx(kvm)->ept_pointers_match == EPT_POINTERS_CHECK)
-		check_ept_pointer_match(kvm);
+	if (!VALID_PAGE(kvm_vmx->hv_root_ept)) {
+		nr_unique_valid_roots = 0;
 
-	if (to_kvm_vmx(kvm)->ept_pointers_match != EPT_POINTERS_MATCH) {
+		/*
+		 * Flush all valid roots, and see if all vCPUs have converged
+		 * on a common root, in which case future flushes can skip the
+		 * loop and flush the common root.
+		 */
 		kvm_for_each_vcpu(i, vcpu, kvm) {
-			/* If ept_pointer is invalid pointer, bypass flush request. */
-			if (VALID_PAGE(to_vmx(vcpu)->ept_pointer))
-				ret |= __hv_remote_flush_tlb_with_range(
-					kvm, vcpu, range);
+			root = to_vmx(vcpu)->hv_root_ept;
+			if (!VALID_PAGE(root) || root == kvm_vmx->hv_root_ept)
+				continue;
+
+			/*
+			 * Set the tracked root to the first valid root.  Keep
+			 * this root for the entirety of the loop even if more
+			 * roots are encountered as a low effort optimization
+			 * to avoid flushing the same (first) root again.
+			 */
+			if (++nr_unique_valid_roots == 1)
+				kvm_vmx->hv_root_ept = root;
+
+			if (!ret)
+				ret = hv_remote_flush_root_ept(root, range);
+
+			/*
+			 * Stop processing roots if a failure occurred and
+			 * multiple valid roots have already been detected.
+			 */
+			if (ret && nr_unique_valid_roots > 1)
+				break;
 		}
+
+		/*
+		 * The optimized flush of a single root can't be used if there
+		 * are multiple valid roots (obviously).
+		 */
+		if (nr_unique_valid_roots > 1)
+			kvm_vmx->hv_root_ept = INVALID_PAGE;
 	} else {
-		ret = __hv_remote_flush_tlb_with_range(kvm,
-				kvm_get_vcpu(kvm, 0), range);
+		ret = hv_remote_flush_root_ept(kvm_vmx->hv_root_ept, range);
 	}
 
-	spin_unlock(&to_kvm_vmx(kvm)->ept_pointer_lock);
+	spin_unlock(&kvm_vmx->hv_root_ept_lock);
 	return ret;
 }
 static int hv_remote_flush_tlb(struct kvm *kvm)
@@ -559,7 +562,7 @@
 	 * evmcs in singe VM shares same assist page.
 	 */
 	if (!*p_hv_pa_pg)
-		*p_hv_pa_pg = kzalloc(PAGE_SIZE, GFP_KERNEL);
+		*p_hv_pa_pg = kzalloc(PAGE_SIZE, GFP_KERNEL_ACCOUNT);
 
 	if (!*p_hv_pa_pg)
 		return -ENOMEM;
@@ -576,6 +579,21 @@
 
 #endif /* IS_ENABLED(CONFIG_HYPERV) */
 
+static void hv_track_root_ept(struct kvm_vcpu *vcpu, hpa_t root_ept)
+{
+#if IS_ENABLED(CONFIG_HYPERV)
+	struct kvm_vmx *kvm_vmx = to_kvm_vmx(vcpu->kvm);
+
+	if (kvm_x86_ops.tlb_remote_flush == hv_remote_flush_tlb) {
+		spin_lock(&kvm_vmx->hv_root_ept_lock);
+		to_vmx(vcpu)->hv_root_ept = root_ept;
+		if (root_ept != kvm_vmx->hv_root_ept)
+			kvm_vmx->hv_root_ept = INVALID_PAGE;
+		spin_unlock(&kvm_vmx->hv_root_ept_lock);
+	}
+#endif
+}
+
 /*
  * Comment's format: document - errata name - stepping - processor name.
  * Refer from
@@ -3088,8 +3106,7 @@
 	return 4;
 }
 
-u64 construct_eptp(struct kvm_vcpu *vcpu, unsigned long root_hpa,
-		   int root_level)
+u64 construct_eptp(struct kvm_vcpu *vcpu, hpa_t root_hpa, int root_level)
 {
 	u64 eptp = VMX_EPTP_MT_WB;
 
@@ -3098,13 +3115,13 @@
 	if (enable_ept_ad_bits &&
 	    (!is_guest_mode(vcpu) || nested_ept_ad_enabled(vcpu)))
 		eptp |= VMX_EPTP_AD_ENABLE_BIT;
-	eptp |= (root_hpa & PAGE_MASK);
+	eptp |= root_hpa;
 
 	return eptp;
 }
 
-static void vmx_load_mmu_pgd(struct kvm_vcpu *vcpu, unsigned long pgd,
-			     int pgd_level)
+static void vmx_load_mmu_pgd(struct kvm_vcpu *vcpu, hpa_t root_hpa,
+			     int root_level)
 {
 	struct kvm *kvm = vcpu->kvm;
 	bool update_guest_cr3 = true;
@@ -3112,16 +3129,10 @@
 	u64 eptp;
 
 	if (enable_ept) {
-		eptp = construct_eptp(vcpu, pgd, pgd_level);
+		eptp = construct_eptp(vcpu, root_hpa, root_level);
 		vmcs_write64(EPT_POINTER, eptp);
 
-		if (kvm_x86_ops.tlb_remote_flush) {
-			spin_lock(&to_kvm_vmx(kvm)->ept_pointer_lock);
-			to_vmx(vcpu)->ept_pointer = eptp;
-			to_kvm_vmx(kvm)->ept_pointers_match
-				= EPT_POINTERS_CHECK;
-			spin_unlock(&to_kvm_vmx(kvm)->ept_pointer_lock);
-		}
+		hv_track_root_ept(vcpu, root_hpa);
 
 		if (!enable_unrestricted_guest && !is_paging(vcpu))
 			guest_cr3 = to_kvm_vmx(kvm)->ept_identity_map_addr;
@@ -3131,7 +3142,7 @@
 			update_guest_cr3 = false;
 		vmx_ept_load_pdptrs(vcpu);
 	} else {
-		guest_cr3 = pgd;
+		guest_cr3 = root_hpa | kvm_get_active_pcid(vcpu);
 	}
 
 	if (update_guest_cr3)
@@ -4314,15 +4325,6 @@
 	vmx->secondary_exec_control = exec_control;
 }
 
-static void ept_set_mmio_spte_mask(void)
-{
-	/*
-	 * EPT Misconfigurations can be generated if the value of bits 2:0
-	 * of an EPT paging-structure entry is 110b (write/execute).
-	 */
-	kvm_mmu_set_mmio_spte_mask(VMX_EPT_MISCONFIG_WX_VALUE, 0);
-}
-
 #define VMX_XSS_EXIT_BITMAP 0
 
 /*
@@ -5184,17 +5186,6 @@
 	return 1;
 }
 
-static int handle_vmcall(struct kvm_vcpu *vcpu)
-{
-	return kvm_emulate_hypercall(vcpu);
-}
-
-static int handle_invd(struct kvm_vcpu *vcpu)
-{
-	/* Treat an INVD instruction as a NOP and just skip it. */
-	return kvm_skip_emulated_instruction(vcpu);
-}
-
 static int handle_invlpg(struct kvm_vcpu *vcpu)
 {
 	unsigned long exit_qualification = vmx_get_exit_qual(vcpu);
@@ -5203,28 +5194,6 @@
 	return kvm_skip_emulated_instruction(vcpu);
 }
 
-static int handle_rdpmc(struct kvm_vcpu *vcpu)
-{
-	int err;
-
-	err = kvm_rdpmc(vcpu);
-	return kvm_complete_insn_gp(vcpu, err);
-}
-
-static int handle_wbinvd(struct kvm_vcpu *vcpu)
-{
-	return kvm_emulate_wbinvd(vcpu);
-}
-
-static int handle_xsetbv(struct kvm_vcpu *vcpu)
-{
-	u64 new_bv = kvm_read_edx_eax(vcpu);
-	u32 index = kvm_rcx_read(vcpu);
-
-	int err = kvm_set_xcr(vcpu, index, new_bv);
-	return kvm_complete_insn_gp(vcpu, err);
-}
-
 static int handle_apic_access(struct kvm_vcpu *vcpu)
 {
 	if (likely(fasteoi)) {
@@ -5485,18 +5454,6 @@
 	}
 }
 
-static void vmx_enable_tdp(void)
-{
-	kvm_mmu_set_mask_ptes(VMX_EPT_READABLE_MASK,
-		enable_ept_ad_bits ? VMX_EPT_ACCESS_BIT : 0ull,
-		enable_ept_ad_bits ? VMX_EPT_DIRTY_BIT : 0ull,
-		0ull, VMX_EPT_EXECUTABLE_MASK,
-		cpu_has_vmx_ept_execute_only() ? 0ull : VMX_EPT_READABLE_MASK,
-		VMX_EPT_RWX_MASK, 0ull);
-
-	ept_set_mmio_spte_mask();
-}
-
 /*
  * Indicate a busy-waiting vcpu in spinlock. We do not enable the PAUSE
  * exiting, so only get here on cpu with PAUSE-Loop-Exiting.
@@ -5516,34 +5473,11 @@
 	return kvm_skip_emulated_instruction(vcpu);
 }
 
-static int handle_nop(struct kvm_vcpu *vcpu)
-{
-	return kvm_skip_emulated_instruction(vcpu);
-}
-
-static int handle_mwait(struct kvm_vcpu *vcpu)
-{
-	printk_once(KERN_WARNING "kvm: MWAIT instruction emulated as NOP!\n");
-	return handle_nop(vcpu);
-}
-
-static int handle_invalid_op(struct kvm_vcpu *vcpu)
-{
-	kvm_queue_exception(vcpu, UD_VECTOR);
-	return 1;
-}
-
 static int handle_monitor_trap(struct kvm_vcpu *vcpu)
 {
 	return 1;
 }
 
-static int handle_monitor(struct kvm_vcpu *vcpu)
-{
-	printk_once(KERN_WARNING "kvm: MONITOR instruction emulated as NOP!\n");
-	return handle_nop(vcpu);
-}
-
 static int handle_invpcid(struct kvm_vcpu *vcpu)
 {
 	u32 vmx_instruction_info;
@@ -5668,10 +5602,10 @@
 	[EXIT_REASON_MSR_WRITE]               = kvm_emulate_wrmsr,
 	[EXIT_REASON_INTERRUPT_WINDOW]        = handle_interrupt_window,
 	[EXIT_REASON_HLT]                     = kvm_emulate_halt,
-	[EXIT_REASON_INVD]		      = handle_invd,
+	[EXIT_REASON_INVD]		      = kvm_emulate_invd,
 	[EXIT_REASON_INVLPG]		      = handle_invlpg,
-	[EXIT_REASON_RDPMC]                   = handle_rdpmc,
-	[EXIT_REASON_VMCALL]                  = handle_vmcall,
+	[EXIT_REASON_RDPMC]                   = kvm_emulate_rdpmc,
+	[EXIT_REASON_VMCALL]                  = kvm_emulate_hypercall,
 	[EXIT_REASON_VMCLEAR]		      = handle_vmx_instruction,
 	[EXIT_REASON_VMLAUNCH]		      = handle_vmx_instruction,
 	[EXIT_REASON_VMPTRLD]		      = handle_vmx_instruction,
@@ -5685,8 +5619,8 @@
 	[EXIT_REASON_APIC_ACCESS]             = handle_apic_access,
 	[EXIT_REASON_APIC_WRITE]              = handle_apic_write,
 	[EXIT_REASON_EOI_INDUCED]             = handle_apic_eoi_induced,
-	[EXIT_REASON_WBINVD]                  = handle_wbinvd,
-	[EXIT_REASON_XSETBV]                  = handle_xsetbv,
+	[EXIT_REASON_WBINVD]                  = kvm_emulate_wbinvd,
+	[EXIT_REASON_XSETBV]                  = kvm_emulate_xsetbv,
 	[EXIT_REASON_TASK_SWITCH]             = handle_task_switch,
 	[EXIT_REASON_MCE_DURING_VMENTRY]      = handle_machine_check,
 	[EXIT_REASON_GDTR_IDTR]		      = handle_desc,
@@ -5694,13 +5628,13 @@
 	[EXIT_REASON_EPT_VIOLATION]	      = handle_ept_violation,
 	[EXIT_REASON_EPT_MISCONFIG]           = handle_ept_misconfig,
 	[EXIT_REASON_PAUSE_INSTRUCTION]       = handle_pause,
-	[EXIT_REASON_MWAIT_INSTRUCTION]	      = handle_mwait,
+	[EXIT_REASON_MWAIT_INSTRUCTION]	      = kvm_emulate_mwait,
 	[EXIT_REASON_MONITOR_TRAP_FLAG]       = handle_monitor_trap,
-	[EXIT_REASON_MONITOR_INSTRUCTION]     = handle_monitor,
+	[EXIT_REASON_MONITOR_INSTRUCTION]     = kvm_emulate_monitor,
 	[EXIT_REASON_INVEPT]                  = handle_vmx_instruction,
 	[EXIT_REASON_INVVPID]                 = handle_vmx_instruction,
-	[EXIT_REASON_RDRAND]                  = handle_invalid_op,
-	[EXIT_REASON_RDSEED]                  = handle_invalid_op,
+	[EXIT_REASON_RDRAND]                  = kvm_handle_invalid_op,
+	[EXIT_REASON_RDSEED]                  = kvm_handle_invalid_op,
 	[EXIT_REASON_PML_FULL]		      = handle_pml_full,
 	[EXIT_REASON_INVPCID]                 = handle_invpcid,
 	[EXIT_REASON_VMFUNC]		      = handle_vmx_instruction,
@@ -5787,12 +5721,23 @@
 	       vmcs_readl(limit + GUEST_GDTR_BASE - GUEST_GDTR_LIMIT));
 }
 
-void dump_vmcs(void)
+static void vmx_dump_msrs(char *name, struct vmx_msrs *m)
 {
+	unsigned int i;
+	struct vmx_msr_entry *e;
+
+	pr_err("MSR %s:\n", name);
+	for (i = 0, e = m->val; i < m->nr; ++i, ++e)
+		pr_err("  %2d: msr=0x%08x value=0x%016llx\n", i, e->index, e->value);
+}
+
+void dump_vmcs(struct kvm_vcpu *vcpu)
+{
+	struct vcpu_vmx *vmx = to_vmx(vcpu);
 	u32 vmentry_ctl, vmexit_ctl;
 	u32 cpu_based_exec_ctrl, pin_based_exec_ctrl, secondary_exec_control;
 	unsigned long cr4;
-	u64 efer;
+	int efer_slot;
 
 	if (!dump_invalid_vmcs) {
 		pr_warn_ratelimited("set kvm_intel.dump_invalid_vmcs=1 to dump internal KVM state.\n");
@@ -5804,7 +5749,6 @@
 	cpu_based_exec_ctrl = vmcs_read32(CPU_BASED_VM_EXEC_CONTROL);
 	pin_based_exec_ctrl = vmcs_read32(PIN_BASED_VM_EXEC_CONTROL);
 	cr4 = vmcs_readl(GUEST_CR4);
-	efer = vmcs_read64(GUEST_IA32_EFER);
 	secondary_exec_control = 0;
 	if (cpu_has_secondary_exec_ctrls())
 		secondary_exec_control = vmcs_read32(SECONDARY_VM_EXEC_CONTROL);
@@ -5816,9 +5760,7 @@
 	pr_err("CR4: actual=0x%016lx, shadow=0x%016lx, gh_mask=%016lx\n",
 	       cr4, vmcs_readl(CR4_READ_SHADOW), vmcs_readl(CR4_GUEST_HOST_MASK));
 	pr_err("CR3 = 0x%016lx\n", vmcs_readl(GUEST_CR3));
-	if ((secondary_exec_control & SECONDARY_EXEC_ENABLE_EPT) &&
-	    (cr4 & X86_CR4_PAE) && !(efer & EFER_LMA))
-	{
+	if (cpu_has_vmx_ept()) {
 		pr_err("PDPTR0 = 0x%016llx  PDPTR1 = 0x%016llx\n",
 		       vmcs_read64(GUEST_PDPTR0), vmcs_read64(GUEST_PDPTR1));
 		pr_err("PDPTR2 = 0x%016llx  PDPTR3 = 0x%016llx\n",
@@ -5841,10 +5783,20 @@
 	vmx_dump_sel("LDTR:", GUEST_LDTR_SELECTOR);
 	vmx_dump_dtsel("IDTR:", GUEST_IDTR_LIMIT);
 	vmx_dump_sel("TR:  ", GUEST_TR_SELECTOR);
-	if ((vmexit_ctl & (VM_EXIT_SAVE_IA32_PAT | VM_EXIT_SAVE_IA32_EFER)) ||
-	    (vmentry_ctl & (VM_ENTRY_LOAD_IA32_PAT | VM_ENTRY_LOAD_IA32_EFER)))
-		pr_err("EFER =     0x%016llx  PAT = 0x%016llx\n",
-		       efer, vmcs_read64(GUEST_IA32_PAT));
+	efer_slot = vmx_find_loadstore_msr_slot(&vmx->msr_autoload.guest, MSR_EFER);
+	if (vmentry_ctl & VM_ENTRY_LOAD_IA32_EFER)
+		pr_err("EFER= 0x%016llx\n", vmcs_read64(GUEST_IA32_EFER));
+	else if (efer_slot >= 0)
+		pr_err("EFER= 0x%016llx (autoload)\n",
+		       vmx->msr_autoload.guest.val[efer_slot].value);
+	else if (vmentry_ctl & VM_ENTRY_IA32E_MODE)
+		pr_err("EFER= 0x%016llx (effective)\n",
+		       vcpu->arch.efer | (EFER_LMA | EFER_LME));
+	else
+		pr_err("EFER= 0x%016llx (effective)\n",
+		       vcpu->arch.efer & ~(EFER_LMA | EFER_LME));
+	if (vmentry_ctl & VM_ENTRY_LOAD_IA32_PAT)
+		pr_err("PAT = 0x%016llx\n", vmcs_read64(GUEST_IA32_PAT));
 	pr_err("DebugCtl = 0x%016llx  DebugExceptions = 0x%016lx\n",
 	       vmcs_read64(GUEST_IA32_DEBUGCTL),
 	       vmcs_readl(GUEST_PENDING_DBG_EXCEPTIONS));
@@ -5860,6 +5812,10 @@
 	if (secondary_exec_control & SECONDARY_EXEC_VIRTUAL_INTR_DELIVERY)
 		pr_err("InterruptStatus = %04x\n",
 		       vmcs_read16(GUEST_INTR_STATUS));
+	if (vmcs_read32(VM_ENTRY_MSR_LOAD_COUNT) > 0)
+		vmx_dump_msrs("guest autoload", &vmx->msr_autoload.guest);
+	if (vmcs_read32(VM_EXIT_MSR_STORE_COUNT) > 0)
+		vmx_dump_msrs("guest autostore", &vmx->msr_autostore.guest);
 
 	pr_err("*** Host State ***\n");
 	pr_err("RIP = 0x%016lx  RSP = 0x%016lx\n",
@@ -5881,14 +5837,16 @@
 	       vmcs_readl(HOST_IA32_SYSENTER_ESP),
 	       vmcs_read32(HOST_IA32_SYSENTER_CS),
 	       vmcs_readl(HOST_IA32_SYSENTER_EIP));
-	if (vmexit_ctl & (VM_EXIT_LOAD_IA32_PAT | VM_EXIT_LOAD_IA32_EFER))
-		pr_err("EFER = 0x%016llx  PAT = 0x%016llx\n",
-		       vmcs_read64(HOST_IA32_EFER),
-		       vmcs_read64(HOST_IA32_PAT));
+	if (vmexit_ctl & VM_EXIT_LOAD_IA32_EFER)
+		pr_err("EFER= 0x%016llx\n", vmcs_read64(HOST_IA32_EFER));
+	if (vmexit_ctl & VM_EXIT_LOAD_IA32_PAT)
+		pr_err("PAT = 0x%016llx\n", vmcs_read64(HOST_IA32_PAT));
 	if (cpu_has_load_perf_global_ctrl() &&
 	    vmexit_ctl & VM_EXIT_LOAD_IA32_PERF_GLOBAL_CTRL)
 		pr_err("PerfGlobCtl = 0x%016llx\n",
 		       vmcs_read64(HOST_IA32_PERF_GLOBAL_CTRL));
+	if (vmcs_read32(VM_EXIT_MSR_LOAD_COUNT) > 0)
+		vmx_dump_msrs("host autoload", &vmx->msr_autoload.host);
 
 	pr_err("*** Control State ***\n");
 	pr_err("PinBased=%08x CPUBased=%08x SecondaryExec=%08x\n",
@@ -5997,7 +5955,7 @@
 	}
 
 	if (exit_reason.failed_vmentry) {
-		dump_vmcs();
+		dump_vmcs(vcpu);
 		vcpu->run->exit_reason = KVM_EXIT_FAIL_ENTRY;
 		vcpu->run->fail_entry.hardware_entry_failure_reason
 			= exit_reason.full;
@@ -6006,7 +5964,7 @@
 	}
 
 	if (unlikely(vmx->fail)) {
-		dump_vmcs();
+		dump_vmcs(vcpu);
 		vcpu->run->exit_reason = KVM_EXIT_FAIL_ENTRY;
 		vcpu->run->fail_entry.hardware_entry_failure_reason
 			= vmcs_read32(VM_INSTRUCTION_ERROR);
@@ -6092,7 +6050,7 @@
 unexpected_vmexit:
 	vcpu_unimpl(vcpu, "vmx: unexpected exit reason 0x%x\n",
 		    exit_reason.full);
-	dump_vmcs();
+	dump_vmcs(vcpu);
 	vcpu->run->exit_reason = KVM_EXIT_INTERNAL_ERROR;
 	vcpu->run->internal.suberror =
 			KVM_INTERNAL_ERROR_UNEXPECTED_EXIT_REASON;
@@ -6989,8 +6947,9 @@
 	vmx->pi_desc.nv = POSTED_INTR_VECTOR;
 	vmx->pi_desc.sn = 1;
 
-	vmx->ept_pointer = INVALID_PAGE;
-
+#if IS_ENABLED(CONFIG_HYPERV)
+	vmx->hv_root_ept = INVALID_PAGE;
+#endif
 	return 0;
 
 free_vmcs:
@@ -7007,7 +6966,9 @@
 
 static int vmx_vm_init(struct kvm *kvm)
 {
-	spin_lock_init(&to_kvm_vmx(kvm)->ept_pointer_lock);
+#if IS_ENABLED(CONFIG_HYPERV)
+	spin_lock_init(&to_kvm_vmx(kvm)->hv_root_ept_lock);
+#endif
 
 	if (!ple_gap)
 		kvm->arch.pause_in_guest = true;
@@ -7848,7 +7809,8 @@
 	set_bit(0, vmx_vpid_bitmap); /* 0 is reserved for host */
 
 	if (enable_ept)
-		vmx_enable_tdp();
+		kvm_mmu_set_ept_masks(enable_ept_ad_bits,
+				      cpu_has_vmx_ept_execute_only());
 
 	if (!enable_ept)
 		ept_lpage_level = 0;
diff --git a/arch/x86/kvm/vmx/vmx.h b/arch/x86/kvm/vmx/vmx.h
index 89da5e1..7886a08 100644
--- a/arch/x86/kvm/vmx/vmx.h
+++ b/arch/x86/kvm/vmx/vmx.h
@@ -325,7 +325,9 @@
 	 */
 	u64 msr_ia32_feature_control;
 	u64 msr_ia32_feature_control_valid_bits;
-	u64 ept_pointer;
+#if IS_ENABLED(CONFIG_HYPERV)
+	u64 hv_root_ept;
+#endif
 
 	struct pt_desc pt_desc;
 	struct lbr_desc lbr_desc;
@@ -338,12 +340,6 @@
 	} shadow_msr_intercept;
 };
 
-enum ept_pointers_status {
-	EPT_POINTERS_CHECK = 0,
-	EPT_POINTERS_MATCH = 1,
-	EPT_POINTERS_MISMATCH = 2
-};
-
 struct kvm_vmx {
 	struct kvm kvm;
 
@@ -351,8 +347,10 @@
 	bool ept_identity_pagetable_done;
 	gpa_t ept_identity_map_addr;
 
-	enum ept_pointers_status ept_pointers_match;
-	spinlock_t ept_pointer_lock;
+#if IS_ENABLED(CONFIG_HYPERV)
+	hpa_t hv_root_ept;
+	spinlock_t hv_root_ept_lock;
+#endif
 };
 
 bool nested_vmx_allowed(struct kvm_vcpu *vcpu);
@@ -376,8 +374,7 @@
 void ept_save_pdptrs(struct kvm_vcpu *vcpu);
 void vmx_get_segment(struct kvm_vcpu *vcpu, struct kvm_segment *var, int seg);
 void vmx_set_segment(struct kvm_vcpu *vcpu, struct kvm_segment *var, int seg);
-u64 construct_eptp(struct kvm_vcpu *vcpu, unsigned long root_hpa,
-		   int root_level);
+u64 construct_eptp(struct kvm_vcpu *vcpu, hpa_t root_hpa, int root_level);
 
 void vmx_update_exception_bitmap(struct kvm_vcpu *vcpu);
 void vmx_update_msr_bitmap(struct kvm_vcpu *vcpu);
@@ -543,6 +540,6 @@
 	return is_unrestricted_guest(vcpu) || __vmx_guest_state_valid(vcpu);
 }
 
-void dump_vmcs(void);
+void dump_vmcs(struct kvm_vcpu *vcpu);
 
 #endif /* __KVM_X86_VMX_H */
diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
index eca6362..863d3ed 100644
--- a/arch/x86/kvm/x86.c
+++ b/arch/x86/kvm/x86.c
@@ -245,6 +245,7 @@
 	VCPU_STAT("l1d_flush", l1d_flush),
 	VCPU_STAT("halt_poll_success_ns", halt_poll_success_ns),
 	VCPU_STAT("halt_poll_fail_ns", halt_poll_fail_ns),
+	VCPU_STAT("nested_run", nested_run),
 	VM_STAT("mmu_shadow_zapped", mmu_shadow_zapped),
 	VM_STAT("mmu_pte_write", mmu_pte_write),
 	VM_STAT("mmu_pde_zapped", mmu_pde_zapped),
@@ -543,8 +544,6 @@
 
 	if (!vcpu->arch.exception.pending && !vcpu->arch.exception.injected) {
 	queue:
-		if (has_error && !is_protmode(vcpu))
-			has_error = false;
 		if (reinject) {
 			/*
 			 * On vmentry, vcpu->arch.exception.pending is only
@@ -983,14 +982,17 @@
 	return 0;
 }
 
-int kvm_set_xcr(struct kvm_vcpu *vcpu, u32 index, u64 xcr)
+int kvm_emulate_xsetbv(struct kvm_vcpu *vcpu)
 {
-	if (static_call(kvm_x86_get_cpl)(vcpu) == 0)
-		return __kvm_set_xcr(vcpu, index, xcr);
+	if (static_call(kvm_x86_get_cpl)(vcpu) != 0 ||
+	    __kvm_set_xcr(vcpu, kvm_rcx_read(vcpu), kvm_read_edx_eax(vcpu))) {
+		kvm_inject_gp(vcpu, 0);
+		return 1;
+	}
 
-	return 1;
+	return kvm_skip_emulated_instruction(vcpu);
 }
-EXPORT_SYMBOL_GPL(kvm_set_xcr);
+EXPORT_SYMBOL_GPL(kvm_emulate_xsetbv);
 
 bool kvm_is_valid_cr4(struct kvm_vcpu *vcpu, unsigned long cr4)
 {
@@ -1191,20 +1193,21 @@
 }
 EXPORT_SYMBOL_GPL(kvm_get_dr);
 
-bool kvm_rdpmc(struct kvm_vcpu *vcpu)
+int kvm_emulate_rdpmc(struct kvm_vcpu *vcpu)
 {
 	u32 ecx = kvm_rcx_read(vcpu);
 	u64 data;
-	int err;
 
-	err = kvm_pmu_rdpmc(vcpu, ecx, &data);
-	if (err)
-		return err;
+	if (kvm_pmu_rdpmc(vcpu, ecx, &data)) {
+		kvm_inject_gp(vcpu, 0);
+		return 1;
+	}
+
 	kvm_rax_write(vcpu, (u32)data);
 	kvm_rdx_write(vcpu, data >> 32);
-	return err;
+	return kvm_skip_emulated_instruction(vcpu);
 }
-EXPORT_SYMBOL_GPL(kvm_rdpmc);
+EXPORT_SYMBOL_GPL(kvm_emulate_rdpmc);
 
 /*
  * List of msr numbers which we expose to userspace through KVM_GET_MSRS
@@ -1791,6 +1794,40 @@
 }
 EXPORT_SYMBOL_GPL(kvm_emulate_wrmsr);
 
+int kvm_emulate_as_nop(struct kvm_vcpu *vcpu)
+{
+	return kvm_skip_emulated_instruction(vcpu);
+}
+EXPORT_SYMBOL_GPL(kvm_emulate_as_nop);
+
+int kvm_emulate_invd(struct kvm_vcpu *vcpu)
+{
+	/* Treat an INVD instruction as a NOP and just skip it. */
+	return kvm_emulate_as_nop(vcpu);
+}
+EXPORT_SYMBOL_GPL(kvm_emulate_invd);
+
+int kvm_emulate_mwait(struct kvm_vcpu *vcpu)
+{
+	pr_warn_once("kvm: MWAIT instruction emulated as NOP!\n");
+	return kvm_emulate_as_nop(vcpu);
+}
+EXPORT_SYMBOL_GPL(kvm_emulate_mwait);
+
+int kvm_handle_invalid_op(struct kvm_vcpu *vcpu)
+{
+	kvm_queue_exception(vcpu, UD_VECTOR);
+	return 1;
+}
+EXPORT_SYMBOL_GPL(kvm_handle_invalid_op);
+
+int kvm_emulate_monitor(struct kvm_vcpu *vcpu)
+{
+	pr_warn_once("kvm: MONITOR instruction emulated as NOP!\n");
+	return kvm_emulate_as_nop(vcpu);
+}
+EXPORT_SYMBOL_GPL(kvm_emulate_monitor);
+
 static inline bool kvm_vcpu_exit_request(struct kvm_vcpu *vcpu)
 {
 	xfer_to_guest_mode_prepare();
@@ -3382,6 +3419,12 @@
 		msr_info->data = 0;
 		break;
 	case MSR_F15H_PERF_CTL0 ... MSR_F15H_PERF_CTR5:
+		if (kvm_pmu_is_valid_msr(vcpu, msr_info->index))
+			return kvm_pmu_get_msr(vcpu, msr_info);
+		if (!msr_info->host_initiated)
+			return 1;
+		msr_info->data = 0;
+		break;
 	case MSR_K7_EVNTSEL0 ... MSR_K7_EVNTSEL3:
 	case MSR_K7_PERFCTR0 ... MSR_K7_PERFCTR3:
 	case MSR_P6_PERFCTR0 ... MSR_P6_PERFCTR1:
@@ -3771,6 +3814,7 @@
 	case KVM_CAP_X86_USER_SPACE_MSR:
 	case KVM_CAP_X86_MSR_FILTER:
 	case KVM_CAP_ENFORCE_PV_FEATURE_CPUID:
+	case KVM_CAP_VM_COPY_ENC_CONTEXT_FROM:
 		r = 1;
 		break;
 #ifdef CONFIG_KVM_XEN
@@ -4675,7 +4719,6 @@
 			kvm_update_pv_runtime(vcpu);
 
 		return 0;
-
 	default:
 		return -EINVAL;
 	}
@@ -5357,6 +5400,11 @@
 			kvm->arch.bus_lock_detection_enabled = true;
 		r = 0;
 		break;
+	case KVM_CAP_VM_COPY_ENC_CONTEXT_FROM:
+		r = -ENOTTY;
+		if (kvm_x86_ops.vm_copy_enc_context_from)
+			r = kvm_x86_ops.vm_copy_enc_context_from(kvm, cap->args[0]);
+		return r;
 	default:
 		r = -EINVAL;
 		break;
@@ -8045,9 +8093,6 @@
 	if (r)
 		goto out_free_percpu;
 
-	kvm_mmu_set_mask_ptes(PT_USER_MASK, PT_ACCESSED_MASK,
-			PT_DIRTY_MASK, PT64_NX_MASK, 0,
-			PT_PRESENT_MASK, 0, sme_me_mask);
 	kvm_timer_init();
 
 	perf_register_guest_info_callbacks(&kvm_guest_cbs);
@@ -8369,6 +8414,27 @@
 	static_call(kvm_x86_update_cr8_intercept)(vcpu, tpr, max_irr);
 }
 
+
+int kvm_check_nested_events(struct kvm_vcpu *vcpu)
+{
+	if (WARN_ON_ONCE(!is_guest_mode(vcpu)))
+		return -EIO;
+
+	if (kvm_check_request(KVM_REQ_TRIPLE_FAULT, vcpu)) {
+		kvm_x86_ops.nested_ops->triple_fault(vcpu);
+		return 1;
+	}
+
+	return kvm_x86_ops.nested_ops->check_events(vcpu);
+}
+
+static void kvm_inject_exception(struct kvm_vcpu *vcpu)
+{
+	if (vcpu->arch.exception.error_code && !is_protmode(vcpu))
+		vcpu->arch.exception.error_code = false;
+	static_call(kvm_x86_queue_exception)(vcpu);
+}
+
 static void inject_pending_event(struct kvm_vcpu *vcpu, bool *req_immediate_exit)
 {
 	int r;
@@ -8377,7 +8443,7 @@
 	/* try to reinject previous events if any */
 
 	if (vcpu->arch.exception.injected) {
-		static_call(kvm_x86_queue_exception)(vcpu);
+		kvm_inject_exception(vcpu);
 		can_inject = false;
 	}
 	/*
@@ -8414,7 +8480,7 @@
 	 * from L2 to L1.
 	 */
 	if (is_guest_mode(vcpu)) {
-		r = kvm_x86_ops.nested_ops->check_events(vcpu);
+		r = kvm_check_nested_events(vcpu);
 		if (r < 0)
 			goto busy;
 	}
@@ -8440,7 +8506,7 @@
 			}
 		}
 
-		static_call(kvm_x86_queue_exception)(vcpu);
+		kvm_inject_exception(vcpu);
 		can_inject = false;
 	}
 
@@ -8977,10 +9043,14 @@
 			goto out;
 		}
 		if (kvm_check_request(KVM_REQ_TRIPLE_FAULT, vcpu)) {
-			vcpu->run->exit_reason = KVM_EXIT_SHUTDOWN;
-			vcpu->mmio_needed = 0;
-			r = 0;
-			goto out;
+			if (is_guest_mode(vcpu)) {
+				kvm_x86_ops.nested_ops->triple_fault(vcpu);
+			} else {
+				vcpu->run->exit_reason = KVM_EXIT_SHUTDOWN;
+				vcpu->mmio_needed = 0;
+				r = 0;
+				goto out;
+			}
 		}
 		if (kvm_check_request(KVM_REQ_APF_HALT, vcpu)) {
 			/* Page is swapped out. Do synthetic halt */
@@ -9278,7 +9348,7 @@
 static inline bool kvm_vcpu_running(struct kvm_vcpu *vcpu)
 {
 	if (is_guest_mode(vcpu))
-		kvm_x86_ops.nested_ops->check_events(vcpu);
+		kvm_check_nested_events(vcpu);
 
 	return (vcpu->arch.mp_state == KVM_MP_STATE_RUNNABLE &&
 		!vcpu->arch.apf.halted);
@@ -11541,7 +11611,7 @@
 
 		fallthrough;
 	case INVPCID_TYPE_ALL_INCL_GLOBAL:
-		kvm_mmu_unload(vcpu);
+		kvm_make_request(KVM_REQ_MMU_RELOAD, vcpu);
 		return kvm_skip_emulated_instruction(vcpu);
 
 	default:
diff --git a/arch/x86/kvm/x86.h b/arch/x86/kvm/x86.h
index 9035e34a..5334bf4 100644
--- a/arch/x86/kvm/x86.h
+++ b/arch/x86/kvm/x86.h
@@ -8,6 +8,14 @@
 #include "kvm_cache_regs.h"
 #include "kvm_emulate.h"
 
+#define KVM_NESTED_VMENTER_CONSISTENCY_CHECK(consistency_check)		\
+({									\
+	bool failed = (consistency_check);				\
+	if (failed)							\
+		trace_kvm_nested_vmenter_failed(#consistency_check, 0);	\
+	failed;								\
+})
+
 #define KVM_DEFAULT_PLE_GAP		128
 #define KVM_VMX_DEFAULT_PLE_WINDOW	4096
 #define KVM_DEFAULT_PLE_WINDOW_GROW	2
@@ -48,6 +56,8 @@
 
 #define MSR_IA32_CR_PAT_DEFAULT  0x0007040600070406ULL
 
+int kvm_check_nested_events(struct kvm_vcpu *vcpu);
+
 static inline void kvm_clear_exception_queue(struct kvm_vcpu *vcpu)
 {
 	vcpu->arch.exception.pending = false;
diff --git a/drivers/clocksource/arm_arch_timer.c b/drivers/clocksource/arm_arch_timer.c
index d017782..e0f167e 100644
--- a/drivers/clocksource/arm_arch_timer.c
+++ b/drivers/clocksource/arm_arch_timer.c
@@ -16,6 +16,7 @@
 #include <linux/cpu_pm.h>
 #include <linux/clockchips.h>
 #include <linux/clocksource.h>
+#include <linux/clocksource_ids.h>
 #include <linux/interrupt.h>
 #include <linux/of_irq.h>
 #include <linux/of_address.h>
@@ -24,6 +25,8 @@
 #include <linux/sched/clock.h>
 #include <linux/sched_clock.h>
 #include <linux/acpi.h>
+#include <linux/arm-smccc.h>
+#include <linux/ptp_kvm.h>
 
 #include <asm/arch_timer.h>
 #include <asm/virt.h>
@@ -191,6 +194,7 @@
 
 static struct clocksource clocksource_counter = {
 	.name	= "arch_sys_counter",
+	.id	= CSID_ARM_ARCH_COUNTER,
 	.rating	= 400,
 	.read	= arch_counter_read,
 	.mask	= CLOCKSOURCE_MASK(56),
@@ -1657,3 +1661,35 @@
 }
 TIMER_ACPI_DECLARE(arch_timer, ACPI_SIG_GTDT, arch_timer_acpi_init);
 #endif
+
+int kvm_arch_ptp_get_crosststamp(u64 *cycle, struct timespec64 *ts,
+				 struct clocksource **cs)
+{
+	struct arm_smccc_res hvc_res;
+	u32 ptp_counter;
+	ktime_t ktime;
+
+	if (!IS_ENABLED(CONFIG_HAVE_ARM_SMCCC_DISCOVERY))
+		return -EOPNOTSUPP;
+
+	if (arch_timer_uses_ppi == ARCH_TIMER_VIRT_PPI)
+		ptp_counter = KVM_PTP_VIRT_COUNTER;
+	else
+		ptp_counter = KVM_PTP_PHYS_COUNTER;
+
+	arm_smccc_1_1_invoke(ARM_SMCCC_VENDOR_HYP_KVM_PTP_FUNC_ID,
+			     ptp_counter, &hvc_res);
+
+	if ((int)(hvc_res.a0) < 0)
+		return -EOPNOTSUPP;
+
+	ktime = (u64)hvc_res.a0 << 32 | hvc_res.a1;
+	*ts = ktime_to_timespec64(ktime);
+	if (cycle)
+		*cycle = (u64)hvc_res.a2 << 32 | hvc_res.a3;
+	if (cs)
+		*cs = &clocksource_counter;
+
+	return 0;
+}
+EXPORT_SYMBOL_GPL(kvm_arch_ptp_get_crosststamp);
diff --git a/drivers/firmware/psci/psci.c b/drivers/firmware/psci/psci.c
index f5fc429..69e296f 100644
--- a/drivers/firmware/psci/psci.c
+++ b/drivers/firmware/psci/psci.c
@@ -23,6 +23,7 @@
 
 #include <asm/cpuidle.h>
 #include <asm/cputype.h>
+#include <asm/hypervisor.h>
 #include <asm/system_misc.h>
 #include <asm/smp_plat.h>
 #include <asm/suspend.h>
@@ -498,6 +499,7 @@
 		psci_init_cpu_suspend();
 		psci_init_system_suspend();
 		psci_init_system_reset2();
+		kvm_init_hyp_services();
 	}
 
 	return 0;
diff --git a/drivers/firmware/smccc/Makefile b/drivers/firmware/smccc/Makefile
index 72ab840..40d1914 100644
--- a/drivers/firmware/smccc/Makefile
+++ b/drivers/firmware/smccc/Makefile
@@ -1,4 +1,4 @@
 # SPDX-License-Identifier: GPL-2.0
 #
-obj-$(CONFIG_HAVE_ARM_SMCCC_DISCOVERY)	+= smccc.o
+obj-$(CONFIG_HAVE_ARM_SMCCC_DISCOVERY)	+= smccc.o kvm_guest.o
 obj-$(CONFIG_ARM_SMCCC_SOC_ID)	+= soc_id.o
diff --git a/drivers/firmware/smccc/kvm_guest.c b/drivers/firmware/smccc/kvm_guest.c
new file mode 100644
index 0000000..2d3e866
--- /dev/null
+++ b/drivers/firmware/smccc/kvm_guest.c
@@ -0,0 +1,50 @@
+// SPDX-License-Identifier: GPL-2.0
+
+#define pr_fmt(fmt) "smccc: KVM: " fmt
+
+#include <linux/arm-smccc.h>
+#include <linux/bitmap.h>
+#include <linux/kernel.h>
+#include <linux/string.h>
+
+#include <asm/hypervisor.h>
+
+static DECLARE_BITMAP(__kvm_arm_hyp_services, ARM_SMCCC_KVM_NUM_FUNCS) __ro_after_init = { };
+
+void __init kvm_init_hyp_services(void)
+{
+	struct arm_smccc_res res;
+	u32 val[4];
+
+	if (arm_smccc_1_1_get_conduit() != SMCCC_CONDUIT_HVC)
+		return;
+
+	arm_smccc_1_1_invoke(ARM_SMCCC_VENDOR_HYP_CALL_UID_FUNC_ID, &res);
+	if (res.a0 != ARM_SMCCC_VENDOR_HYP_UID_KVM_REG_0 ||
+	    res.a1 != ARM_SMCCC_VENDOR_HYP_UID_KVM_REG_1 ||
+	    res.a2 != ARM_SMCCC_VENDOR_HYP_UID_KVM_REG_2 ||
+	    res.a3 != ARM_SMCCC_VENDOR_HYP_UID_KVM_REG_3)
+		return;
+
+	memset(&res, 0, sizeof(res));
+	arm_smccc_1_1_invoke(ARM_SMCCC_VENDOR_HYP_KVM_FEATURES_FUNC_ID, &res);
+
+	val[0] = lower_32_bits(res.a0);
+	val[1] = lower_32_bits(res.a1);
+	val[2] = lower_32_bits(res.a2);
+	val[3] = lower_32_bits(res.a3);
+
+	bitmap_from_arr32(__kvm_arm_hyp_services, val, ARM_SMCCC_KVM_NUM_FUNCS);
+
+	pr_info("hypervisor services detected (0x%08lx 0x%08lx 0x%08lx 0x%08lx)\n",
+		 res.a3, res.a2, res.a1, res.a0);
+}
+
+bool kvm_arm_hyp_service_available(u32 func_id)
+{
+	if (func_id >= ARM_SMCCC_KVM_NUM_FUNCS)
+		return false;
+
+	return test_bit(func_id, __kvm_arm_hyp_services);
+}
+EXPORT_SYMBOL_GPL(kvm_arm_hyp_service_available);
diff --git a/drivers/firmware/smccc/smccc.c b/drivers/firmware/smccc/smccc.c
index d52bfc5..028f81d 100644
--- a/drivers/firmware/smccc/smccc.c
+++ b/drivers/firmware/smccc/smccc.c
@@ -8,6 +8,7 @@
 #include <linux/cache.h>
 #include <linux/init.h>
 #include <linux/arm-smccc.h>
+#include <linux/kernel.h>
 #include <asm/archrandom.h>
 
 static u32 smccc_version = ARM_SMCCC_VERSION_1_0;
diff --git a/drivers/hwtracing/coresight/Kconfig b/drivers/hwtracing/coresight/Kconfig
index 7b44ba2..84530fd 100644
--- a/drivers/hwtracing/coresight/Kconfig
+++ b/drivers/hwtracing/coresight/Kconfig
@@ -97,15 +97,15 @@
 	  module will be called coresight-etm3x.
 
 config CORESIGHT_SOURCE_ETM4X
-	tristate "CoreSight Embedded Trace Macrocell 4.x driver"
+	tristate "CoreSight ETMv4.x / ETE driver"
 	depends on ARM64
 	select CORESIGHT_LINKS_AND_SINKS
 	select PID_IN_CONTEXTIDR
 	help
-	  This driver provides support for the ETM4.x tracer module, tracing the
-	  instructions that a processor is executing. This is primarily useful
-	  for instruction level tracing. Depending on the implemented version
-	  data tracing may also be available.
+	  This driver provides support for the CoreSight Embedded Trace Macrocell
+	  version 4.x and the Embedded Trace Extensions (ETE). Both are CPU tracer
+	  modules, tracing the instructions that a processor is executing. This is
+	  primarily useful for instruction level tracing.
 
 	  To compile this driver as a module, choose M here: the
 	  module will be called coresight-etm4x.
@@ -173,4 +173,18 @@
 	  CTI trigger connections between this and other devices.These
 	  registers are not used in normal operation and can leave devices in
 	  an inconsistent state.
+
+config CORESIGHT_TRBE
+	tristate "Trace Buffer Extension (TRBE) driver"
+	depends on ARM64 && CORESIGHT_SOURCE_ETM4X
+	help
+	  This driver provides support for percpu Trace Buffer Extension (TRBE).
+	  TRBE always needs to be used along with it's corresponding percpu ETE
+	  component. ETE generates trace data which is then captured with TRBE.
+	  Unlike traditional sink devices, TRBE is a CPU feature accessible via
+	  system registers. But it's explicit dependency with trace unit (ETE)
+	  requires it to be plugged in as a coresight sink device.
+
+	  To compile this driver as a module, choose M here: the module will be
+	  called coresight-trbe.
 endif
diff --git a/drivers/hwtracing/coresight/Makefile b/drivers/hwtracing/coresight/Makefile
index f20e357..d608165 100644
--- a/drivers/hwtracing/coresight/Makefile
+++ b/drivers/hwtracing/coresight/Makefile
@@ -21,5 +21,6 @@
 obj-$(CONFIG_CORESIGHT_CPU_DEBUG) += coresight-cpu-debug.o
 obj-$(CONFIG_CORESIGHT_CATU) += coresight-catu.o
 obj-$(CONFIG_CORESIGHT_CTI) += coresight-cti.o
+obj-$(CONFIG_CORESIGHT_TRBE) += coresight-trbe.o
 coresight-cti-y := coresight-cti-core.o	coresight-cti-platform.o \
 		   coresight-cti-sysfs.o
diff --git a/drivers/hwtracing/coresight/coresight-core.c b/drivers/hwtracing/coresight/coresight-core.c
index 0062c89..55c6456 100644
--- a/drivers/hwtracing/coresight/coresight-core.c
+++ b/drivers/hwtracing/coresight/coresight-core.c
@@ -23,6 +23,7 @@
 #include "coresight-priv.h"
 
 static DEFINE_MUTEX(coresight_mutex);
+DEFINE_PER_CPU(struct coresight_device *, csdev_sink);
 
 /**
  * struct coresight_node - elements of a path, from source to sink
@@ -70,6 +71,18 @@
 }
 EXPORT_SYMBOL_GPL(coresight_remove_cti_ops);
 
+void coresight_set_percpu_sink(int cpu, struct coresight_device *csdev)
+{
+	per_cpu(csdev_sink, cpu) = csdev;
+}
+EXPORT_SYMBOL_GPL(coresight_set_percpu_sink);
+
+struct coresight_device *coresight_get_percpu_sink(int cpu)
+{
+	return per_cpu(csdev_sink, cpu);
+}
+EXPORT_SYMBOL_GPL(coresight_get_percpu_sink);
+
 static int coresight_id_match(struct device *dev, void *data)
 {
 	int trace_id, i_trace_id;
@@ -784,6 +797,14 @@
 	if (csdev == sink)
 		goto out;
 
+	if (coresight_is_percpu_source(csdev) && coresight_is_percpu_sink(sink) &&
+	    sink == per_cpu(csdev_sink, source_ops(csdev)->cpu_id(csdev))) {
+		if (_coresight_build_path(sink, sink, path) == 0) {
+			found = true;
+			goto out;
+		}
+	}
+
 	/* Not a sink - recursively explore each port found on this element */
 	for (i = 0; i < csdev->pdata->nr_outport; i++) {
 		struct coresight_device *child_dev;
@@ -999,8 +1020,12 @@
 	int depth = 0;
 
 	/* look for a default sink if we have not found for this device */
-	if (!csdev->def_sink)
-		csdev->def_sink = coresight_find_sink(csdev, &depth);
+	if (!csdev->def_sink) {
+		if (coresight_is_percpu_source(csdev))
+			csdev->def_sink = per_cpu(csdev_sink, source_ops(csdev)->cpu_id(csdev));
+		if (!csdev->def_sink)
+			csdev->def_sink = coresight_find_sink(csdev, &depth);
+	}
 	return csdev->def_sink;
 }
 
diff --git a/drivers/hwtracing/coresight/coresight-etm-perf.c b/drivers/hwtracing/coresight/coresight-etm-perf.c
index 0f603b4..f123c26 100644
--- a/drivers/hwtracing/coresight/coresight-etm-perf.c
+++ b/drivers/hwtracing/coresight/coresight-etm-perf.c
@@ -24,7 +24,26 @@
 static struct pmu etm_pmu;
 static bool etm_perf_up;
 
-static DEFINE_PER_CPU(struct perf_output_handle, ctx_handle);
+/*
+ * An ETM context for a running event includes the perf aux handle
+ * and aux_data. For ETM, the aux_data (etm_event_data), consists of
+ * the trace path and the sink configuration. The event data is accessible
+ * via perf_get_aux(handle). However, a sink could "end" a perf output
+ * handle via the IRQ handler. And if the "sink" encounters a failure
+ * to "begin" another session (e.g due to lack of space in the buffer),
+ * the handle will be cleared. Thus, the event_data may not be accessible
+ * from the handle when we get to the etm_event_stop(), which is required
+ * for stopping the trace path. The event_data is guaranteed to stay alive
+ * until "free_aux()", which cannot happen as long as the event is active on
+ * the ETM. Thus the event_data for the session must be part of the ETM context
+ * to make sure we can disable the trace path.
+ */
+struct etm_ctxt {
+	struct perf_output_handle handle;
+	struct etm_event_data *event_data;
+};
+
+static DEFINE_PER_CPU(struct etm_ctxt, etm_ctxt);
 static DEFINE_PER_CPU(struct coresight_device *, csdev_src);
 
 /*
@@ -232,6 +251,25 @@
 	schedule_work(&event_data->work);
 }
 
+/*
+ * Check if two given sinks are compatible with each other,
+ * so that they can use the same sink buffers, when an event
+ * moves around.
+ */
+static bool sinks_compatible(struct coresight_device *a,
+			     struct coresight_device *b)
+{
+	if (!a || !b)
+		return false;
+	/*
+	 * If the sinks are of the same subtype and driven
+	 * by the same driver, we can use the same buffer
+	 * on these sinks.
+	 */
+	return (a->subtype.sink_subtype == b->subtype.sink_subtype) &&
+	       (sink_ops(a) == sink_ops(b));
+}
+
 static void *etm_setup_aux(struct perf_event *event, void **pages,
 			   int nr_pages, bool overwrite)
 {
@@ -239,6 +277,7 @@
 	int cpu = event->cpu;
 	cpumask_t *mask;
 	struct coresight_device *sink = NULL;
+	struct coresight_device *user_sink = NULL, *last_sink = NULL;
 	struct etm_event_data *event_data = NULL;
 
 	event_data = alloc_event_data(cpu);
@@ -249,7 +288,7 @@
 	/* First get the selected sink from user space. */
 	if (event->attr.config2) {
 		id = (u32)event->attr.config2;
-		sink = coresight_get_sink_by_id(id);
+		sink = user_sink = coresight_get_sink_by_id(id);
 	}
 
 	mask = &event_data->mask;
@@ -277,14 +316,33 @@
 		}
 
 		/*
-		 * No sink provided - look for a default sink for one of the
-		 * devices. At present we only support topology where all CPUs
-		 * use the same sink [N:1], so only need to find one sink. The
-		 * coresight_build_path later will remove any CPU that does not
-		 * attach to the sink, or if we have not found a sink.
+		 * No sink provided - look for a default sink for all the ETMs,
+		 * where this event can be scheduled.
+		 * We allocate the sink specific buffers only once for this
+		 * event. If the ETMs have different default sink devices, we
+		 * can only use a single "type" of sink as the event can carry
+		 * only one sink specific buffer. Thus we have to make sure
+		 * that the sinks are of the same type and driven by the same
+		 * driver, as the one we allocate the buffer for. As such
+		 * we choose the first sink and check if the remaining ETMs
+		 * have a compatible default sink. We don't trace on a CPU
+		 * if the sink is not compatible.
 		 */
-		if (!sink)
+		if (!user_sink) {
+			/* Find the default sink for this ETM */
 			sink = coresight_find_default_sink(csdev);
+			if (!sink) {
+				cpumask_clear_cpu(cpu, mask);
+				continue;
+			}
+
+			/* Check if this sink compatible with the last sink */
+			if (last_sink && !sinks_compatible(last_sink, sink)) {
+				cpumask_clear_cpu(cpu, mask);
+				continue;
+			}
+			last_sink = sink;
+		}
 
 		/*
 		 * Building a path doesn't enable it, it simply builds a
@@ -312,7 +370,12 @@
 	if (!sink_ops(sink)->alloc_buffer || !sink_ops(sink)->free_buffer)
 		goto err;
 
-	/* Allocate the sink buffer for this session */
+	/*
+	 * Allocate the sink buffer for this session. All the sinks
+	 * where this event can be scheduled are ensured to be of the
+	 * same type. Thus the same sink configuration is used by the
+	 * sinks.
+	 */
 	event_data->snk_config =
 			sink_ops(sink)->alloc_buffer(sink, event, pages,
 						     nr_pages, overwrite);
@@ -332,13 +395,18 @@
 {
 	int cpu = smp_processor_id();
 	struct etm_event_data *event_data;
-	struct perf_output_handle *handle = this_cpu_ptr(&ctx_handle);
+	struct etm_ctxt *ctxt = this_cpu_ptr(&etm_ctxt);
+	struct perf_output_handle *handle = &ctxt->handle;
 	struct coresight_device *sink, *csdev = per_cpu(csdev_src, cpu);
 	struct list_head *path;
 
 	if (!csdev)
 		goto fail;
 
+	/* Have we messed up our tracking ? */
+	if (WARN_ON(ctxt->event_data))
+		goto fail;
+
 	/*
 	 * Deal with the ring buffer API and get a handle on the
 	 * session's information.
@@ -374,6 +442,8 @@
 	if (source_ops(csdev)->enable(csdev, event, CS_MODE_PERF))
 		goto fail_disable_path;
 
+	/* Save the event_data for this ETM */
+	ctxt->event_data = event_data;
 out:
 	return;
 
@@ -392,13 +462,30 @@
 	int cpu = smp_processor_id();
 	unsigned long size;
 	struct coresight_device *sink, *csdev = per_cpu(csdev_src, cpu);
-	struct perf_output_handle *handle = this_cpu_ptr(&ctx_handle);
-	struct etm_event_data *event_data = perf_get_aux(handle);
+	struct etm_ctxt *ctxt = this_cpu_ptr(&etm_ctxt);
+	struct perf_output_handle *handle = &ctxt->handle;
+	struct etm_event_data *event_data;
 	struct list_head *path;
 
+	/*
+	 * If we still have access to the event_data via handle,
+	 * confirm that we haven't messed up the tracking.
+	 */
+	if (handle->event &&
+	    WARN_ON(perf_get_aux(handle) != ctxt->event_data))
+		return;
+
+	event_data = ctxt->event_data;
+	/* Clear the event_data as this ETM is stopping the trace. */
+	ctxt->event_data = NULL;
+
 	if (event->hw.state == PERF_HES_STOPPED)
 		return;
 
+	/* We must have a valid event_data for a running event */
+	if (WARN_ON(!event_data))
+		return;
+
 	if (!csdev)
 		return;
 
@@ -416,7 +503,13 @@
 	/* tell the core */
 	event->hw.state = PERF_HES_STOPPED;
 
-	if (mode & PERF_EF_UPDATE) {
+	/*
+	 * If the handle is not bound to an event anymore
+	 * (e.g, the sink driver was unable to restart the
+	 * handle due to lack of buffer space), we don't
+	 * have to do anything here.
+	 */
+	if (handle->event && (mode & PERF_EF_UPDATE)) {
 		if (WARN_ON_ONCE(handle->event != event))
 			return;
 
diff --git a/drivers/hwtracing/coresight/coresight-etm4x-core.c b/drivers/hwtracing/coresight/coresight-etm4x-core.c
index 15016f7..efb84ce 100644
--- a/drivers/hwtracing/coresight/coresight-etm4x-core.c
+++ b/drivers/hwtracing/coresight/coresight-etm4x-core.c
@@ -31,6 +31,7 @@
 #include <linux/pm_runtime.h>
 #include <linux/property.h>
 
+#include <asm/barrier.h>
 #include <asm/sections.h>
 #include <asm/sysreg.h>
 #include <asm/local.h>
@@ -114,30 +115,91 @@
 	}
 }
 
-static void etm4_os_unlock_csa(struct etmv4_drvdata *drvdata, struct csdev_access *csa)
+static u64 ete_sysreg_read(u32 offset, bool _relaxed, bool _64bit)
 {
-	/* Writing 0 to TRCOSLAR unlocks the trace registers */
-	etm4x_relaxed_write32(csa, 0x0, TRCOSLAR);
-	drvdata->os_unlock = true;
+	u64 res = 0;
+
+	switch (offset) {
+	ETE_READ_CASES(res)
+	default :
+		pr_warn_ratelimited("ete: trying to read unsupported register @%x\n",
+				    offset);
+	}
+
+	if (!_relaxed)
+		__iormb(res);	/* Imitate the !relaxed I/O helpers */
+
+	return res;
+}
+
+static void ete_sysreg_write(u64 val, u32 offset, bool _relaxed, bool _64bit)
+{
+	if (!_relaxed)
+		__iowmb();	/* Imitate the !relaxed I/O helpers */
+	if (!_64bit)
+		val &= GENMASK(31, 0);
+
+	switch (offset) {
+	ETE_WRITE_CASES(val)
+	default :
+		pr_warn_ratelimited("ete: trying to write to unsupported register @%x\n",
+				    offset);
+	}
+}
+
+static void etm_detect_os_lock(struct etmv4_drvdata *drvdata,
+			       struct csdev_access *csa)
+{
+	u32 oslsr = etm4x_relaxed_read32(csa, TRCOSLSR);
+
+	drvdata->os_lock_model = ETM_OSLSR_OSLM(oslsr);
+}
+
+static void etm_write_os_lock(struct etmv4_drvdata *drvdata,
+			      struct csdev_access *csa, u32 val)
+{
+	val = !!val;
+
+	switch (drvdata->os_lock_model) {
+	case ETM_OSLOCK_PRESENT:
+		etm4x_relaxed_write32(csa, val, TRCOSLAR);
+		break;
+	case ETM_OSLOCK_PE:
+		write_sysreg_s(val, SYS_OSLAR_EL1);
+		break;
+	default:
+		pr_warn_once("CPU%d: Unsupported Trace OSLock model: %x\n",
+			     smp_processor_id(), drvdata->os_lock_model);
+		fallthrough;
+	case ETM_OSLOCK_NI:
+		return;
+	}
 	isb();
 }
 
+static inline void etm4_os_unlock_csa(struct etmv4_drvdata *drvdata,
+				      struct csdev_access *csa)
+{
+	WARN_ON(drvdata->cpu != smp_processor_id());
+
+	/* Writing 0 to OS Lock unlocks the trace unit registers */
+	etm_write_os_lock(drvdata, csa, 0x0);
+	drvdata->os_unlock = true;
+}
+
 static void etm4_os_unlock(struct etmv4_drvdata *drvdata)
 {
 	if (!WARN_ON(!drvdata->csdev))
 		etm4_os_unlock_csa(drvdata, &drvdata->csdev->access);
-
 }
 
 static void etm4_os_lock(struct etmv4_drvdata *drvdata)
 {
 	if (WARN_ON(!drvdata->csdev))
 		return;
-
-	/* Writing 0x1 to TRCOSLAR locks the trace registers */
-	etm4x_relaxed_write32(&drvdata->csdev->access, 0x1, TRCOSLAR);
+	/* Writing 0x1 to OS Lock locks the trace registers */
+	etm_write_os_lock(drvdata, &drvdata->csdev->access, 0x1);
 	drvdata->os_unlock = false;
-	isb();
 }
 
 static void etm4_cs_lock(struct etmv4_drvdata *drvdata,
@@ -371,6 +433,13 @@
 		etm4x_relaxed_write32(csa, trcpdcr | TRCPDCR_PU, TRCPDCR);
 	}
 
+	/*
+	 * ETE mandates that the TRCRSR is written to before
+	 * enabling it.
+	 */
+	if (etm4x_is_ete(drvdata))
+		etm4x_relaxed_write32(csa, TRCRSR_TA, TRCRSR);
+
 	/* Enable the trace unit */
 	etm4x_relaxed_write32(csa, 1, TRCPRGCTLR);
 
@@ -654,6 +723,7 @@
 static void etm4_disable_hw(void *info)
 {
 	u32 control;
+	u64 trfcr;
 	struct etmv4_drvdata *drvdata = info;
 	struct etmv4_config *config = &drvdata->config;
 	struct coresight_device *csdev = drvdata->csdev;
@@ -677,18 +747,32 @@
 	control &= ~0x1;
 
 	/*
+	 * If the CPU supports v8.4 Trace filter Control,
+	 * set the ETM to trace prohibited region.
+	 */
+	if (drvdata->trfc) {
+		trfcr = read_sysreg_s(SYS_TRFCR_EL1);
+		write_sysreg_s(trfcr & ~(TRFCR_ELx_ExTRE | TRFCR_ELx_E0TRE),
+			       SYS_TRFCR_EL1);
+		isb();
+	}
+	/*
 	 * Make sure everything completes before disabling, as recommended
 	 * by section 7.3.77 ("TRCVICTLR, ViewInst Main Control Register,
 	 * SSTATUS") of ARM IHI 0064D
 	 */
 	dsb(sy);
 	isb();
+	/* Trace synchronization barrier, is a nop if not supported */
+	tsb_csync();
 	etm4x_relaxed_write32(csa, control, TRCPRGCTLR);
 
 	/* wait for TRCSTATR.PMSTABLE to go to '1' */
 	if (coresight_timeout(csa, TRCSTATR, TRCSTATR_PMSTABLE_BIT, 1))
 		dev_err(etm_dev,
 			"timeout while waiting for PM stable Trace Status\n");
+	if (drvdata->trfc)
+		write_sysreg_s(trfcr, SYS_TRFCR_EL1);
 
 	/* read the status of the single shot comparators */
 	for (i = 0; i < drvdata->nr_ss_cmp; i++) {
@@ -817,13 +901,24 @@
 	 * ETMs implementing sysreg access must implement TRCDEVARCH.
 	 */
 	devarch = read_etm4x_sysreg_const_offset(TRCDEVARCH);
-	if ((devarch & ETM_DEVARCH_ID_MASK) != ETM_DEVARCH_ETMv4x_ARCH)
+	switch (devarch & ETM_DEVARCH_ID_MASK) {
+	case ETM_DEVARCH_ETMv4x_ARCH:
+		*csa = (struct csdev_access) {
+			.io_mem	= false,
+			.read	= etm4x_sysreg_read,
+			.write	= etm4x_sysreg_write,
+		};
+		break;
+	case ETM_DEVARCH_ETE_ARCH:
+		*csa = (struct csdev_access) {
+			.io_mem	= false,
+			.read	= ete_sysreg_read,
+			.write	= ete_sysreg_write,
+		};
+		break;
+	default:
 		return false;
-	*csa = (struct csdev_access) {
-		.io_mem	= false,
-		.read	= etm4x_sysreg_read,
-		.write	= etm4x_sysreg_write,
-	};
+	}
 
 	drvdata->arch = etm_devarch_to_arch(devarch);
 	return true;
@@ -873,7 +968,7 @@
 	return false;
 }
 
-static void cpu_enable_tracing(void)
+static void cpu_enable_tracing(struct etmv4_drvdata *drvdata)
 {
 	u64 dfr0 = read_sysreg(id_aa64dfr0_el1);
 	u64 trfcr;
@@ -881,6 +976,7 @@
 	if (!cpuid_feature_extract_unsigned_field(dfr0, ID_AA64DFR0_TRACE_FILT_SHIFT))
 		return;
 
+	drvdata->trfc = true;
 	/*
 	 * If the CPU supports v8.4 SelfHosted Tracing, enable
 	 * tracing at the kernel EL and EL0, forcing to use the
@@ -920,6 +1016,9 @@
 	if (!etm4_init_csdev_access(drvdata, csa))
 		return;
 
+	/* Detect the support for OS Lock before we actually use it */
+	etm_detect_os_lock(drvdata, csa);
+
 	/* Make sure all registers are accessible */
 	etm4_os_unlock_csa(drvdata, csa);
 	etm4_cs_unlock(drvdata, csa);
@@ -1082,7 +1181,7 @@
 	/* NUMCNTR, bits[30:28] number of counters available for tracing */
 	drvdata->nr_cntr = BMVAL(etmidr5, 28, 30);
 	etm4_cs_lock(drvdata, csa);
-	cpu_enable_tracing();
+	cpu_enable_tracing(drvdata);
 }
 
 static inline u32 etm4_get_victlr_access_type(struct etmv4_config *config)
@@ -1760,6 +1859,8 @@
 	struct etmv4_drvdata *drvdata;
 	struct coresight_desc desc = { 0 };
 	struct etm4_init_arg init_arg = { 0 };
+	u8 major, minor;
+	char *type_name;
 
 	drvdata = devm_kzalloc(dev, sizeof(*drvdata), GFP_KERNEL);
 	if (!drvdata)
@@ -1786,10 +1887,6 @@
 	if (drvdata->cpu < 0)
 		return drvdata->cpu;
 
-	desc.name = devm_kasprintf(dev, GFP_KERNEL, "etm%d", drvdata->cpu);
-	if (!desc.name)
-		return -ENOMEM;
-
 	init_arg.drvdata = drvdata;
 	init_arg.csa = &desc.access;
 	init_arg.pid = etm_pid;
@@ -1806,6 +1903,22 @@
 	    fwnode_property_present(dev_fwnode(dev), "qcom,skip-power-up"))
 		drvdata->skip_power_up = true;
 
+	major = ETM_ARCH_MAJOR_VERSION(drvdata->arch);
+	minor = ETM_ARCH_MINOR_VERSION(drvdata->arch);
+
+	if (etm4x_is_ete(drvdata)) {
+		type_name = "ete";
+		/* ETE v1 has major version == 0b101. Adjust this for logging.*/
+		major -= 4;
+	} else {
+		type_name = "etm";
+	}
+
+	desc.name = devm_kasprintf(dev, GFP_KERNEL,
+				   "%s%d", type_name, drvdata->cpu);
+	if (!desc.name)
+		return -ENOMEM;
+
 	etm4_init_trace_id(drvdata);
 	etm4_set_default(&drvdata->config);
 
@@ -1833,9 +1946,8 @@
 
 	etmdrvdata[drvdata->cpu] = drvdata;
 
-	dev_info(&drvdata->csdev->dev, "CPU%d: ETM v%d.%d initialized\n",
-		 drvdata->cpu, ETM_ARCH_MAJOR_VERSION(drvdata->arch),
-		 ETM_ARCH_MINOR_VERSION(drvdata->arch));
+	dev_info(&drvdata->csdev->dev, "CPU%d: %s v%d.%d initialized\n",
+		 drvdata->cpu, type_name, major, minor);
 
 	if (boot_enable) {
 		coresight_enable(drvdata->csdev);
@@ -1978,6 +2090,7 @@
 
 static const struct of_device_id etm4_sysreg_match[] = {
 	{ .compatible	= "arm,coresight-etm4x-sysreg" },
+	{ .compatible	= "arm,embedded-trace-extension" },
 	{}
 };
 
diff --git a/drivers/hwtracing/coresight/coresight-etm4x-sysfs.c b/drivers/hwtracing/coresight/coresight-etm4x-sysfs.c
index 0995a10..007bad9 100644
--- a/drivers/hwtracing/coresight/coresight-etm4x-sysfs.c
+++ b/drivers/hwtracing/coresight/coresight-etm4x-sysfs.c
@@ -2374,12 +2374,20 @@
 etm4x_register_implemented(struct etmv4_drvdata *drvdata, u32 offset)
 {
 	switch (offset) {
-	ETM4x_SYSREG_LIST_CASES
+	ETM_COMMON_SYSREG_LIST_CASES
 		/*
-		 * Registers accessible via system instructions are always
-		 * implemented.
+		 * Common registers to ETE & ETM4x accessible via system
+		 * instructions are always implemented.
 		 */
 		return true;
+
+	ETM4x_ONLY_SYSREG_LIST_CASES
+		/*
+		 * We only support etm4x and ete. So if the device is not
+		 * ETE, it must be ETMv4x.
+		 */
+		return !etm4x_is_ete(drvdata);
+
 	ETM4x_MMAP_LIST_CASES
 		/*
 		 * Registers accessible only via memory-mapped registers
@@ -2389,8 +2397,13 @@
 		 * coresight_register() and the csdev is not initialized
 		 * until that is done. So rely on the drvdata->base to
 		 * detect if we have a memory mapped access.
+		 * Also ETE doesn't implement memory mapped access, thus
+		 * it is sufficient to check that we are using mmio.
 		 */
 		return !!drvdata->base;
+
+	ETE_ONLY_SYSREG_LIST_CASES
+		return etm4x_is_ete(drvdata);
 	}
 
 	return false;
diff --git a/drivers/hwtracing/coresight/coresight-etm4x.h b/drivers/hwtracing/coresight/coresight-etm4x.h
index 0af6057..e5b79bd 100644
--- a/drivers/hwtracing/coresight/coresight-etm4x.h
+++ b/drivers/hwtracing/coresight/coresight-etm4x.h
@@ -29,6 +29,7 @@
 #define TRCAUXCTLR			0x018
 #define TRCEVENTCTL0R			0x020
 #define TRCEVENTCTL1R			0x024
+#define TRCRSR				0x028
 #define TRCSTALLCTLR			0x02C
 #define TRCTSCTLR			0x030
 #define TRCSYNCPR			0x034
@@ -49,6 +50,7 @@
 #define TRCSEQRSTEVR			0x118
 #define TRCSEQSTR			0x11C
 #define TRCEXTINSELR			0x120
+#define TRCEXTINSELRn(n)		(0x120 + (n * 4)) /* n = 0-3 */
 #define TRCCNTRLDVRn(n)			(0x140 + (n * 4)) /* n = 0-3 */
 #define TRCCNTCTLRn(n)			(0x150 + (n * 4)) /* n = 0-3 */
 #define TRCCNTVRn(n)			(0x160 + (n * 4)) /* n = 0-3 */
@@ -126,6 +128,8 @@
 #define TRCCIDR2			0xFF8
 #define TRCCIDR3			0xFFC
 
+#define TRCRSR_TA			BIT(12)
+
 /*
  * System instructions to access ETM registers.
  * See ETMv4.4 spec ARM IHI0064F section 4.3.6 System instructions
@@ -160,10 +164,22 @@
 #define CASE_NOP(__unused, x)					\
 	case (x):	/* fall through */
 
+#define ETE_ONLY_SYSREG_LIST(op, val)		\
+	CASE_##op((val), TRCRSR)		\
+	CASE_##op((val), TRCEXTINSELRn(1))	\
+	CASE_##op((val), TRCEXTINSELRn(2))	\
+	CASE_##op((val), TRCEXTINSELRn(3))
+
 /* List of registers accessible via System instructions */
-#define ETM_SYSREG_LIST(op, val)		\
-	CASE_##op((val), TRCPRGCTLR)		\
+#define ETM4x_ONLY_SYSREG_LIST(op, val)		\
 	CASE_##op((val), TRCPROCSELR)		\
+	CASE_##op((val), TRCVDCTLR)		\
+	CASE_##op((val), TRCVDSACCTLR)		\
+	CASE_##op((val), TRCVDARCCTLR)		\
+	CASE_##op((val), TRCOSLAR)
+
+#define ETM_COMMON_SYSREG_LIST(op, val)		\
+	CASE_##op((val), TRCPRGCTLR)		\
 	CASE_##op((val), TRCSTATR)		\
 	CASE_##op((val), TRCCONFIGR)		\
 	CASE_##op((val), TRCAUXCTLR)		\
@@ -180,9 +196,6 @@
 	CASE_##op((val), TRCVIIECTLR)		\
 	CASE_##op((val), TRCVISSCTLR)		\
 	CASE_##op((val), TRCVIPCSSCTLR)		\
-	CASE_##op((val), TRCVDCTLR)		\
-	CASE_##op((val), TRCVDSACCTLR)		\
-	CASE_##op((val), TRCVDARCCTLR)		\
 	CASE_##op((val), TRCSEQEVRn(0))		\
 	CASE_##op((val), TRCSEQEVRn(1))		\
 	CASE_##op((val), TRCSEQEVRn(2))		\
@@ -277,7 +290,6 @@
 	CASE_##op((val), TRCSSPCICRn(5))	\
 	CASE_##op((val), TRCSSPCICRn(6))	\
 	CASE_##op((val), TRCSSPCICRn(7))	\
-	CASE_##op((val), TRCOSLAR)		\
 	CASE_##op((val), TRCOSLSR)		\
 	CASE_##op((val), TRCACVRn(0))		\
 	CASE_##op((val), TRCACVRn(1))		\
@@ -369,12 +381,38 @@
 	CASE_##op((val), TRCPIDR2)		\
 	CASE_##op((val), TRCPIDR3)
 
-#define ETM4x_READ_SYSREG_CASES(res)	ETM_SYSREG_LIST(READ, (res))
-#define ETM4x_WRITE_SYSREG_CASES(val)	ETM_SYSREG_LIST(WRITE, (val))
+#define ETM4x_READ_SYSREG_CASES(res)		\
+	ETM_COMMON_SYSREG_LIST(READ, (res))	\
+	ETM4x_ONLY_SYSREG_LIST(READ, (res))
 
-#define ETM4x_SYSREG_LIST_CASES		ETM_SYSREG_LIST(NOP, __unused)
+#define ETM4x_WRITE_SYSREG_CASES(val)		\
+	ETM_COMMON_SYSREG_LIST(WRITE, (val))	\
+	ETM4x_ONLY_SYSREG_LIST(WRITE, (val))
+
+#define ETM_COMMON_SYSREG_LIST_CASES		\
+	ETM_COMMON_SYSREG_LIST(NOP, __unused)
+
+#define ETM4x_ONLY_SYSREG_LIST_CASES		\
+	ETM4x_ONLY_SYSREG_LIST(NOP, __unused)
+
+#define ETM4x_SYSREG_LIST_CASES			\
+	ETM_COMMON_SYSREG_LIST_CASES		\
+	ETM4x_ONLY_SYSREG_LIST(NOP, __unused)
+
 #define ETM4x_MMAP_LIST_CASES		ETM_MMAP_LIST(NOP, __unused)
 
+/* ETE only supports system register access */
+#define ETE_READ_CASES(res)			\
+	ETM_COMMON_SYSREG_LIST(READ, (res))	\
+	ETE_ONLY_SYSREG_LIST(READ, (res))
+
+#define ETE_WRITE_CASES(val)			\
+	ETM_COMMON_SYSREG_LIST(WRITE, (val))	\
+	ETE_ONLY_SYSREG_LIST(WRITE, (val))
+
+#define ETE_ONLY_SYSREG_LIST_CASES		\
+	ETE_ONLY_SYSREG_LIST(NOP, __unused)
+
 #define read_etm4x_sysreg_offset(offset, _64bit)				\
 	({									\
 		u64 __val;							\
@@ -506,6 +544,20 @@
 					 ETM_MODE_EXCL_USER)
 
 /*
+ * TRCOSLSR.OSLM advertises the OS Lock model.
+ * OSLM[2:0] = TRCOSLSR[4:3,0]
+ *
+ *	0b000 - Trace OS Lock is not implemented.
+ *	0b010 - Trace OS Lock is implemented.
+ *	0b100 - Trace OS Lock is not implemented, unit is controlled by PE OS Lock.
+ */
+#define ETM_OSLOCK_NI		0b000
+#define ETM_OSLOCK_PRESENT	0b010
+#define ETM_OSLOCK_PE		0b100
+
+#define ETM_OSLSR_OSLM(oslsr)	((((oslsr) & GENMASK(4, 3)) >> 2) | (oslsr & 0x1))
+
+/*
  * TRCDEVARCH Bit field definitions
  * Bits[31:21]	- ARCHITECT = Always Arm Ltd.
  *                * Bits[31:28] = 0x4
@@ -541,11 +593,14 @@
 	((ETM_DEVARCH_MAKE_ARCHID_ARCH_VER(major)) | ETM_DEVARCH_ARCHID_ARCH_PART(0xA13))
 
 #define ETM_DEVARCH_ARCHID_ETMv4x		ETM_DEVARCH_MAKE_ARCHID(0x4)
+#define ETM_DEVARCH_ARCHID_ETE			ETM_DEVARCH_MAKE_ARCHID(0x5)
 
 #define ETM_DEVARCH_ID_MASK						\
 	(ETM_DEVARCH_ARCHITECT_MASK | ETM_DEVARCH_ARCHID_MASK | ETM_DEVARCH_PRESENT)
 #define ETM_DEVARCH_ETMv4x_ARCH						\
 	(ETM_DEVARCH_ARCHITECT_ARM | ETM_DEVARCH_ARCHID_ETMv4x | ETM_DEVARCH_PRESENT)
+#define ETM_DEVARCH_ETE_ARCH						\
+	(ETM_DEVARCH_ARCHITECT_ARM | ETM_DEVARCH_ARCHID_ETE | ETM_DEVARCH_PRESENT)
 
 #define TRCSTATR_IDLE_BIT		0
 #define TRCSTATR_PMSTABLE_BIT		1
@@ -635,6 +690,8 @@
 #define ETM_ARCH_MINOR_VERSION(arch)	((arch) & 0xfU)
 
 #define ETM_ARCH_V4	ETM_ARCH_VERSION(4, 0)
+#define ETM_ARCH_ETE	ETM_ARCH_VERSION(5, 0)
+
 /* Interpretation of resource numbers change at ETM v4.3 architecture */
 #define ETM_ARCH_V4_3	ETM_ARCH_VERSION(4, 3)
 
@@ -862,6 +919,7 @@
  * @nooverflow:	Indicate if overflow prevention is supported.
  * @atbtrig:	If the implementation can support ATB triggers
  * @lpoverride:	If the implementation can support low-power state over.
+ * @trfc:	If the implementation supports Arm v8.4 trace filter controls.
  * @config:	structure holding configuration parameters.
  * @save_state:	State to be preserved across power loss
  * @state_needs_restore: True when there is context to restore after PM exit
@@ -897,6 +955,7 @@
 	u8				s_ex_level;
 	u8				ns_ex_level;
 	u8				q_support;
+	u8				os_lock_model;
 	bool				sticky_enable;
 	bool				boot_enable;
 	bool				os_unlock;
@@ -912,6 +971,7 @@
 	bool				nooverflow;
 	bool				atbtrig;
 	bool				lpoverride;
+	bool				trfc;
 	struct etmv4_config		config;
 	struct etmv4_save_state		*save_state;
 	bool				state_needs_restore;
@@ -940,4 +1000,9 @@
 
 u64 etm4x_sysreg_read(u32 offset, bool _relaxed, bool _64bit);
 void etm4x_sysreg_write(u64 val, u32 offset, bool _relaxed, bool _64bit);
+
+static inline bool etm4x_is_ete(struct etmv4_drvdata *drvdata)
+{
+	return drvdata->arch >= ETM_ARCH_ETE;
+}
 #endif
diff --git a/drivers/hwtracing/coresight/coresight-platform.c b/drivers/hwtracing/coresight/coresight-platform.c
index 3629b78..c594f45 100644
--- a/drivers/hwtracing/coresight/coresight-platform.c
+++ b/drivers/hwtracing/coresight/coresight-platform.c
@@ -90,6 +90,12 @@
 	struct of_endpoint endpoint;
 	int in = 0, out = 0;
 
+	/*
+	 * Avoid warnings in of_graph_get_next_endpoint()
+	 * if the device doesn't have any graph connections
+	 */
+	if (!of_graph_is_present(node))
+		return;
 	do {
 		ep = of_graph_get_next_endpoint(node, ep);
 		if (!ep)
diff --git a/drivers/hwtracing/coresight/coresight-priv.h b/drivers/hwtracing/coresight/coresight-priv.h
index f5f654e..ff1dd20 100644
--- a/drivers/hwtracing/coresight/coresight-priv.h
+++ b/drivers/hwtracing/coresight/coresight-priv.h
@@ -232,4 +232,7 @@
 void coresight_set_assoc_ectdev_mutex(struct coresight_device *csdev,
 				      struct coresight_device *ect_csdev);
 
+void coresight_set_percpu_sink(int cpu, struct coresight_device *csdev);
+struct coresight_device *coresight_get_percpu_sink(int cpu);
+
 #endif
diff --git a/drivers/hwtracing/coresight/coresight-trbe.c b/drivers/hwtracing/coresight/coresight-trbe.c
new file mode 100644
index 0000000..5ce2398
--- /dev/null
+++ b/drivers/hwtracing/coresight/coresight-trbe.c
@@ -0,0 +1,1157 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * This driver enables Trace Buffer Extension (TRBE) as a per-cpu coresight
+ * sink device could then pair with an appropriate per-cpu coresight source
+ * device (ETE) thus generating required trace data. Trace can be enabled
+ * via the perf framework.
+ *
+ * The AUX buffer handling is inspired from Arm SPE PMU driver.
+ *
+ * Copyright (C) 2020 ARM Ltd.
+ *
+ * Author: Anshuman Khandual <anshuman.khandual@arm.com>
+ */
+#define DRVNAME "arm_trbe"
+
+#define pr_fmt(fmt) DRVNAME ": " fmt
+
+#include <asm/barrier.h>
+#include "coresight-trbe.h"
+
+#define PERF_IDX2OFF(idx, buf) ((idx) % ((buf)->nr_pages << PAGE_SHIFT))
+
+/*
+ * A padding packet that will help the user space tools
+ * in skipping relevant sections in the captured trace
+ * data which could not be decoded. TRBE doesn't support
+ * formatting the trace data, unlike the legacy CoreSight
+ * sinks and thus we use ETE trace packets to pad the
+ * sections of the buffer.
+ */
+#define ETE_IGNORE_PACKET		0x70
+
+/*
+ * Minimum amount of meaningful trace will contain:
+ * A-Sync, Trace Info, Trace On, Address, Atom.
+ * This is about 44bytes of ETE trace. To be on
+ * the safer side, we assume 64bytes is the minimum
+ * space required for a meaningful session, before
+ * we hit a "WRAP" event.
+ */
+#define TRBE_TRACE_MIN_BUF_SIZE		64
+
+enum trbe_fault_action {
+	TRBE_FAULT_ACT_WRAP,
+	TRBE_FAULT_ACT_SPURIOUS,
+	TRBE_FAULT_ACT_FATAL,
+};
+
+struct trbe_buf {
+	/*
+	 * Even though trbe_base represents vmap()
+	 * mapped allocated buffer's start address,
+	 * it's being as unsigned long for various
+	 * arithmetic and comparision operations &
+	 * also to be consistent with trbe_write &
+	 * trbe_limit sibling pointers.
+	 */
+	unsigned long trbe_base;
+	unsigned long trbe_limit;
+	unsigned long trbe_write;
+	int nr_pages;
+	void **pages;
+	bool snapshot;
+	struct trbe_cpudata *cpudata;
+};
+
+struct trbe_cpudata {
+	bool trbe_flag;
+	u64 trbe_align;
+	int cpu;
+	enum cs_mode mode;
+	struct trbe_buf *buf;
+	struct trbe_drvdata *drvdata;
+};
+
+struct trbe_drvdata {
+	struct trbe_cpudata __percpu *cpudata;
+	struct perf_output_handle * __percpu *handle;
+	struct hlist_node hotplug_node;
+	int irq;
+	cpumask_t supported_cpus;
+	enum cpuhp_state trbe_online;
+	struct platform_device *pdev;
+};
+
+static int trbe_alloc_node(struct perf_event *event)
+{
+	if (event->cpu == -1)
+		return NUMA_NO_NODE;
+	return cpu_to_node(event->cpu);
+}
+
+static void trbe_drain_buffer(void)
+{
+	tsb_csync();
+	dsb(nsh);
+}
+
+static void trbe_drain_and_disable_local(void)
+{
+	u64 trblimitr = read_sysreg_s(SYS_TRBLIMITR_EL1);
+
+	trbe_drain_buffer();
+
+	/*
+	 * Disable the TRBE without clearing LIMITPTR which
+	 * might be required for fetching the buffer limits.
+	 */
+	trblimitr &= ~TRBLIMITR_ENABLE;
+	write_sysreg_s(trblimitr, SYS_TRBLIMITR_EL1);
+	isb();
+}
+
+static void trbe_reset_local(void)
+{
+	trbe_drain_and_disable_local();
+	write_sysreg_s(0, SYS_TRBLIMITR_EL1);
+	write_sysreg_s(0, SYS_TRBPTR_EL1);
+	write_sysreg_s(0, SYS_TRBBASER_EL1);
+	write_sysreg_s(0, SYS_TRBSR_EL1);
+}
+
+static void trbe_stop_and_truncate_event(struct perf_output_handle *handle)
+{
+	struct trbe_buf *buf = etm_perf_sink_config(handle);
+
+	/*
+	 * We cannot proceed with the buffer collection and we
+	 * do not have any data for the current session. The
+	 * etm_perf driver expects to close out the aux_buffer
+	 * at event_stop(). So disable the TRBE here and leave
+	 * the update_buffer() to return a 0 size.
+	 */
+	trbe_drain_and_disable_local();
+	perf_aux_output_flag(handle, PERF_AUX_FLAG_TRUNCATED);
+	*this_cpu_ptr(buf->cpudata->drvdata->handle) = NULL;
+}
+
+/*
+ * TRBE Buffer Management
+ *
+ * The TRBE buffer spans from the base pointer till the limit pointer. When enabled,
+ * it starts writing trace data from the write pointer onward till the limit pointer.
+ * When the write pointer reaches the address just before the limit pointer, it gets
+ * wrapped around again to the base pointer. This is called a TRBE wrap event, which
+ * generates a maintenance interrupt when operated in WRAP or FILL mode. This driver
+ * uses FILL mode, where the TRBE stops the trace collection at wrap event. The IRQ
+ * handler updates the AUX buffer and re-enables the TRBE with updated WRITE and
+ * LIMIT pointers.
+ *
+ *	Wrap around with an IRQ
+ *	------ < ------ < ------- < ----- < -----
+ *	|					|
+ *	------ > ------ > ------- > ----- > -----
+ *
+ *	+---------------+-----------------------+
+ *	|		|			|
+ *	+---------------+-----------------------+
+ *	Base Pointer	Write Pointer		Limit Pointer
+ *
+ * The base and limit pointers always needs to be PAGE_SIZE aligned. But the write
+ * pointer can be aligned to the implementation defined TRBE trace buffer alignment
+ * as captured in trbe_cpudata->trbe_align.
+ *
+ *
+ *		head		tail		wakeup
+ *	+---------------------------------------+----- ~ ~ ------
+ *	|$$$$$$$|################|$$$$$$$$$$$$$$|		|
+ *	+---------------------------------------+----- ~ ~ ------
+ *	Base Pointer	Write Pointer		Limit Pointer
+ *
+ * The perf_output_handle indices (head, tail, wakeup) are monotonically increasing
+ * values which tracks all the driver writes and user reads from the perf auxiliary
+ * buffer. Generally [head..tail] is the area where the driver can write into unless
+ * the wakeup is behind the tail. Enabled TRBE buffer span needs to be adjusted and
+ * configured depending on the perf_output_handle indices, so that the driver does
+ * not override into areas in the perf auxiliary buffer which is being or yet to be
+ * consumed from the user space. The enabled TRBE buffer area is a moving subset of
+ * the allocated perf auxiliary buffer.
+ */
+static void trbe_pad_buf(struct perf_output_handle *handle, int len)
+{
+	struct trbe_buf *buf = etm_perf_sink_config(handle);
+	u64 head = PERF_IDX2OFF(handle->head, buf);
+
+	memset((void *)buf->trbe_base + head, ETE_IGNORE_PACKET, len);
+	if (!buf->snapshot)
+		perf_aux_output_skip(handle, len);
+}
+
+static unsigned long trbe_snapshot_offset(struct perf_output_handle *handle)
+{
+	struct trbe_buf *buf = etm_perf_sink_config(handle);
+
+	/*
+	 * The ETE trace has alignment synchronization packets allowing
+	 * the decoder to reset in case of an overflow or corruption.
+	 * So we can use the entire buffer for the snapshot mode.
+	 */
+	return buf->nr_pages * PAGE_SIZE;
+}
+
+/*
+ * TRBE Limit Calculation
+ *
+ * The following markers are used to illustrate various TRBE buffer situations.
+ *
+ * $$$$ - Data area, unconsumed captured trace data, not to be overridden
+ * #### - Free area, enabled, trace will be written
+ * %%%% - Free area, disabled, trace will not be written
+ * ==== - Free area, padded with ETE_IGNORE_PACKET, trace will be skipped
+ */
+static unsigned long __trbe_normal_offset(struct perf_output_handle *handle)
+{
+	struct trbe_buf *buf = etm_perf_sink_config(handle);
+	struct trbe_cpudata *cpudata = buf->cpudata;
+	const u64 bufsize = buf->nr_pages * PAGE_SIZE;
+	u64 limit = bufsize;
+	u64 head, tail, wakeup;
+
+	head = PERF_IDX2OFF(handle->head, buf);
+
+	/*
+	 *		head
+	 *	------->|
+	 *	|
+	 *	head	TRBE align	tail
+	 * +----|-------|---------------|-------+
+	 * |$$$$|=======|###############|$$$$$$$|
+	 * +----|-------|---------------|-------+
+	 * trbe_base				trbe_base + nr_pages
+	 *
+	 * Perf aux buffer output head position can be misaligned depending on
+	 * various factors including user space reads. In case misaligned, head
+	 * needs to be aligned before TRBE can be configured. Pad the alignment
+	 * gap with ETE_IGNORE_PACKET bytes that will be ignored by user tools
+	 * and skip this section thus advancing the head.
+	 */
+	if (!IS_ALIGNED(head, cpudata->trbe_align)) {
+		unsigned long delta = roundup(head, cpudata->trbe_align) - head;
+
+		delta = min(delta, handle->size);
+		trbe_pad_buf(handle, delta);
+		head = PERF_IDX2OFF(handle->head, buf);
+	}
+
+	/*
+	 *	head = tail (size = 0)
+	 * +----|-------------------------------+
+	 * |$$$$|$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$	|
+	 * +----|-------------------------------+
+	 * trbe_base				trbe_base + nr_pages
+	 *
+	 * Perf aux buffer does not have any space for the driver to write into.
+	 * Just communicate trace truncation event to the user space by marking
+	 * it with PERF_AUX_FLAG_TRUNCATED.
+	 */
+	if (!handle->size) {
+		perf_aux_output_flag(handle, PERF_AUX_FLAG_TRUNCATED);
+		return 0;
+	}
+
+	/* Compute the tail and wakeup indices now that we've aligned head */
+	tail = PERF_IDX2OFF(handle->head + handle->size, buf);
+	wakeup = PERF_IDX2OFF(handle->wakeup, buf);
+
+	/*
+	 * Lets calculate the buffer area which TRBE could write into. There
+	 * are three possible scenarios here. Limit needs to be aligned with
+	 * PAGE_SIZE per the TRBE requirement. Always avoid clobbering the
+	 * unconsumed data.
+	 *
+	 * 1) head < tail
+	 *
+	 *	head			tail
+	 * +----|-----------------------|-------+
+	 * |$$$$|#######################|$$$$$$$|
+	 * +----|-----------------------|-------+
+	 * trbe_base			limit	trbe_base + nr_pages
+	 *
+	 * TRBE could write into [head..tail] area. Unless the tail is right at
+	 * the end of the buffer, neither an wrap around nor an IRQ is expected
+	 * while being enabled.
+	 *
+	 * 2) head == tail
+	 *
+	 *	head = tail (size > 0)
+	 * +----|-------------------------------+
+	 * |%%%%|###############################|
+	 * +----|-------------------------------+
+	 * trbe_base				limit = trbe_base + nr_pages
+	 *
+	 * TRBE should just write into [head..base + nr_pages] area even though
+	 * the entire buffer is empty. Reason being, when the trace reaches the
+	 * end of the buffer, it will just wrap around with an IRQ giving an
+	 * opportunity to reconfigure the buffer.
+	 *
+	 * 3) tail < head
+	 *
+	 *	tail			head
+	 * +----|-----------------------|-------+
+	 * |%%%%|$$$$$$$$$$$$$$$$$$$$$$$|#######|
+	 * +----|-----------------------|-------+
+	 * trbe_base				limit = trbe_base + nr_pages
+	 *
+	 * TRBE should just write into [head..base + nr_pages] area even though
+	 * the [trbe_base..tail] is also empty. Reason being, when the trace
+	 * reaches the end of the buffer, it will just wrap around with an IRQ
+	 * giving an opportunity to reconfigure the buffer.
+	 */
+	if (head < tail)
+		limit = round_down(tail, PAGE_SIZE);
+
+	/*
+	 * Wakeup may be arbitrarily far into the future. If it's not in the
+	 * current generation, either we'll wrap before hitting it, or it's
+	 * in the past and has been handled already.
+	 *
+	 * If there's a wakeup before we wrap, arrange to be woken up by the
+	 * page boundary following it. Keep the tail boundary if that's lower.
+	 *
+	 *	head		wakeup	tail
+	 * +----|---------------|-------|-------+
+	 * |$$$$|###############|%%%%%%%|$$$$$$$|
+	 * +----|---------------|-------|-------+
+	 * trbe_base		limit		trbe_base + nr_pages
+	 */
+	if (handle->wakeup < (handle->head + handle->size) && head <= wakeup)
+		limit = min(limit, round_up(wakeup, PAGE_SIZE));
+
+	/*
+	 * There are two situation when this can happen i.e limit is before
+	 * the head and hence TRBE cannot be configured.
+	 *
+	 * 1) head < tail (aligned down with PAGE_SIZE) and also they are both
+	 * within the same PAGE size range.
+	 *
+	 *			PAGE_SIZE
+	 *		|----------------------|
+	 *
+	 *		limit	head	tail
+	 * +------------|------|--------|-------+
+	 * |$$$$$$$$$$$$$$$$$$$|========|$$$$$$$|
+	 * +------------|------|--------|-------+
+	 * trbe_base				trbe_base + nr_pages
+	 *
+	 * 2) head < wakeup (aligned up with PAGE_SIZE) < tail and also both
+	 * head and wakeup are within same PAGE size range.
+	 *
+	 *		PAGE_SIZE
+	 *	|----------------------|
+	 *
+	 *	limit	head	wakeup  tail
+	 * +----|------|-------|--------|-------+
+	 * |$$$$$$$$$$$|=======|========|$$$$$$$|
+	 * +----|------|-------|--------|-------+
+	 * trbe_base				trbe_base + nr_pages
+	 */
+	if (limit > head)
+		return limit;
+
+	trbe_pad_buf(handle, handle->size);
+	perf_aux_output_flag(handle, PERF_AUX_FLAG_TRUNCATED);
+	return 0;
+}
+
+static unsigned long trbe_normal_offset(struct perf_output_handle *handle)
+{
+	struct trbe_buf *buf = perf_get_aux(handle);
+	u64 limit = __trbe_normal_offset(handle);
+	u64 head = PERF_IDX2OFF(handle->head, buf);
+
+	/*
+	 * If the head is too close to the limit and we don't
+	 * have space for a meaningful run, we rather pad it
+	 * and start fresh.
+	 */
+	if (limit && (limit - head < TRBE_TRACE_MIN_BUF_SIZE)) {
+		trbe_pad_buf(handle, limit - head);
+		limit = __trbe_normal_offset(handle);
+	}
+	return limit;
+}
+
+static unsigned long compute_trbe_buffer_limit(struct perf_output_handle *handle)
+{
+	struct trbe_buf *buf = etm_perf_sink_config(handle);
+	unsigned long offset;
+
+	if (buf->snapshot)
+		offset = trbe_snapshot_offset(handle);
+	else
+		offset = trbe_normal_offset(handle);
+	return buf->trbe_base + offset;
+}
+
+static void clr_trbe_status(void)
+{
+	u64 trbsr = read_sysreg_s(SYS_TRBSR_EL1);
+
+	WARN_ON(is_trbe_enabled());
+	trbsr &= ~TRBSR_IRQ;
+	trbsr &= ~TRBSR_TRG;
+	trbsr &= ~TRBSR_WRAP;
+	trbsr &= ~(TRBSR_EC_MASK << TRBSR_EC_SHIFT);
+	trbsr &= ~(TRBSR_BSC_MASK << TRBSR_BSC_SHIFT);
+	trbsr &= ~TRBSR_STOP;
+	write_sysreg_s(trbsr, SYS_TRBSR_EL1);
+}
+
+static void set_trbe_limit_pointer_enabled(unsigned long addr)
+{
+	u64 trblimitr = read_sysreg_s(SYS_TRBLIMITR_EL1);
+
+	WARN_ON(!IS_ALIGNED(addr, (1UL << TRBLIMITR_LIMIT_SHIFT)));
+	WARN_ON(!IS_ALIGNED(addr, PAGE_SIZE));
+
+	trblimitr &= ~TRBLIMITR_NVM;
+	trblimitr &= ~(TRBLIMITR_FILL_MODE_MASK << TRBLIMITR_FILL_MODE_SHIFT);
+	trblimitr &= ~(TRBLIMITR_TRIG_MODE_MASK << TRBLIMITR_TRIG_MODE_SHIFT);
+	trblimitr &= ~(TRBLIMITR_LIMIT_MASK << TRBLIMITR_LIMIT_SHIFT);
+
+	/*
+	 * Fill trace buffer mode is used here while configuring the
+	 * TRBE for trace capture. In this particular mode, the trace
+	 * collection is stopped and a maintenance interrupt is raised
+	 * when the current write pointer wraps. This pause in trace
+	 * collection gives the software an opportunity to capture the
+	 * trace data in the interrupt handler, before reconfiguring
+	 * the TRBE.
+	 */
+	trblimitr |= (TRBE_FILL_MODE_FILL & TRBLIMITR_FILL_MODE_MASK) << TRBLIMITR_FILL_MODE_SHIFT;
+
+	/*
+	 * Trigger mode is not used here while configuring the TRBE for
+	 * the trace capture. Hence just keep this in the ignore mode.
+	 */
+	trblimitr |= (TRBE_TRIG_MODE_IGNORE & TRBLIMITR_TRIG_MODE_MASK) <<
+		      TRBLIMITR_TRIG_MODE_SHIFT;
+	trblimitr |= (addr & PAGE_MASK);
+
+	trblimitr |= TRBLIMITR_ENABLE;
+	write_sysreg_s(trblimitr, SYS_TRBLIMITR_EL1);
+
+	/* Synchronize the TRBE enable event */
+	isb();
+}
+
+static void trbe_enable_hw(struct trbe_buf *buf)
+{
+	WARN_ON(buf->trbe_write < buf->trbe_base);
+	WARN_ON(buf->trbe_write >= buf->trbe_limit);
+	set_trbe_disabled();
+	isb();
+	clr_trbe_status();
+	set_trbe_base_pointer(buf->trbe_base);
+	set_trbe_write_pointer(buf->trbe_write);
+
+	/*
+	 * Synchronize all the register updates
+	 * till now before enabling the TRBE.
+	 */
+	isb();
+	set_trbe_limit_pointer_enabled(buf->trbe_limit);
+}
+
+static enum trbe_fault_action trbe_get_fault_act(u64 trbsr)
+{
+	int ec = get_trbe_ec(trbsr);
+	int bsc = get_trbe_bsc(trbsr);
+
+	WARN_ON(is_trbe_running(trbsr));
+	if (is_trbe_trg(trbsr) || is_trbe_abort(trbsr))
+		return TRBE_FAULT_ACT_FATAL;
+
+	if ((ec == TRBE_EC_STAGE1_ABORT) || (ec == TRBE_EC_STAGE2_ABORT))
+		return TRBE_FAULT_ACT_FATAL;
+
+	if (is_trbe_wrap(trbsr) && (ec == TRBE_EC_OTHERS) && (bsc == TRBE_BSC_FILLED)) {
+		if (get_trbe_write_pointer() == get_trbe_base_pointer())
+			return TRBE_FAULT_ACT_WRAP;
+	}
+	return TRBE_FAULT_ACT_SPURIOUS;
+}
+
+static void *arm_trbe_alloc_buffer(struct coresight_device *csdev,
+				   struct perf_event *event, void **pages,
+				   int nr_pages, bool snapshot)
+{
+	struct trbe_buf *buf;
+	struct page **pglist;
+	int i;
+
+	/*
+	 * TRBE LIMIT and TRBE WRITE pointers must be page aligned. But with
+	 * just a single page, there would not be any room left while writing
+	 * into a partially filled TRBE buffer after the page size alignment.
+	 * Hence restrict the minimum buffer size as two pages.
+	 */
+	if (nr_pages < 2)
+		return NULL;
+
+	buf = kzalloc_node(sizeof(*buf), GFP_KERNEL, trbe_alloc_node(event));
+	if (!buf)
+		return ERR_PTR(-ENOMEM);
+
+	pglist = kcalloc(nr_pages, sizeof(*pglist), GFP_KERNEL);
+	if (!pglist) {
+		kfree(buf);
+		return ERR_PTR(-ENOMEM);
+	}
+
+	for (i = 0; i < nr_pages; i++)
+		pglist[i] = virt_to_page(pages[i]);
+
+	buf->trbe_base = (unsigned long)vmap(pglist, nr_pages, VM_MAP, PAGE_KERNEL);
+	if (!buf->trbe_base) {
+		kfree(pglist);
+		kfree(buf);
+		return ERR_PTR(-ENOMEM);
+	}
+	buf->trbe_limit = buf->trbe_base + nr_pages * PAGE_SIZE;
+	buf->trbe_write = buf->trbe_base;
+	buf->snapshot = snapshot;
+	buf->nr_pages = nr_pages;
+	buf->pages = pages;
+	kfree(pglist);
+	return buf;
+}
+
+static void arm_trbe_free_buffer(void *config)
+{
+	struct trbe_buf *buf = config;
+
+	vunmap((void *)buf->trbe_base);
+	kfree(buf);
+}
+
+static unsigned long arm_trbe_update_buffer(struct coresight_device *csdev,
+					    struct perf_output_handle *handle,
+					    void *config)
+{
+	struct trbe_drvdata *drvdata = dev_get_drvdata(csdev->dev.parent);
+	struct trbe_cpudata *cpudata = dev_get_drvdata(&csdev->dev);
+	struct trbe_buf *buf = config;
+	enum trbe_fault_action act;
+	unsigned long size, offset;
+	unsigned long write, base, status;
+	unsigned long flags;
+
+	WARN_ON(buf->cpudata != cpudata);
+	WARN_ON(cpudata->cpu != smp_processor_id());
+	WARN_ON(cpudata->drvdata != drvdata);
+	if (cpudata->mode != CS_MODE_PERF)
+		return 0;
+
+	perf_aux_output_flag(handle, PERF_AUX_FLAG_CORESIGHT_FORMAT_RAW);
+
+	/*
+	 * We are about to disable the TRBE. And this could in turn
+	 * fill up the buffer triggering, an IRQ. This could be consumed
+	 * by the PE asynchronously, causing a race here against
+	 * the IRQ handler in closing out the handle. So, let us
+	 * make sure the IRQ can't trigger while we are collecting
+	 * the buffer. We also make sure that a WRAP event is handled
+	 * accordingly.
+	 */
+	local_irq_save(flags);
+
+	/*
+	 * If the TRBE was disabled due to lack of space in the AUX buffer or a
+	 * spurious fault, the driver leaves it disabled, truncating the buffer.
+	 * Since the etm_perf driver expects to close out the AUX buffer, the
+	 * driver skips it. Thus, just pass in 0 size here to indicate that the
+	 * buffer was truncated.
+	 */
+	if (!is_trbe_enabled()) {
+		size = 0;
+		goto done;
+	}
+	/*
+	 * perf handle structure needs to be shared with the TRBE IRQ handler for
+	 * capturing trace data and restarting the handle. There is a probability
+	 * of an undefined reference based crash when etm event is being stopped
+	 * while a TRBE IRQ also getting processed. This happens due the release
+	 * of perf handle via perf_aux_output_end() in etm_event_stop(). Stopping
+	 * the TRBE here will ensure that no IRQ could be generated when the perf
+	 * handle gets freed in etm_event_stop().
+	 */
+	trbe_drain_and_disable_local();
+	write = get_trbe_write_pointer();
+	base = get_trbe_base_pointer();
+
+	/* Check if there is a pending interrupt and handle it here */
+	status = read_sysreg_s(SYS_TRBSR_EL1);
+	if (is_trbe_irq(status)) {
+
+		/*
+		 * Now that we are handling the IRQ here, clear the IRQ
+		 * from the status, to let the irq handler know that it
+		 * is taken care of.
+		 */
+		clr_trbe_irq();
+		isb();
+
+		act = trbe_get_fault_act(status);
+		/*
+		 * If this was not due to a WRAP event, we have some
+		 * errors and as such buffer is empty.
+		 */
+		if (act != TRBE_FAULT_ACT_WRAP) {
+			size = 0;
+			goto done;
+		}
+
+		/*
+		 * Otherwise, the buffer is full and the write pointer
+		 * has reached base. Adjust this back to the Limit pointer
+		 * for correct size. Also, mark the buffer truncated.
+		 */
+		write = get_trbe_limit_pointer();
+		perf_aux_output_flag(handle, PERF_AUX_FLAG_TRUNCATED);
+	}
+
+	offset = write - base;
+	if (WARN_ON_ONCE(offset < PERF_IDX2OFF(handle->head, buf)))
+		size = 0;
+	else
+		size = offset - PERF_IDX2OFF(handle->head, buf);
+
+done:
+	local_irq_restore(flags);
+
+	if (buf->snapshot)
+		handle->head += size;
+	return size;
+}
+
+static int arm_trbe_enable(struct coresight_device *csdev, u32 mode, void *data)
+{
+	struct trbe_drvdata *drvdata = dev_get_drvdata(csdev->dev.parent);
+	struct trbe_cpudata *cpudata = dev_get_drvdata(&csdev->dev);
+	struct perf_output_handle *handle = data;
+	struct trbe_buf *buf = etm_perf_sink_config(handle);
+
+	WARN_ON(cpudata->cpu != smp_processor_id());
+	WARN_ON(cpudata->drvdata != drvdata);
+	if (mode != CS_MODE_PERF)
+		return -EINVAL;
+
+	*this_cpu_ptr(drvdata->handle) = handle;
+	cpudata->buf = buf;
+	cpudata->mode = mode;
+	buf->cpudata = cpudata;
+	buf->trbe_limit = compute_trbe_buffer_limit(handle);
+	buf->trbe_write = buf->trbe_base + PERF_IDX2OFF(handle->head, buf);
+	if (buf->trbe_limit == buf->trbe_base) {
+		trbe_stop_and_truncate_event(handle);
+		return 0;
+	}
+	trbe_enable_hw(buf);
+	return 0;
+}
+
+static int arm_trbe_disable(struct coresight_device *csdev)
+{
+	struct trbe_drvdata *drvdata = dev_get_drvdata(csdev->dev.parent);
+	struct trbe_cpudata *cpudata = dev_get_drvdata(&csdev->dev);
+	struct trbe_buf *buf = cpudata->buf;
+
+	WARN_ON(buf->cpudata != cpudata);
+	WARN_ON(cpudata->cpu != smp_processor_id());
+	WARN_ON(cpudata->drvdata != drvdata);
+	if (cpudata->mode != CS_MODE_PERF)
+		return -EINVAL;
+
+	trbe_drain_and_disable_local();
+	buf->cpudata = NULL;
+	cpudata->buf = NULL;
+	cpudata->mode = CS_MODE_DISABLED;
+	return 0;
+}
+
+static void trbe_handle_spurious(struct perf_output_handle *handle)
+{
+	struct trbe_buf *buf = etm_perf_sink_config(handle);
+
+	buf->trbe_limit = compute_trbe_buffer_limit(handle);
+	buf->trbe_write = buf->trbe_base + PERF_IDX2OFF(handle->head, buf);
+	if (buf->trbe_limit == buf->trbe_base) {
+		trbe_drain_and_disable_local();
+		return;
+	}
+	trbe_enable_hw(buf);
+}
+
+static void trbe_handle_overflow(struct perf_output_handle *handle)
+{
+	struct perf_event *event = handle->event;
+	struct trbe_buf *buf = etm_perf_sink_config(handle);
+	unsigned long offset, size;
+	struct etm_event_data *event_data;
+
+	offset = get_trbe_limit_pointer() - get_trbe_base_pointer();
+	size = offset - PERF_IDX2OFF(handle->head, buf);
+	if (buf->snapshot)
+		handle->head += size;
+
+	/*
+	 * Mark the buffer as truncated, as we have stopped the trace
+	 * collection upon the WRAP event, without stopping the source.
+	 */
+	perf_aux_output_flag(handle, PERF_AUX_FLAG_CORESIGHT_FORMAT_RAW |
+				     PERF_AUX_FLAG_TRUNCATED);
+	perf_aux_output_end(handle, size);
+	event_data = perf_aux_output_begin(handle, event);
+	if (!event_data) {
+		/*
+		 * We are unable to restart the trace collection,
+		 * thus leave the TRBE disabled. The etm-perf driver
+		 * is able to detect this with a disconnected handle
+		 * (handle->event = NULL).
+		 */
+		trbe_drain_and_disable_local();
+		*this_cpu_ptr(buf->cpudata->drvdata->handle) = NULL;
+		return;
+	}
+	buf->trbe_limit = compute_trbe_buffer_limit(handle);
+	buf->trbe_write = buf->trbe_base + PERF_IDX2OFF(handle->head, buf);
+	if (buf->trbe_limit == buf->trbe_base) {
+		trbe_stop_and_truncate_event(handle);
+		return;
+	}
+	*this_cpu_ptr(buf->cpudata->drvdata->handle) = handle;
+	trbe_enable_hw(buf);
+}
+
+static bool is_perf_trbe(struct perf_output_handle *handle)
+{
+	struct trbe_buf *buf = etm_perf_sink_config(handle);
+	struct trbe_cpudata *cpudata = buf->cpudata;
+	struct trbe_drvdata *drvdata = cpudata->drvdata;
+	int cpu = smp_processor_id();
+
+	WARN_ON(buf->trbe_base != get_trbe_base_pointer());
+	WARN_ON(buf->trbe_limit != get_trbe_limit_pointer());
+
+	if (cpudata->mode != CS_MODE_PERF)
+		return false;
+
+	if (cpudata->cpu != cpu)
+		return false;
+
+	if (!cpumask_test_cpu(cpu, &drvdata->supported_cpus))
+		return false;
+
+	return true;
+}
+
+static irqreturn_t arm_trbe_irq_handler(int irq, void *dev)
+{
+	struct perf_output_handle **handle_ptr = dev;
+	struct perf_output_handle *handle = *handle_ptr;
+	enum trbe_fault_action act;
+	u64 status;
+
+	/*
+	 * Ensure the trace is visible to the CPUs and
+	 * any external aborts have been resolved.
+	 */
+	trbe_drain_and_disable_local();
+
+	status = read_sysreg_s(SYS_TRBSR_EL1);
+	/*
+	 * If the pending IRQ was handled by update_buffer callback
+	 * we have nothing to do here.
+	 */
+	if (!is_trbe_irq(status))
+		return IRQ_NONE;
+
+	clr_trbe_irq();
+	isb();
+
+	if (WARN_ON_ONCE(!handle) || !perf_get_aux(handle))
+		return IRQ_NONE;
+
+	if (!is_perf_trbe(handle))
+		return IRQ_NONE;
+
+	/*
+	 * Ensure perf callbacks have completed, which may disable
+	 * the trace buffer in response to a TRUNCATION flag.
+	 */
+	irq_work_run();
+
+	act = trbe_get_fault_act(status);
+	switch (act) {
+	case TRBE_FAULT_ACT_WRAP:
+		trbe_handle_overflow(handle);
+		break;
+	case TRBE_FAULT_ACT_SPURIOUS:
+		trbe_handle_spurious(handle);
+		break;
+	case TRBE_FAULT_ACT_FATAL:
+		trbe_stop_and_truncate_event(handle);
+		break;
+	}
+	return IRQ_HANDLED;
+}
+
+static const struct coresight_ops_sink arm_trbe_sink_ops = {
+	.enable		= arm_trbe_enable,
+	.disable	= arm_trbe_disable,
+	.alloc_buffer	= arm_trbe_alloc_buffer,
+	.free_buffer	= arm_trbe_free_buffer,
+	.update_buffer	= arm_trbe_update_buffer,
+};
+
+static const struct coresight_ops arm_trbe_cs_ops = {
+	.sink_ops	= &arm_trbe_sink_ops,
+};
+
+static ssize_t align_show(struct device *dev, struct device_attribute *attr, char *buf)
+{
+	struct trbe_cpudata *cpudata = dev_get_drvdata(dev);
+
+	return sprintf(buf, "%llx\n", cpudata->trbe_align);
+}
+static DEVICE_ATTR_RO(align);
+
+static ssize_t flag_show(struct device *dev, struct device_attribute *attr, char *buf)
+{
+	struct trbe_cpudata *cpudata = dev_get_drvdata(dev);
+
+	return sprintf(buf, "%d\n", cpudata->trbe_flag);
+}
+static DEVICE_ATTR_RO(flag);
+
+static struct attribute *arm_trbe_attrs[] = {
+	&dev_attr_align.attr,
+	&dev_attr_flag.attr,
+	NULL,
+};
+
+static const struct attribute_group arm_trbe_group = {
+	.attrs = arm_trbe_attrs,
+};
+
+static const struct attribute_group *arm_trbe_groups[] = {
+	&arm_trbe_group,
+	NULL,
+};
+
+static void arm_trbe_enable_cpu(void *info)
+{
+	struct trbe_drvdata *drvdata = info;
+
+	trbe_reset_local();
+	enable_percpu_irq(drvdata->irq, IRQ_TYPE_NONE);
+}
+
+static void arm_trbe_register_coresight_cpu(struct trbe_drvdata *drvdata, int cpu)
+{
+	struct trbe_cpudata *cpudata = per_cpu_ptr(drvdata->cpudata, cpu);
+	struct coresight_device *trbe_csdev = coresight_get_percpu_sink(cpu);
+	struct coresight_desc desc = { 0 };
+	struct device *dev;
+
+	if (WARN_ON(trbe_csdev))
+		return;
+
+	dev = &cpudata->drvdata->pdev->dev;
+	desc.name = devm_kasprintf(dev, GFP_KERNEL, "trbe%d", cpu);
+	if (IS_ERR(desc.name))
+		goto cpu_clear;
+
+	desc.type = CORESIGHT_DEV_TYPE_SINK;
+	desc.subtype.sink_subtype = CORESIGHT_DEV_SUBTYPE_SINK_PERCPU_SYSMEM;
+	desc.ops = &arm_trbe_cs_ops;
+	desc.pdata = dev_get_platdata(dev);
+	desc.groups = arm_trbe_groups;
+	desc.dev = dev;
+	trbe_csdev = coresight_register(&desc);
+	if (IS_ERR(trbe_csdev))
+		goto cpu_clear;
+
+	dev_set_drvdata(&trbe_csdev->dev, cpudata);
+	coresight_set_percpu_sink(cpu, trbe_csdev);
+	return;
+cpu_clear:
+	cpumask_clear_cpu(cpu, &drvdata->supported_cpus);
+}
+
+static void arm_trbe_probe_cpu(void *info)
+{
+	struct trbe_drvdata *drvdata = info;
+	int cpu = smp_processor_id();
+	struct trbe_cpudata *cpudata = per_cpu_ptr(drvdata->cpudata, cpu);
+	u64 trbidr;
+
+	if (WARN_ON(!cpudata))
+		goto cpu_clear;
+
+	if (!is_trbe_available()) {
+		pr_err("TRBE is not implemented on cpu %d\n", cpu);
+		goto cpu_clear;
+	}
+
+	trbidr = read_sysreg_s(SYS_TRBIDR_EL1);
+	if (!is_trbe_programmable(trbidr)) {
+		pr_err("TRBE is owned in higher exception level on cpu %d\n", cpu);
+		goto cpu_clear;
+	}
+
+	cpudata->trbe_align = 1ULL << get_trbe_address_align(trbidr);
+	if (cpudata->trbe_align > SZ_2K) {
+		pr_err("Unsupported alignment on cpu %d\n", cpu);
+		goto cpu_clear;
+	}
+	cpudata->trbe_flag = get_trbe_flag_update(trbidr);
+	cpudata->cpu = cpu;
+	cpudata->drvdata = drvdata;
+	return;
+cpu_clear:
+	cpumask_clear_cpu(cpu, &drvdata->supported_cpus);
+}
+
+static void arm_trbe_remove_coresight_cpu(void *info)
+{
+	int cpu = smp_processor_id();
+	struct trbe_drvdata *drvdata = info;
+	struct trbe_cpudata *cpudata = per_cpu_ptr(drvdata->cpudata, cpu);
+	struct coresight_device *trbe_csdev = coresight_get_percpu_sink(cpu);
+
+	disable_percpu_irq(drvdata->irq);
+	trbe_reset_local();
+	if (trbe_csdev) {
+		coresight_unregister(trbe_csdev);
+		cpudata->drvdata = NULL;
+		coresight_set_percpu_sink(cpu, NULL);
+	}
+}
+
+static int arm_trbe_probe_coresight(struct trbe_drvdata *drvdata)
+{
+	int cpu;
+
+	drvdata->cpudata = alloc_percpu(typeof(*drvdata->cpudata));
+	if (!drvdata->cpudata)
+		return -ENOMEM;
+
+	for_each_cpu(cpu, &drvdata->supported_cpus) {
+		smp_call_function_single(cpu, arm_trbe_probe_cpu, drvdata, 1);
+		if (cpumask_test_cpu(cpu, &drvdata->supported_cpus))
+			arm_trbe_register_coresight_cpu(drvdata, cpu);
+		if (cpumask_test_cpu(cpu, &drvdata->supported_cpus))
+			smp_call_function_single(cpu, arm_trbe_enable_cpu, drvdata, 1);
+	}
+	return 0;
+}
+
+static int arm_trbe_remove_coresight(struct trbe_drvdata *drvdata)
+{
+	int cpu;
+
+	for_each_cpu(cpu, &drvdata->supported_cpus)
+		smp_call_function_single(cpu, arm_trbe_remove_coresight_cpu, drvdata, 1);
+	free_percpu(drvdata->cpudata);
+	return 0;
+}
+
+static int arm_trbe_cpu_startup(unsigned int cpu, struct hlist_node *node)
+{
+	struct trbe_drvdata *drvdata = hlist_entry_safe(node, struct trbe_drvdata, hotplug_node);
+
+	if (cpumask_test_cpu(cpu, &drvdata->supported_cpus)) {
+
+		/*
+		 * If this CPU was not probed for TRBE,
+		 * initialize it now.
+		 */
+		if (!coresight_get_percpu_sink(cpu)) {
+			arm_trbe_probe_cpu(drvdata);
+			if (cpumask_test_cpu(cpu, &drvdata->supported_cpus))
+				arm_trbe_register_coresight_cpu(drvdata, cpu);
+			if (cpumask_test_cpu(cpu, &drvdata->supported_cpus))
+				arm_trbe_enable_cpu(drvdata);
+		} else {
+			arm_trbe_enable_cpu(drvdata);
+		}
+	}
+	return 0;
+}
+
+static int arm_trbe_cpu_teardown(unsigned int cpu, struct hlist_node *node)
+{
+	struct trbe_drvdata *drvdata = hlist_entry_safe(node, struct trbe_drvdata, hotplug_node);
+
+	if (cpumask_test_cpu(cpu, &drvdata->supported_cpus)) {
+		disable_percpu_irq(drvdata->irq);
+		trbe_reset_local();
+	}
+	return 0;
+}
+
+static int arm_trbe_probe_cpuhp(struct trbe_drvdata *drvdata)
+{
+	enum cpuhp_state trbe_online;
+	int ret;
+
+	trbe_online = cpuhp_setup_state_multi(CPUHP_AP_ONLINE_DYN, DRVNAME,
+					      arm_trbe_cpu_startup, arm_trbe_cpu_teardown);
+	if (trbe_online < 0)
+		return trbe_online;
+
+	ret = cpuhp_state_add_instance(trbe_online, &drvdata->hotplug_node);
+	if (ret) {
+		cpuhp_remove_multi_state(trbe_online);
+		return ret;
+	}
+	drvdata->trbe_online = trbe_online;
+	return 0;
+}
+
+static void arm_trbe_remove_cpuhp(struct trbe_drvdata *drvdata)
+{
+	cpuhp_remove_multi_state(drvdata->trbe_online);
+}
+
+static int arm_trbe_probe_irq(struct platform_device *pdev,
+			      struct trbe_drvdata *drvdata)
+{
+	int ret;
+
+	drvdata->irq = platform_get_irq(pdev, 0);
+	if (drvdata->irq < 0) {
+		pr_err("IRQ not found for the platform device\n");
+		return drvdata->irq;
+	}
+
+	if (!irq_is_percpu(drvdata->irq)) {
+		pr_err("IRQ is not a PPI\n");
+		return -EINVAL;
+	}
+
+	if (irq_get_percpu_devid_partition(drvdata->irq, &drvdata->supported_cpus))
+		return -EINVAL;
+
+	drvdata->handle = alloc_percpu(struct perf_output_handle *);
+	if (!drvdata->handle)
+		return -ENOMEM;
+
+	ret = request_percpu_irq(drvdata->irq, arm_trbe_irq_handler, DRVNAME, drvdata->handle);
+	if (ret) {
+		free_percpu(drvdata->handle);
+		return ret;
+	}
+	return 0;
+}
+
+static void arm_trbe_remove_irq(struct trbe_drvdata *drvdata)
+{
+	free_percpu_irq(drvdata->irq, drvdata->handle);
+	free_percpu(drvdata->handle);
+}
+
+static int arm_trbe_device_probe(struct platform_device *pdev)
+{
+	struct coresight_platform_data *pdata;
+	struct trbe_drvdata *drvdata;
+	struct device *dev = &pdev->dev;
+	int ret;
+
+	drvdata = devm_kzalloc(dev, sizeof(*drvdata), GFP_KERNEL);
+	if (!drvdata)
+		return -ENOMEM;
+
+	pdata = coresight_get_platform_data(dev);
+	if (IS_ERR(pdata))
+		return PTR_ERR(pdata);
+
+	dev_set_drvdata(dev, drvdata);
+	dev->platform_data = pdata;
+	drvdata->pdev = pdev;
+	ret = arm_trbe_probe_irq(pdev, drvdata);
+	if (ret)
+		return ret;
+
+	ret = arm_trbe_probe_coresight(drvdata);
+	if (ret)
+		goto probe_failed;
+
+	ret = arm_trbe_probe_cpuhp(drvdata);
+	if (ret)
+		goto cpuhp_failed;
+
+	return 0;
+cpuhp_failed:
+	arm_trbe_remove_coresight(drvdata);
+probe_failed:
+	arm_trbe_remove_irq(drvdata);
+	return ret;
+}
+
+static int arm_trbe_device_remove(struct platform_device *pdev)
+{
+	struct trbe_drvdata *drvdata = platform_get_drvdata(pdev);
+
+	arm_trbe_remove_cpuhp(drvdata);
+	arm_trbe_remove_coresight(drvdata);
+	arm_trbe_remove_irq(drvdata);
+	return 0;
+}
+
+static const struct of_device_id arm_trbe_of_match[] = {
+	{ .compatible = "arm,trace-buffer-extension"},
+	{},
+};
+MODULE_DEVICE_TABLE(of, arm_trbe_of_match);
+
+static struct platform_driver arm_trbe_driver = {
+	.driver	= {
+		.name = DRVNAME,
+		.of_match_table = of_match_ptr(arm_trbe_of_match),
+		.suppress_bind_attrs = true,
+	},
+	.probe	= arm_trbe_device_probe,
+	.remove	= arm_trbe_device_remove,
+};
+
+static int __init arm_trbe_init(void)
+{
+	int ret;
+
+	if (arm64_kernel_unmapped_at_el0()) {
+		pr_err("TRBE wouldn't work if kernel gets unmapped at EL0\n");
+		return -EOPNOTSUPP;
+	}
+
+	ret = platform_driver_register(&arm_trbe_driver);
+	if (!ret)
+		return 0;
+
+	pr_err("Error registering %s platform driver\n", DRVNAME);
+	return ret;
+}
+
+static void __exit arm_trbe_exit(void)
+{
+	platform_driver_unregister(&arm_trbe_driver);
+}
+module_init(arm_trbe_init);
+module_exit(arm_trbe_exit);
+
+MODULE_AUTHOR("Anshuman Khandual <anshuman.khandual@arm.com>");
+MODULE_DESCRIPTION("Arm Trace Buffer Extension (TRBE) driver");
+MODULE_LICENSE("GPL v2");
diff --git a/drivers/hwtracing/coresight/coresight-trbe.h b/drivers/hwtracing/coresight/coresight-trbe.h
new file mode 100644
index 0000000..abf3e36
--- /dev/null
+++ b/drivers/hwtracing/coresight/coresight-trbe.h
@@ -0,0 +1,152 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+/*
+ * This contains all required hardware related helper functions for
+ * Trace Buffer Extension (TRBE) driver in the coresight framework.
+ *
+ * Copyright (C) 2020 ARM Ltd.
+ *
+ * Author: Anshuman Khandual <anshuman.khandual@arm.com>
+ */
+#include <linux/coresight.h>
+#include <linux/device.h>
+#include <linux/irq.h>
+#include <linux/kernel.h>
+#include <linux/of.h>
+#include <linux/platform_device.h>
+#include <linux/smp.h>
+
+#include "coresight-etm-perf.h"
+
+static inline bool is_trbe_available(void)
+{
+	u64 aa64dfr0 = read_sysreg_s(SYS_ID_AA64DFR0_EL1);
+	unsigned int trbe = cpuid_feature_extract_unsigned_field(aa64dfr0, ID_AA64DFR0_TRBE_SHIFT);
+
+	return trbe >= 0b0001;
+}
+
+static inline bool is_trbe_enabled(void)
+{
+	u64 trblimitr = read_sysreg_s(SYS_TRBLIMITR_EL1);
+
+	return trblimitr & TRBLIMITR_ENABLE;
+}
+
+#define TRBE_EC_OTHERS		0
+#define TRBE_EC_STAGE1_ABORT	36
+#define TRBE_EC_STAGE2_ABORT	37
+
+static inline int get_trbe_ec(u64 trbsr)
+{
+	return (trbsr >> TRBSR_EC_SHIFT) & TRBSR_EC_MASK;
+}
+
+#define TRBE_BSC_NOT_STOPPED 0
+#define TRBE_BSC_FILLED      1
+#define TRBE_BSC_TRIGGERED   2
+
+static inline int get_trbe_bsc(u64 trbsr)
+{
+	return (trbsr >> TRBSR_BSC_SHIFT) & TRBSR_BSC_MASK;
+}
+
+static inline void clr_trbe_irq(void)
+{
+	u64 trbsr = read_sysreg_s(SYS_TRBSR_EL1);
+
+	trbsr &= ~TRBSR_IRQ;
+	write_sysreg_s(trbsr, SYS_TRBSR_EL1);
+}
+
+static inline bool is_trbe_irq(u64 trbsr)
+{
+	return trbsr & TRBSR_IRQ;
+}
+
+static inline bool is_trbe_trg(u64 trbsr)
+{
+	return trbsr & TRBSR_TRG;
+}
+
+static inline bool is_trbe_wrap(u64 trbsr)
+{
+	return trbsr & TRBSR_WRAP;
+}
+
+static inline bool is_trbe_abort(u64 trbsr)
+{
+	return trbsr & TRBSR_ABORT;
+}
+
+static inline bool is_trbe_running(u64 trbsr)
+{
+	return !(trbsr & TRBSR_STOP);
+}
+
+#define TRBE_TRIG_MODE_STOP		0
+#define TRBE_TRIG_MODE_IRQ		1
+#define TRBE_TRIG_MODE_IGNORE		3
+
+#define TRBE_FILL_MODE_FILL		0
+#define TRBE_FILL_MODE_WRAP		1
+#define TRBE_FILL_MODE_CIRCULAR_BUFFER	3
+
+static inline void set_trbe_disabled(void)
+{
+	u64 trblimitr = read_sysreg_s(SYS_TRBLIMITR_EL1);
+
+	trblimitr &= ~TRBLIMITR_ENABLE;
+	write_sysreg_s(trblimitr, SYS_TRBLIMITR_EL1);
+}
+
+static inline bool get_trbe_flag_update(u64 trbidr)
+{
+	return trbidr & TRBIDR_FLAG;
+}
+
+static inline bool is_trbe_programmable(u64 trbidr)
+{
+	return !(trbidr & TRBIDR_PROG);
+}
+
+static inline int get_trbe_address_align(u64 trbidr)
+{
+	return (trbidr >> TRBIDR_ALIGN_SHIFT) & TRBIDR_ALIGN_MASK;
+}
+
+static inline unsigned long get_trbe_write_pointer(void)
+{
+	return read_sysreg_s(SYS_TRBPTR_EL1);
+}
+
+static inline void set_trbe_write_pointer(unsigned long addr)
+{
+	WARN_ON(is_trbe_enabled());
+	write_sysreg_s(addr, SYS_TRBPTR_EL1);
+}
+
+static inline unsigned long get_trbe_limit_pointer(void)
+{
+	u64 trblimitr = read_sysreg_s(SYS_TRBLIMITR_EL1);
+	unsigned long addr = trblimitr & (TRBLIMITR_LIMIT_MASK << TRBLIMITR_LIMIT_SHIFT);
+
+	WARN_ON(!IS_ALIGNED(addr, PAGE_SIZE));
+	return addr;
+}
+
+static inline unsigned long get_trbe_base_pointer(void)
+{
+	u64 trbbaser = read_sysreg_s(SYS_TRBBASER_EL1);
+	unsigned long addr = trbbaser & (TRBBASER_BASE_MASK << TRBBASER_BASE_SHIFT);
+
+	WARN_ON(!IS_ALIGNED(addr, PAGE_SIZE));
+	return addr;
+}
+
+static inline void set_trbe_base_pointer(unsigned long addr)
+{
+	WARN_ON(is_trbe_enabled());
+	WARN_ON(!IS_ALIGNED(addr, (1UL << TRBBASER_BASE_SHIFT)));
+	WARN_ON(!IS_ALIGNED(addr, PAGE_SIZE));
+	write_sysreg_s(addr, SYS_TRBBASER_EL1);
+}
diff --git a/drivers/irqchip/irq-gic-v3-its.c b/drivers/irqchip/irq-gic-v3-its.c
index ed46e60..d205faf 100644
--- a/drivers/irqchip/irq-gic-v3-its.c
+++ b/drivers/irqchip/irq-gic-v3-its.c
@@ -794,8 +794,13 @@
 
 	its_encode_alloc(cmd, alloc);
 
-	/* We can only signal PTZ when alloc==1. Why do we have two bits? */
-	its_encode_ptz(cmd, alloc);
+	/*
+	 * GICv4.1 provides a way to get the VLPI state, which needs the vPE
+	 * to be unmapped first, and in this case, we may remap the vPE
+	 * back while the VPT is not empty. So we can't assume that the
+	 * VPT is empty on map. This is why we never advertise PTZ.
+	 */
+	its_encode_ptz(cmd, false);
 	its_encode_vconf_addr(cmd, vconf_addr);
 	its_encode_vmapp_default_db(cmd, desc->its_vmapp_cmd.vpe->vpe_db_lpi);
 
@@ -4554,6 +4559,15 @@
 
 		its_send_vmapp(its, vpe, false);
 	}
+
+	/*
+	 * There may be a direct read to the VPT after unmapping the
+	 * vPE, to guarantee the validity of this, we make the VPT
+	 * memory coherent with the CPU caches here.
+	 */
+	if (find_4_1_its() && !atomic_read(&vpe->vmapp_count))
+		gic_flush_dcache_to_poc(page_address(vpe->vpt_page),
+					LPI_PENDBASE_SZ);
 }
 
 static const struct irq_domain_ops its_vpe_domain_ops = {
diff --git a/drivers/ptp/Kconfig b/drivers/ptp/Kconfig
index f2edef0..8c20e52 100644
--- a/drivers/ptp/Kconfig
+++ b/drivers/ptp/Kconfig
@@ -108,7 +108,7 @@
 config PTP_1588_CLOCK_KVM
 	tristate "KVM virtual PTP clock"
 	depends on PTP_1588_CLOCK
-	depends on KVM_GUEST && X86
+	depends on (KVM_GUEST && X86) || (HAVE_ARM_SMCCC_DISCOVERY && ARM_ARCH_TIMER)
 	default y
 	help
 	  This driver adds support for using kvm infrastructure as a PTP
diff --git a/drivers/ptp/Makefile b/drivers/ptp/Makefile
index db5aef3..8673d17 100644
--- a/drivers/ptp/Makefile
+++ b/drivers/ptp/Makefile
@@ -4,6 +4,8 @@
 #
 
 ptp-y					:= ptp_clock.o ptp_chardev.o ptp_sysfs.o
+ptp_kvm-$(CONFIG_X86)			:= ptp_kvm_x86.o ptp_kvm_common.o
+ptp_kvm-$(CONFIG_HAVE_ARM_SMCCC)	:= ptp_kvm_arm.o ptp_kvm_common.o
 obj-$(CONFIG_PTP_1588_CLOCK)		+= ptp.o
 obj-$(CONFIG_PTP_1588_CLOCK_DTE)	+= ptp_dte.o
 obj-$(CONFIG_PTP_1588_CLOCK_INES)	+= ptp_ines.o
diff --git a/drivers/ptp/ptp_kvm_arm.c b/drivers/ptp/ptp_kvm_arm.c
new file mode 100644
index 0000000..b7d28c8
--- /dev/null
+++ b/drivers/ptp/ptp_kvm_arm.c
@@ -0,0 +1,28 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ *  Virtual PTP 1588 clock for use with KVM guests
+ *  Copyright (C) 2019 ARM Ltd.
+ *  All Rights Reserved
+ */
+
+#include <linux/arm-smccc.h>
+#include <linux/ptp_kvm.h>
+
+#include <asm/arch_timer.h>
+#include <asm/hypervisor.h>
+
+int kvm_arch_ptp_init(void)
+{
+	int ret;
+
+	ret = kvm_arm_hyp_service_available(ARM_SMCCC_KVM_FUNC_PTP);
+	if (ret <= 0)
+		return -EOPNOTSUPP;
+
+	return 0;
+}
+
+int kvm_arch_ptp_get_clock(struct timespec64 *ts)
+{
+	return kvm_arch_ptp_get_crosststamp(NULL, ts, NULL);
+}
diff --git a/drivers/ptp/ptp_kvm.c b/drivers/ptp/ptp_kvm_common.c
similarity index 60%
rename from drivers/ptp/ptp_kvm.c
rename to drivers/ptp/ptp_kvm_common.c
index 658d33f..721ddce 100644
--- a/drivers/ptp/ptp_kvm.c
+++ b/drivers/ptp/ptp_kvm_common.c
@@ -8,11 +8,11 @@
 #include <linux/err.h>
 #include <linux/init.h>
 #include <linux/kernel.h>
+#include <linux/slab.h>
 #include <linux/module.h>
+#include <linux/ptp_kvm.h>
 #include <uapi/linux/kvm_para.h>
 #include <asm/kvm_para.h>
-#include <asm/pvclock.h>
-#include <asm/kvmclock.h>
 #include <uapi/asm/kvm_para.h>
 
 #include <linux/ptp_clock_kernel.h>
@@ -24,56 +24,29 @@
 
 static DEFINE_SPINLOCK(kvm_ptp_lock);
 
-static struct pvclock_vsyscall_time_info *hv_clock;
-
-static struct kvm_clock_pairing clock_pair;
-static phys_addr_t clock_pair_gpa;
-
 static int ptp_kvm_get_time_fn(ktime_t *device_time,
 			       struct system_counterval_t *system_counter,
 			       void *ctx)
 {
-	unsigned long ret;
+	long ret;
+	u64 cycle;
 	struct timespec64 tspec;
-	unsigned version;
-	int cpu;
-	struct pvclock_vcpu_time_info *src;
+	struct clocksource *cs;
 
 	spin_lock(&kvm_ptp_lock);
 
 	preempt_disable_notrace();
-	cpu = smp_processor_id();
-	src = &hv_clock[cpu].pvti;
-
-	do {
-		/*
-		 * We are using a TSC value read in the hosts
-		 * kvm_hc_clock_pairing handling.
-		 * So any changes to tsc_to_system_mul
-		 * and tsc_shift or any other pvclock
-		 * data invalidate that measurement.
-		 */
-		version = pvclock_read_begin(src);
-
-		ret = kvm_hypercall2(KVM_HC_CLOCK_PAIRING,
-				     clock_pair_gpa,
-				     KVM_CLOCK_PAIRING_WALLCLOCK);
-		if (ret != 0) {
-			pr_err_ratelimited("clock pairing hypercall ret %lu\n", ret);
-			spin_unlock(&kvm_ptp_lock);
-			preempt_enable_notrace();
-			return -EOPNOTSUPP;
-		}
-
-		tspec.tv_sec = clock_pair.sec;
-		tspec.tv_nsec = clock_pair.nsec;
-		ret = __pvclock_read_cycles(src, clock_pair.tsc);
-	} while (pvclock_read_retry(src, version));
+	ret = kvm_arch_ptp_get_crosststamp(&cycle, &tspec, &cs);
+	if (ret) {
+		spin_unlock(&kvm_ptp_lock);
+		preempt_enable_notrace();
+		return ret;
+	}
 
 	preempt_enable_notrace();
 
-	system_counter->cycles = ret;
-	system_counter->cs = &kvm_clock;
+	system_counter->cycles = cycle;
+	system_counter->cs = cs;
 
 	*device_time = timespec64_to_ktime(tspec);
 
@@ -111,22 +84,17 @@
 
 static int ptp_kvm_gettime(struct ptp_clock_info *ptp, struct timespec64 *ts)
 {
-	unsigned long ret;
+	long ret;
 	struct timespec64 tspec;
 
 	spin_lock(&kvm_ptp_lock);
 
-	ret = kvm_hypercall2(KVM_HC_CLOCK_PAIRING,
-			     clock_pair_gpa,
-			     KVM_CLOCK_PAIRING_WALLCLOCK);
-	if (ret != 0) {
-		pr_err_ratelimited("clock offset hypercall ret %lu\n", ret);
+	ret = kvm_arch_ptp_get_clock(&tspec);
+	if (ret) {
 		spin_unlock(&kvm_ptp_lock);
-		return -EOPNOTSUPP;
+		return ret;
 	}
 
-	tspec.tv_sec = clock_pair.sec;
-	tspec.tv_nsec = clock_pair.nsec;
 	spin_unlock(&kvm_ptp_lock);
 
 	memcpy(ts, &tspec, sizeof(struct timespec64));
@@ -168,19 +136,11 @@
 {
 	long ret;
 
-	if (!kvm_para_available())
-		return -ENODEV;
-
-	clock_pair_gpa = slow_virt_to_phys(&clock_pair);
-	hv_clock = pvclock_get_pvti_cpu0_va();
-
-	if (!hv_clock)
-		return -ENODEV;
-
-	ret = kvm_hypercall2(KVM_HC_CLOCK_PAIRING, clock_pair_gpa,
-			KVM_CLOCK_PAIRING_WALLCLOCK);
-	if (ret == -KVM_ENOSYS || ret == -KVM_EOPNOTSUPP)
-		return -ENODEV;
+	ret = kvm_arch_ptp_init();
+	if (ret) {
+		pr_err("fail to initialize ptp_kvm");
+		return ret;
+	}
 
 	kvm_ptp_clock.caps = ptp_kvm_caps;
 
diff --git a/drivers/ptp/ptp_kvm_x86.c b/drivers/ptp/ptp_kvm_x86.c
new file mode 100644
index 0000000..3dd519d
--- /dev/null
+++ b/drivers/ptp/ptp_kvm_x86.c
@@ -0,0 +1,97 @@
+// SPDX-License-Identifier: GPL-2.0-or-later
+/*
+ * Virtual PTP 1588 clock for use with KVM guests
+ *
+ * Copyright (C) 2017 Red Hat Inc.
+ */
+
+#include <linux/device.h>
+#include <linux/kernel.h>
+#include <asm/pvclock.h>
+#include <asm/kvmclock.h>
+#include <linux/module.h>
+#include <uapi/asm/kvm_para.h>
+#include <uapi/linux/kvm_para.h>
+#include <linux/ptp_clock_kernel.h>
+#include <linux/ptp_kvm.h>
+
+struct pvclock_vsyscall_time_info *hv_clock;
+
+static phys_addr_t clock_pair_gpa;
+static struct kvm_clock_pairing clock_pair;
+
+int kvm_arch_ptp_init(void)
+{
+	long ret;
+
+	if (!kvm_para_available())
+		return -ENODEV;
+
+	clock_pair_gpa = slow_virt_to_phys(&clock_pair);
+	hv_clock = pvclock_get_pvti_cpu0_va();
+	if (!hv_clock)
+		return -ENODEV;
+
+	ret = kvm_hypercall2(KVM_HC_CLOCK_PAIRING, clock_pair_gpa,
+			     KVM_CLOCK_PAIRING_WALLCLOCK);
+	if (ret == -KVM_ENOSYS || ret == -KVM_EOPNOTSUPP)
+		return -ENODEV;
+
+	return 0;
+}
+
+int kvm_arch_ptp_get_clock(struct timespec64 *ts)
+{
+	long ret;
+
+	ret = kvm_hypercall2(KVM_HC_CLOCK_PAIRING,
+			     clock_pair_gpa,
+			     KVM_CLOCK_PAIRING_WALLCLOCK);
+	if (ret != 0) {
+		pr_err_ratelimited("clock offset hypercall ret %lu\n", ret);
+		return -EOPNOTSUPP;
+	}
+
+	ts->tv_sec = clock_pair.sec;
+	ts->tv_nsec = clock_pair.nsec;
+
+	return 0;
+}
+
+int kvm_arch_ptp_get_crosststamp(u64 *cycle, struct timespec64 *tspec,
+			      struct clocksource **cs)
+{
+	struct pvclock_vcpu_time_info *src;
+	unsigned int version;
+	long ret;
+	int cpu;
+
+	cpu = smp_processor_id();
+	src = &hv_clock[cpu].pvti;
+
+	do {
+		/*
+		 * We are using a TSC value read in the hosts
+		 * kvm_hc_clock_pairing handling.
+		 * So any changes to tsc_to_system_mul
+		 * and tsc_shift or any other pvclock
+		 * data invalidate that measurement.
+		 */
+		version = pvclock_read_begin(src);
+
+		ret = kvm_hypercall2(KVM_HC_CLOCK_PAIRING,
+				     clock_pair_gpa,
+				     KVM_CLOCK_PAIRING_WALLCLOCK);
+		if (ret != 0) {
+			pr_err_ratelimited("clock pairing hypercall ret %lu\n", ret);
+			return -EOPNOTSUPP;
+		}
+		tspec->tv_sec = clock_pair.sec;
+		tspec->tv_nsec = clock_pair.nsec;
+		*cycle = __pvclock_read_cycles(src, clock_pair.tsc);
+	} while (pvclock_read_retry(src, version));
+
+	*cs = &kvm_clock;
+
+	return 0;
+}
diff --git a/include/kvm/arm_vgic.h b/include/kvm/arm_vgic.h
index 3d74f10..ec62118 100644
--- a/include/kvm/arm_vgic.h
+++ b/include/kvm/arm_vgic.h
@@ -322,6 +322,7 @@
 	 */
 	struct vgic_io_device	rd_iodev;
 	struct vgic_redist_region *rdreg;
+	u32 rdreg_index;
 
 	/* Contains the attributes and gpa of the LPI pending tables. */
 	u64 pendbaser;
diff --git a/include/linux/arm-smccc.h b/include/linux/arm-smccc.h
index 62c5423..6861489 100644
--- a/include/linux/arm-smccc.h
+++ b/include/linux/arm-smccc.h
@@ -55,6 +55,8 @@
 #define ARM_SMCCC_OWNER_TRUSTED_OS	50
 #define ARM_SMCCC_OWNER_TRUSTED_OS_END	63
 
+#define ARM_SMCCC_FUNC_QUERY_CALL_UID  0xff01
+
 #define ARM_SMCCC_QUIRK_NONE		0
 #define ARM_SMCCC_QUIRK_QCOM_A6		1 /* Save/restore register a6 */
 
@@ -87,8 +89,47 @@
 			   ARM_SMCCC_SMC_32,				\
 			   0, 0x7fff)
 
+#define ARM_SMCCC_VENDOR_HYP_CALL_UID_FUNC_ID				\
+	ARM_SMCCC_CALL_VAL(ARM_SMCCC_FAST_CALL,				\
+			   ARM_SMCCC_SMC_32,				\
+			   ARM_SMCCC_OWNER_VENDOR_HYP,			\
+			   ARM_SMCCC_FUNC_QUERY_CALL_UID)
+
+/* KVM UID value: 28b46fb6-2ec5-11e9-a9ca-4b564d003a74 */
+#define ARM_SMCCC_VENDOR_HYP_UID_KVM_REG_0	0xb66fb428U
+#define ARM_SMCCC_VENDOR_HYP_UID_KVM_REG_1	0xe911c52eU
+#define ARM_SMCCC_VENDOR_HYP_UID_KVM_REG_2	0x564bcaa9U
+#define ARM_SMCCC_VENDOR_HYP_UID_KVM_REG_3	0x743a004dU
+
+/* KVM "vendor specific" services */
+#define ARM_SMCCC_KVM_FUNC_FEATURES		0
+#define ARM_SMCCC_KVM_FUNC_PTP			1
+#define ARM_SMCCC_KVM_FUNC_FEATURES_2		127
+#define ARM_SMCCC_KVM_NUM_FUNCS			128
+
+#define ARM_SMCCC_VENDOR_HYP_KVM_FEATURES_FUNC_ID			\
+	ARM_SMCCC_CALL_VAL(ARM_SMCCC_FAST_CALL,				\
+			   ARM_SMCCC_SMC_32,				\
+			   ARM_SMCCC_OWNER_VENDOR_HYP,			\
+			   ARM_SMCCC_KVM_FUNC_FEATURES)
+
 #define SMCCC_ARCH_WORKAROUND_RET_UNAFFECTED	1
 
+/*
+ * ptp_kvm is a feature used for time sync between vm and host.
+ * ptp_kvm module in guest kernel will get service from host using
+ * this hypercall ID.
+ */
+#define ARM_SMCCC_VENDOR_HYP_KVM_PTP_FUNC_ID				\
+	ARM_SMCCC_CALL_VAL(ARM_SMCCC_FAST_CALL,				\
+			   ARM_SMCCC_SMC_32,				\
+			   ARM_SMCCC_OWNER_VENDOR_HYP,			\
+			   ARM_SMCCC_KVM_FUNC_PTP)
+
+/* ptp_kvm counter type ID */
+#define KVM_PTP_VIRT_COUNTER			0
+#define KVM_PTP_PHYS_COUNTER			1
+
 /* Paravirtualised time calls (defined by ARM DEN0057A) */
 #define ARM_SMCCC_HV_PV_TIME_FEATURES				\
 	ARM_SMCCC_CALL_VAL(ARM_SMCCC_FAST_CALL,			\
diff --git a/include/linux/bug.h b/include/linux/bug.h
index f639bd0..e3841be 100644
--- a/include/linux/bug.h
+++ b/include/linux/bug.h
@@ -36,6 +36,9 @@
 	return bug->flags & BUGFLAG_WARNING;
 }
 
+void bug_get_file_line(struct bug_entry *bug, const char **file,
+		       unsigned int *line);
+
 struct bug_entry *find_bug(unsigned long bugaddr);
 
 enum bug_trap_type report_bug(unsigned long bug_addr, struct pt_regs *regs);
diff --git a/include/linux/clocksource.h b/include/linux/clocksource.h
index 86d143db..1290d0d 100644
--- a/include/linux/clocksource.h
+++ b/include/linux/clocksource.h
@@ -17,6 +17,7 @@
 #include <linux/timer.h>
 #include <linux/init.h>
 #include <linux/of.h>
+#include <linux/clocksource_ids.h>
 #include <asm/div64.h>
 #include <asm/io.h>
 
@@ -62,6 +63,10 @@
  *			400-499: Perfect
  *				The ideal clocksource. A must-use where
  *				available.
+ * @id:			Defaults to CSID_GENERIC. The id value is captured
+ *			in certain snapshot functions to allow callers to
+ *			validate the clocksource from which the snapshot was
+ *			taken.
  * @flags:		Flags describing special properties
  * @enable:		Optional function to enable the clocksource
  * @disable:		Optional function to disable the clocksource
@@ -100,6 +105,7 @@
 	const char		*name;
 	struct list_head	list;
 	int			rating;
+	enum clocksource_ids	id;
 	enum vdso_clock_mode	vdso_clock_mode;
 	unsigned long		flags;
 
diff --git a/include/linux/clocksource_ids.h b/include/linux/clocksource_ids.h
new file mode 100644
index 0000000..16775d7
--- /dev/null
+++ b/include/linux/clocksource_ids.h
@@ -0,0 +1,12 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+#ifndef _LINUX_CLOCKSOURCE_IDS_H
+#define _LINUX_CLOCKSOURCE_IDS_H
+
+/* Enum to give clocksources a unique identifier */
+enum clocksource_ids {
+	CSID_GENERIC		= 0,
+	CSID_ARM_ARCH_COUNTER,
+	CSID_MAX,
+};
+
+#endif
diff --git a/include/linux/coresight.h b/include/linux/coresight.h
index 976ec26..85008a6 100644
--- a/include/linux/coresight.h
+++ b/include/linux/coresight.h
@@ -50,6 +50,7 @@
 	CORESIGHT_DEV_SUBTYPE_SINK_PORT,
 	CORESIGHT_DEV_SUBTYPE_SINK_BUFFER,
 	CORESIGHT_DEV_SUBTYPE_SINK_SYSMEM,
+	CORESIGHT_DEV_SUBTYPE_SINK_PERCPU_SYSMEM,
 };
 
 enum coresight_dev_subtype_link {
@@ -455,6 +456,18 @@
 }
 #endif	/* CONFIG_64BIT */
 
+static inline bool coresight_is_percpu_source(struct coresight_device *csdev)
+{
+	return csdev && (csdev->type == CORESIGHT_DEV_TYPE_SOURCE) &&
+	       (csdev->subtype.source_subtype == CORESIGHT_DEV_SUBTYPE_SOURCE_PROC);
+}
+
+static inline bool coresight_is_percpu_sink(struct coresight_device *csdev)
+{
+	return csdev && (csdev->type == CORESIGHT_DEV_TYPE_SINK) &&
+	       (csdev->subtype.sink_subtype == CORESIGHT_DEV_SUBTYPE_SINK_PERCPU_SYSMEM);
+}
+
 extern struct coresight_device *
 coresight_register(struct coresight_desc *desc);
 extern void coresight_unregister(struct coresight_device *csdev);
diff --git a/include/linux/kvm_host.h b/include/linux/kvm_host.h
index 1b65e72..862d793 100644
--- a/include/linux/kvm_host.h
+++ b/include/linux/kvm_host.h
@@ -218,6 +218,20 @@
 int kvm_async_pf_wakeup_all(struct kvm_vcpu *vcpu);
 #endif
 
+#ifdef KVM_ARCH_WANT_MMU_NOTIFIER
+struct kvm_gfn_range {
+	struct kvm_memory_slot *slot;
+	gfn_t start;
+	gfn_t end;
+	pte_t pte;
+	bool may_block;
+};
+bool kvm_unmap_gfn_range(struct kvm *kvm, struct kvm_gfn_range *range);
+bool kvm_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range);
+bool kvm_test_age_gfn(struct kvm *kvm, struct kvm_gfn_range *range);
+bool kvm_set_spte_gfn(struct kvm *kvm, struct kvm_gfn_range *range);
+#endif
+
 enum {
 	OUTSIDE_GUEST_MODE,
 	IN_GUEST_MODE,
@@ -458,6 +472,7 @@
 #endif /* KVM_HAVE_MMU_RWLOCK */
 
 	struct mutex slots_lock;
+	struct rw_semaphore mmu_notifier_slots_lock;
 	struct mm_struct *mm; /* userspace tied to this vm */
 	struct kvm_memslots __rcu *memslots[KVM_ADDRESS_SPACE_NUM];
 	struct kvm_vcpu *vcpus[KVM_MAX_VCPUS];
@@ -640,14 +655,16 @@
 
 void kvm_get_kvm(struct kvm *kvm);
 void kvm_put_kvm(struct kvm *kvm);
+bool file_is_kvm(struct file *file);
 void kvm_put_kvm_no_destroy(struct kvm *kvm);
 
 static inline struct kvm_memslots *__kvm_memslots(struct kvm *kvm, int as_id)
 {
 	as_id = array_index_nospec(as_id, KVM_ADDRESS_SPACE_NUM);
 	return srcu_dereference_check(kvm->memslots[as_id], &kvm->srcu,
-			lockdep_is_held(&kvm->slots_lock) ||
-			!refcount_read(&kvm->users_count));
+				      lockdep_is_held(&kvm->slots_lock) ||
+				      lockdep_is_held(&kvm->mmu_notifier_slots_lock) ||
+				      !refcount_read(&kvm->users_count));
 }
 
 static inline struct kvm_memslots *kvm_memslots(struct kvm *kvm)
@@ -886,7 +903,7 @@
 
 #ifdef CONFIG_KVM_GENERIC_DIRTYLOG_READ_PROTECT
 void kvm_arch_flush_remote_tlbs_memslot(struct kvm *kvm,
-					struct kvm_memory_slot *memslot);
+					const struct kvm_memory_slot *memslot);
 #else /* !CONFIG_KVM_GENERIC_DIRTYLOG_READ_PROTECT */
 int kvm_vm_ioctl_get_dirty_log(struct kvm *kvm, struct kvm_dirty_log *log);
 int kvm_get_dirty_log(struct kvm *kvm, struct kvm_dirty_log *log,
@@ -1116,7 +1133,7 @@
 }
 
 static inline unsigned long
-__gfn_to_hva_memslot(struct kvm_memory_slot *slot, gfn_t gfn)
+__gfn_to_hva_memslot(const struct kvm_memory_slot *slot, gfn_t gfn)
 {
 	return slot->userspace_addr + (gfn - slot->base_gfn) * PAGE_SIZE;
 }
diff --git a/include/linux/ptp_kvm.h b/include/linux/ptp_kvm.h
new file mode 100644
index 0000000..f960a71
--- /dev/null
+++ b/include/linux/ptp_kvm.h
@@ -0,0 +1,19 @@
+/* SPDX-License-Identifier: GPL-2.0-or-later */
+/*
+ * Virtual PTP 1588 clock for use with KVM guests
+ *
+ * Copyright (C) 2017 Red Hat Inc.
+ */
+
+#ifndef _PTP_KVM_H_
+#define _PTP_KVM_H_
+
+struct timespec64;
+struct clocksource;
+
+int kvm_arch_ptp_init(void);
+int kvm_arch_ptp_get_clock(struct timespec64 *ts);
+int kvm_arch_ptp_get_crosststamp(u64 *cycle,
+		struct timespec64 *tspec, struct clocksource **cs);
+
+#endif /* _PTP_KVM_H_ */
diff --git a/include/linux/timekeeping.h b/include/linux/timekeeping.h
index c6792cf..78a98bd 100644
--- a/include/linux/timekeeping.h
+++ b/include/linux/timekeeping.h
@@ -3,6 +3,7 @@
 #define _LINUX_TIMEKEEPING_H
 
 #include <linux/errno.h>
+#include <linux/clocksource_ids.h>
 
 /* Included from linux/ktime.h */
 
@@ -243,11 +244,12 @@
  * @cs_was_changed_seq:	The sequence number of clocksource change events
  */
 struct system_time_snapshot {
-	u64		cycles;
-	ktime_t		real;
-	ktime_t		raw;
-	unsigned int	clock_was_set_seq;
-	u8		cs_was_changed_seq;
+	u64			cycles;
+	ktime_t			real;
+	ktime_t			raw;
+	enum clocksource_ids	cs_id;
+	unsigned int		clock_was_set_seq;
+	u8			cs_was_changed_seq;
 };
 
 /**
diff --git a/include/trace/events/kvm.h b/include/trace/events/kvm.h
index 49d7d0f..37e1e1a 100644
--- a/include/trace/events/kvm.h
+++ b/include/trace/events/kvm.h
@@ -255,30 +255,6 @@
 	TP_printk("%s", __print_symbolic(__entry->load, kvm_fpu_load_symbol))
 );
 
-TRACE_EVENT(kvm_age_page,
-	TP_PROTO(ulong gfn, int level, struct kvm_memory_slot *slot, int ref),
-	TP_ARGS(gfn, level, slot, ref),
-
-	TP_STRUCT__entry(
-		__field(	u64,	hva		)
-		__field(	u64,	gfn		)
-		__field(	u8,	level		)
-		__field(	u8,	referenced	)
-	),
-
-	TP_fast_assign(
-		__entry->gfn		= gfn;
-		__entry->level		= level;
-		__entry->hva		= ((gfn - slot->base_gfn) <<
-					    PAGE_SHIFT) + slot->userspace_addr;
-		__entry->referenced	= ref;
-	),
-
-	TP_printk("hva %llx gfn %llx level %u %s",
-		  __entry->hva, __entry->gfn, __entry->level,
-		  __entry->referenced ? "YOUNG" : "OLD")
-);
-
 #ifdef CONFIG_KVM_ASYNC_PF
 DECLARE_EVENT_CLASS(kvm_async_get_page_class,
 
@@ -462,6 +438,72 @@
 	TP_printk("vcpu %d", __entry->vcpu_id)
 );
 
+TRACE_EVENT(kvm_unmap_hva_range,
+	TP_PROTO(unsigned long start, unsigned long end),
+	TP_ARGS(start, end),
+
+	TP_STRUCT__entry(
+		__field(	unsigned long,	start		)
+		__field(	unsigned long,	end		)
+	),
+
+	TP_fast_assign(
+		__entry->start		= start;
+		__entry->end		= end;
+	),
+
+	TP_printk("mmu notifier unmap range: %#016lx -- %#016lx",
+		  __entry->start, __entry->end)
+);
+
+TRACE_EVENT(kvm_set_spte_hva,
+	TP_PROTO(unsigned long hva),
+	TP_ARGS(hva),
+
+	TP_STRUCT__entry(
+		__field(	unsigned long,	hva		)
+	),
+
+	TP_fast_assign(
+		__entry->hva		= hva;
+	),
+
+	TP_printk("mmu notifier set pte hva: %#016lx", __entry->hva)
+);
+
+TRACE_EVENT(kvm_age_hva,
+	TP_PROTO(unsigned long start, unsigned long end),
+	TP_ARGS(start, end),
+
+	TP_STRUCT__entry(
+		__field(	unsigned long,	start		)
+		__field(	unsigned long,	end		)
+	),
+
+	TP_fast_assign(
+		__entry->start		= start;
+		__entry->end		= end;
+	),
+
+	TP_printk("mmu notifier age hva: %#016lx -- %#016lx",
+		  __entry->start, __entry->end)
+);
+
+TRACE_EVENT(kvm_test_age_hva,
+	TP_PROTO(unsigned long hva),
+	TP_ARGS(hva),
+
+	TP_STRUCT__entry(
+		__field(	unsigned long,	hva		)
+	),
+
+	TP_fast_assign(
+		__entry->hva		= hva;
+	),
+
+	TP_printk("mmu notifier test age hva: %#016lx", __entry->hva)
+);
+
 #endif /* _TRACE_KVM_MAIN_H */
 
 /* This part must be outside protection */
diff --git a/include/uapi/linux/kvm.h b/include/uapi/linux/kvm.h
index f6afee2..7c89a9c 100644
--- a/include/uapi/linux/kvm.h
+++ b/include/uapi/linux/kvm.h
@@ -1078,6 +1078,8 @@
 #define KVM_CAP_DIRTY_LOG_RING 192
 #define KVM_CAP_X86_BUS_LOCK_EXIT 193
 #define KVM_CAP_PPC_DAWR1 194
+#define KVM_CAP_VM_COPY_ENC_CONTEXT_FROM 195
+#define KVM_CAP_PTP_KVM 196
 
 #ifdef KVM_CAP_IRQ_ROUTING
 
diff --git a/include/uapi/linux/perf_event.h b/include/uapi/linux/perf_event.h
index ad15e40..63971ea 100644
--- a/include/uapi/linux/perf_event.h
+++ b/include/uapi/linux/perf_event.h
@@ -1156,10 +1156,15 @@
 /**
  * PERF_RECORD_AUX::flags bits
  */
-#define PERF_AUX_FLAG_TRUNCATED		0x01	/* record was truncated to fit */
-#define PERF_AUX_FLAG_OVERWRITE		0x02	/* snapshot from overwrite mode */
-#define PERF_AUX_FLAG_PARTIAL		0x04	/* record contains gaps */
-#define PERF_AUX_FLAG_COLLISION		0x08	/* sample collided with another */
+#define PERF_AUX_FLAG_TRUNCATED			0x01	/* record was truncated to fit */
+#define PERF_AUX_FLAG_OVERWRITE			0x02	/* snapshot from overwrite mode */
+#define PERF_AUX_FLAG_PARTIAL			0x04	/* record contains gaps */
+#define PERF_AUX_FLAG_COLLISION			0x08	/* sample collided with another */
+#define PERF_AUX_FLAG_PMU_FORMAT_TYPE_MASK	0xff00	/* PMU specific trace format type */
+
+/* CoreSight PMU AUX buffer formats */
+#define PERF_AUX_FLAG_CORESIGHT_FORMAT_CORESIGHT	0x0000 /* Default for backward compatibility */
+#define PERF_AUX_FLAG_CORESIGHT_FORMAT_RAW		0x0100 /* Raw format of the source */
 
 #define PERF_FLAG_FD_NO_GROUP		(1UL << 0)
 #define PERF_FLAG_FD_OUTPUT		(1UL << 1)
diff --git a/kernel/time/clocksource.c b/kernel/time/clocksource.c
index cce484a..4fe1df8 100644
--- a/kernel/time/clocksource.c
+++ b/kernel/time/clocksource.c
@@ -920,6 +920,8 @@
 
 	clocksource_arch_init(cs);
 
+	if (WARN_ON_ONCE((unsigned int)cs->id >= CSID_MAX))
+		cs->id = CSID_GENERIC;
 	if (cs->vdso_clock_mode < 0 ||
 	    cs->vdso_clock_mode >= VDSO_CLOCKMODE_MAX) {
 		pr_warn("clocksource %s registered with invalid VDSO mode %d. Disabling VDSO support.\n",
diff --git a/kernel/time/timekeeping.c b/kernel/time/timekeeping.c
index 6aee576..06f55f9 100644
--- a/kernel/time/timekeeping.c
+++ b/kernel/time/timekeeping.c
@@ -1048,6 +1048,7 @@
 	do {
 		seq = read_seqcount_begin(&tk_core.seq);
 		now = tk_clock_read(&tk->tkr_mono);
+		systime_snapshot->cs_id = tk->tkr_mono.clock->id;
 		systime_snapshot->cs_was_changed_seq = tk->cs_was_changed_seq;
 		systime_snapshot->clock_was_set_seq = tk->clock_was_set_seq;
 		base_real = ktime_add(tk->tkr_mono.base,
diff --git a/lib/bug.c b/lib/bug.c
index 8f9d537..45a0584f 100644
--- a/lib/bug.c
+++ b/lib/bug.c
@@ -127,6 +127,22 @@
 }
 #endif
 
+void bug_get_file_line(struct bug_entry *bug, const char **file,
+		       unsigned int *line)
+{
+#ifdef CONFIG_DEBUG_BUGVERBOSE
+#ifndef CONFIG_GENERIC_BUG_RELATIVE_POINTERS
+	*file = bug->file;
+#else
+	*file = (const char *)bug + bug->file_disp;
+#endif
+	*line = bug->line;
+#else
+	*file = NULL;
+	*line = 0;
+#endif
+}
+
 struct bug_entry *find_bug(unsigned long bugaddr)
 {
 	struct bug_entry *bug;
@@ -153,32 +169,20 @@
 
 	disable_trace_on_warning();
 
-	file = NULL;
-	line = 0;
-	warning = 0;
+	bug_get_file_line(bug, &file, &line);
 
-	if (bug) {
-#ifdef CONFIG_DEBUG_BUGVERBOSE
-#ifndef CONFIG_GENERIC_BUG_RELATIVE_POINTERS
-		file = bug->file;
-#else
-		file = (const char *)bug + bug->file_disp;
-#endif
-		line = bug->line;
-#endif
-		warning = (bug->flags & BUGFLAG_WARNING) != 0;
-		once = (bug->flags & BUGFLAG_ONCE) != 0;
-		done = (bug->flags & BUGFLAG_DONE) != 0;
+	warning = (bug->flags & BUGFLAG_WARNING) != 0;
+	once = (bug->flags & BUGFLAG_ONCE) != 0;
+	done = (bug->flags & BUGFLAG_DONE) != 0;
 
-		if (warning && once) {
-			if (done)
-				return BUG_TRAP_TYPE_WARN;
+	if (warning && once) {
+		if (done)
+			return BUG_TRAP_TYPE_WARN;
 
-			/*
-			 * Since this is the only store, concurrency is not an issue.
-			 */
-			bug->flags |= BUGFLAG_DONE;
-		}
+		/*
+		 * Since this is the only store, concurrency is not an issue.
+		 */
+		bug->flags |= BUGFLAG_DONE;
 	}
 
 	/*
diff --git a/tools/testing/selftests/kvm/.gitignore b/tools/testing/selftests/kvm/.gitignore
index 7bd7e77..bb862f9 100644
--- a/tools/testing/selftests/kvm/.gitignore
+++ b/tools/testing/selftests/kvm/.gitignore
@@ -1,6 +1,7 @@
 # SPDX-License-Identifier: GPL-2.0-only
 /aarch64/get-reg-list
 /aarch64/get-reg-list-sve
+/aarch64/vgic_init
 /s390x/memop
 /s390x/resets
 /s390x/sync_regs_test
diff --git a/tools/testing/selftests/kvm/Makefile b/tools/testing/selftests/kvm/Makefile
index 67eebb5..2fd4801 100644
--- a/tools/testing/selftests/kvm/Makefile
+++ b/tools/testing/selftests/kvm/Makefile
@@ -78,6 +78,7 @@
 
 TEST_GEN_PROGS_aarch64 += aarch64/get-reg-list
 TEST_GEN_PROGS_aarch64 += aarch64/get-reg-list-sve
+TEST_GEN_PROGS_aarch64 += aarch64/vgic_init
 TEST_GEN_PROGS_aarch64 += demand_paging_test
 TEST_GEN_PROGS_aarch64 += dirty_log_test
 TEST_GEN_PROGS_aarch64 += dirty_log_perf_test
diff --git a/tools/testing/selftests/kvm/aarch64/vgic_init.c b/tools/testing/selftests/kvm/aarch64/vgic_init.c
new file mode 100644
index 0000000..623f31a
--- /dev/null
+++ b/tools/testing/selftests/kvm/aarch64/vgic_init.c
@@ -0,0 +1,551 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * vgic init sequence tests
+ *
+ * Copyright (C) 2020, Red Hat, Inc.
+ */
+#define _GNU_SOURCE
+#include <linux/kernel.h>
+#include <sys/syscall.h>
+#include <asm/kvm.h>
+#include <asm/kvm_para.h>
+
+#include "test_util.h"
+#include "kvm_util.h"
+#include "processor.h"
+
+#define NR_VCPUS		4
+
+#define REDIST_REGION_ATTR_ADDR(count, base, flags, index) (((uint64_t)(count) << 52) | \
+	((uint64_t)((base) >> 16) << 16) | ((uint64_t)(flags) << 12) | index)
+#define REG_OFFSET(vcpu, offset) (((uint64_t)vcpu << 32) | offset)
+
+#define GICR_TYPER 0x8
+
+struct vm_gic {
+	struct kvm_vm *vm;
+	int gic_fd;
+};
+
+static int max_ipa_bits;
+
+/* helper to access a redistributor register */
+static int access_redist_reg(int gicv3_fd, int vcpu, int offset,
+			     uint32_t *val, bool write)
+{
+	uint64_t attr = REG_OFFSET(vcpu, offset);
+
+	return _kvm_device_access(gicv3_fd, KVM_DEV_ARM_VGIC_GRP_REDIST_REGS,
+				  attr, val, write);
+}
+
+/* dummy guest code */
+static void guest_code(void)
+{
+	GUEST_SYNC(0);
+	GUEST_SYNC(1);
+	GUEST_SYNC(2);
+	GUEST_DONE();
+}
+
+/* we don't want to assert on run execution, hence that helper */
+static int run_vcpu(struct kvm_vm *vm, uint32_t vcpuid)
+{
+	ucall_init(vm, NULL);
+	int ret = _vcpu_ioctl(vm, vcpuid, KVM_RUN, NULL);
+	if (ret)
+		return -errno;
+	return 0;
+}
+
+static struct vm_gic vm_gic_create(void)
+{
+	struct vm_gic v;
+
+	v.vm = vm_create_default_with_vcpus(NR_VCPUS, 0, 0, guest_code, NULL);
+	v.gic_fd = kvm_create_device(v.vm, KVM_DEV_TYPE_ARM_VGIC_V3, false);
+
+	return v;
+}
+
+static void vm_gic_destroy(struct vm_gic *v)
+{
+	close(v->gic_fd);
+	kvm_vm_free(v->vm);
+}
+
+/**
+ * Helper routine that performs KVM device tests in general and
+ * especially ARM_VGIC_V3 ones. Eventually the ARM_VGIC_V3
+ * device gets created, a legacy RDIST region is set at @0x0
+ * and a DIST region is set @0x60000
+ */
+static void subtest_dist_rdist(struct vm_gic *v)
+{
+	int ret;
+	uint64_t addr;
+
+	/* Check existing group/attributes */
+	kvm_device_check_attr(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			      KVM_VGIC_V3_ADDR_TYPE_DIST);
+
+	kvm_device_check_attr(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			      KVM_VGIC_V3_ADDR_TYPE_REDIST);
+
+	/* check non existing attribute */
+	ret = _kvm_device_check_attr(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR, 0);
+	TEST_ASSERT(ret && errno == ENXIO, "attribute not supported");
+
+	/* misaligned DIST and REDIST address settings */
+	addr = 0x1000;
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_DIST, &addr, true);
+	TEST_ASSERT(ret && errno == EINVAL, "GICv3 dist base not 64kB aligned");
+
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST, &addr, true);
+	TEST_ASSERT(ret && errno == EINVAL, "GICv3 redist base not 64kB aligned");
+
+	/* out of range address */
+	if (max_ipa_bits) {
+		addr = 1ULL << max_ipa_bits;
+		ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+					 KVM_VGIC_V3_ADDR_TYPE_DIST, &addr, true);
+		TEST_ASSERT(ret && errno == E2BIG, "dist address beyond IPA limit");
+
+		ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+					 KVM_VGIC_V3_ADDR_TYPE_REDIST, &addr, true);
+		TEST_ASSERT(ret && errno == E2BIG, "redist address beyond IPA limit");
+	}
+
+	/* set REDIST base address @0x0*/
+	addr = 0x00000;
+	kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			  KVM_VGIC_V3_ADDR_TYPE_REDIST, &addr, true);
+
+	/* Attempt to create a second legacy redistributor region */
+	addr = 0xE0000;
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST, &addr, true);
+	TEST_ASSERT(ret && errno == EEXIST, "GICv3 redist base set again");
+
+	/* Attempt to mix legacy and new redistributor regions */
+	addr = REDIST_REGION_ATTR_ADDR(NR_VCPUS, 0x100000, 0, 0);
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+	TEST_ASSERT(ret && errno == EINVAL, "attempt to mix GICv3 REDIST and REDIST_REGION");
+
+	/*
+	 * Set overlapping DIST / REDIST, cannot be detected here. Will be detected
+	 * on first vcpu run instead.
+	 */
+	addr = 3 * 2 * 0x10000;
+	kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR, KVM_VGIC_V3_ADDR_TYPE_DIST,
+			  &addr, true);
+}
+
+/* Test the new REDIST region API */
+static void subtest_redist_regions(struct vm_gic *v)
+{
+	uint64_t addr, expected_addr;
+	int ret;
+
+	ret = kvm_device_check_attr(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				     KVM_VGIC_V3_ADDR_TYPE_REDIST);
+	TEST_ASSERT(!ret, "Multiple redist regions advertised");
+
+	addr = REDIST_REGION_ATTR_ADDR(NR_VCPUS, 0x100000, 2, 0);
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+	TEST_ASSERT(ret && errno == EINVAL, "redist region attr value with flags != 0");
+
+	addr = REDIST_REGION_ATTR_ADDR(0, 0x100000, 0, 0);
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+	TEST_ASSERT(ret && errno == EINVAL, "redist region attr value with count== 0");
+
+	addr = REDIST_REGION_ATTR_ADDR(2, 0x200000, 0, 1);
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+	TEST_ASSERT(ret && errno == EINVAL,
+		    "attempt to register the first rdist region with index != 0");
+
+	addr = REDIST_REGION_ATTR_ADDR(2, 0x201000, 0, 1);
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+	TEST_ASSERT(ret && errno == EINVAL, "rdist region with misaligned address");
+
+	addr = REDIST_REGION_ATTR_ADDR(2, 0x200000, 0, 0);
+	kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			  KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+
+	addr = REDIST_REGION_ATTR_ADDR(2, 0x200000, 0, 1);
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+	TEST_ASSERT(ret && errno == EINVAL, "register an rdist region with already used index");
+
+	addr = REDIST_REGION_ATTR_ADDR(1, 0x210000, 0, 2);
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+	TEST_ASSERT(ret && errno == EINVAL,
+		    "register an rdist region overlapping with another one");
+
+	addr = REDIST_REGION_ATTR_ADDR(1, 0x240000, 0, 2);
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+	TEST_ASSERT(ret && errno == EINVAL, "register redist region with index not +1");
+
+	addr = REDIST_REGION_ATTR_ADDR(1, 0x240000, 0, 1);
+	kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			  KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+
+	addr = REDIST_REGION_ATTR_ADDR(1, 1ULL << max_ipa_bits, 0, 2);
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+	TEST_ASSERT(ret && errno == E2BIG,
+		    "register redist region with base address beyond IPA range");
+
+	addr = 0x260000;
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST, &addr, true);
+	TEST_ASSERT(ret && errno == EINVAL,
+		    "Mix KVM_VGIC_V3_ADDR_TYPE_REDIST and REDIST_REGION");
+
+	/*
+	 * Now there are 2 redist regions:
+	 * region 0 @ 0x200000 2 redists
+	 * region 1 @ 0x240000 1 redist
+	 * Attempt to read their characteristics
+	 */
+
+	addr = REDIST_REGION_ATTR_ADDR(0, 0, 0, 0);
+	expected_addr = REDIST_REGION_ATTR_ADDR(2, 0x200000, 0, 0);
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, false);
+	TEST_ASSERT(!ret && addr == expected_addr, "read characteristics of region #0");
+
+	addr = REDIST_REGION_ATTR_ADDR(0, 0, 0, 1);
+	expected_addr = REDIST_REGION_ATTR_ADDR(1, 0x240000, 0, 1);
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, false);
+	TEST_ASSERT(!ret && addr == expected_addr, "read characteristics of region #1");
+
+	addr = REDIST_REGION_ATTR_ADDR(0, 0, 0, 2);
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, false);
+	TEST_ASSERT(ret && errno == ENOENT, "read characteristics of non existing region");
+
+	addr = 0x260000;
+	kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			  KVM_VGIC_V3_ADDR_TYPE_DIST, &addr, true);
+
+	addr = REDIST_REGION_ATTR_ADDR(1, 0x260000, 0, 2);
+	ret = _kvm_device_access(v->gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+	TEST_ASSERT(ret && errno == EINVAL, "register redist region colliding with dist");
+}
+
+/*
+ * VGIC KVM device is created and initialized before the secondary CPUs
+ * get created
+ */
+static void test_vgic_then_vcpus(void)
+{
+	struct vm_gic v;
+	int ret, i;
+
+	v.vm = vm_create_default(0, 0, guest_code);
+	v.gic_fd = kvm_create_device(v.vm, KVM_DEV_TYPE_ARM_VGIC_V3, false);
+
+	subtest_dist_rdist(&v);
+
+	/* Add the rest of the VCPUs */
+	for (i = 1; i < NR_VCPUS; ++i)
+		vm_vcpu_add_default(v.vm, i, guest_code);
+
+	ret = run_vcpu(v.vm, 3);
+	TEST_ASSERT(ret == -EINVAL, "dist/rdist overlap detected on 1st vcpu run");
+
+	vm_gic_destroy(&v);
+}
+
+/* All the VCPUs are created before the VGIC KVM device gets initialized */
+static void test_vcpus_then_vgic(void)
+{
+	struct vm_gic v;
+	int ret;
+
+	v = vm_gic_create();
+
+	subtest_dist_rdist(&v);
+
+	ret = run_vcpu(v.vm, 3);
+	TEST_ASSERT(ret == -EINVAL, "dist/rdist overlap detected on 1st vcpu run");
+
+	vm_gic_destroy(&v);
+}
+
+static void test_new_redist_regions(void)
+{
+	void *dummy = NULL;
+	struct vm_gic v;
+	uint64_t addr;
+	int ret;
+
+	v = vm_gic_create();
+	subtest_redist_regions(&v);
+	kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_CTRL,
+			  KVM_DEV_ARM_VGIC_CTRL_INIT, NULL, true);
+
+	ret = run_vcpu(v.vm, 3);
+	TEST_ASSERT(ret == -ENXIO, "running without sufficient number of rdists");
+	vm_gic_destroy(&v);
+
+	/* step2 */
+
+	v = vm_gic_create();
+	subtest_redist_regions(&v);
+
+	addr = REDIST_REGION_ATTR_ADDR(1, 0x280000, 0, 2);
+	kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			  KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+
+	ret = run_vcpu(v.vm, 3);
+	TEST_ASSERT(ret == -EBUSY, "running without vgic explicit init");
+
+	vm_gic_destroy(&v);
+
+	/* step 3 */
+
+	v = vm_gic_create();
+	subtest_redist_regions(&v);
+
+	_kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			  KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, dummy, true);
+	TEST_ASSERT(ret && errno == EFAULT,
+		    "register a third region allowing to cover the 4 vcpus");
+
+	addr = REDIST_REGION_ATTR_ADDR(1, 0x280000, 0, 2);
+	kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			  KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+
+	kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_CTRL,
+			  KVM_DEV_ARM_VGIC_CTRL_INIT, NULL, true);
+
+	ret = run_vcpu(v.vm, 3);
+	TEST_ASSERT(!ret, "vcpu run");
+
+	vm_gic_destroy(&v);
+}
+
+static void test_typer_accesses(void)
+{
+	struct vm_gic v;
+	uint64_t addr;
+	uint32_t val;
+	int ret, i;
+
+	v.vm = vm_create_default(0, 0, guest_code);
+
+	v.gic_fd = kvm_create_device(v.vm, KVM_DEV_TYPE_ARM_VGIC_V3, false);
+
+	vm_vcpu_add_default(v.vm, 3, guest_code);
+
+	ret = access_redist_reg(v.gic_fd, 1, GICR_TYPER, &val, false);
+	TEST_ASSERT(ret && errno == EINVAL, "attempting to read GICR_TYPER of non created vcpu");
+
+	vm_vcpu_add_default(v.vm, 1, guest_code);
+
+	ret = access_redist_reg(v.gic_fd, 1, GICR_TYPER, &val, false);
+	TEST_ASSERT(ret && errno == EBUSY, "read GICR_TYPER before GIC initialized");
+
+	vm_vcpu_add_default(v.vm, 2, guest_code);
+
+	kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_CTRL,
+			  KVM_DEV_ARM_VGIC_CTRL_INIT, NULL, true);
+
+	for (i = 0; i < NR_VCPUS ; i++) {
+		ret = access_redist_reg(v.gic_fd, 0, GICR_TYPER, &val, false);
+		TEST_ASSERT(!ret && !val, "read GICR_TYPER before rdist region setting");
+	}
+
+	addr = REDIST_REGION_ATTR_ADDR(2, 0x200000, 0, 0);
+	kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			  KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+
+	/* The 2 first rdists should be put there (vcpu 0 and 3) */
+	ret = access_redist_reg(v.gic_fd, 0, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && !val, "read typer of rdist #0");
+
+	ret = access_redist_reg(v.gic_fd, 3, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x310, "read typer of rdist #1");
+
+	addr = REDIST_REGION_ATTR_ADDR(10, 0x100000, 0, 1);
+	ret = _kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+				 KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+	TEST_ASSERT(ret && errno == EINVAL, "collision with previous rdist region");
+
+	ret = access_redist_reg(v.gic_fd, 1, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x100,
+		    "no redist region attached to vcpu #1 yet, last cannot be returned");
+
+	ret = access_redist_reg(v.gic_fd, 2, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x200,
+		    "no redist region attached to vcpu #2, last cannot be returned");
+
+	addr = REDIST_REGION_ATTR_ADDR(10, 0x20000, 0, 1);
+	kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			  KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+
+	ret = access_redist_reg(v.gic_fd, 1, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x100, "read typer of rdist #1");
+
+	ret = access_redist_reg(v.gic_fd, 2, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x210,
+		    "read typer of rdist #1, last properly returned");
+
+	vm_gic_destroy(&v);
+}
+
+/**
+ * Test GICR_TYPER last bit with new redist regions
+ * rdist regions #1 and #2 are contiguous
+ * rdist region #0 @0x100000 2 rdist capacity
+ *     rdists: 0, 3 (Last)
+ * rdist region #1 @0x240000 2 rdist capacity
+ *     rdists:  5, 4 (Last)
+ * rdist region #2 @0x200000 2 rdist capacity
+ *     rdists: 1, 2
+ */
+static void test_last_bit_redist_regions(void)
+{
+	uint32_t vcpuids[] = { 0, 3, 5, 4, 1, 2 };
+	struct vm_gic v;
+	uint64_t addr;
+	uint32_t val;
+	int ret;
+
+	v.vm = vm_create_default_with_vcpus(6, 0, 0, guest_code, vcpuids);
+
+	v.gic_fd = kvm_create_device(v.vm, KVM_DEV_TYPE_ARM_VGIC_V3, false);
+
+	kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_CTRL,
+			  KVM_DEV_ARM_VGIC_CTRL_INIT, NULL, true);
+
+	addr = REDIST_REGION_ATTR_ADDR(2, 0x100000, 0, 0);
+	kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			  KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+
+	addr = REDIST_REGION_ATTR_ADDR(2, 0x240000, 0, 1);
+	kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			  KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+
+	addr = REDIST_REGION_ATTR_ADDR(2, 0x200000, 0, 2);
+	kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			  KVM_VGIC_V3_ADDR_TYPE_REDIST_REGION, &addr, true);
+
+	ret = access_redist_reg(v.gic_fd, 0, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x000, "read typer of rdist #0");
+
+	ret = access_redist_reg(v.gic_fd, 1, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x100, "read typer of rdist #1");
+
+	ret = access_redist_reg(v.gic_fd, 2, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x200, "read typer of rdist #2");
+
+	ret = access_redist_reg(v.gic_fd, 3, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x310, "read typer of rdist #3");
+
+	ret = access_redist_reg(v.gic_fd, 5, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x500, "read typer of rdist #5");
+
+	ret = access_redist_reg(v.gic_fd, 4, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x410, "read typer of rdist #4");
+
+	vm_gic_destroy(&v);
+}
+
+/* Test last bit with legacy region */
+static void test_last_bit_single_rdist(void)
+{
+	uint32_t vcpuids[] = { 0, 3, 5, 4, 1, 2 };
+	struct vm_gic v;
+	uint64_t addr;
+	uint32_t val;
+	int ret;
+
+	v.vm = vm_create_default_with_vcpus(6, 0, 0, guest_code, vcpuids);
+
+	v.gic_fd = kvm_create_device(v.vm, KVM_DEV_TYPE_ARM_VGIC_V3, false);
+
+	kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_CTRL,
+			  KVM_DEV_ARM_VGIC_CTRL_INIT, NULL, true);
+
+	addr = 0x10000;
+	kvm_device_access(v.gic_fd, KVM_DEV_ARM_VGIC_GRP_ADDR,
+			  KVM_VGIC_V3_ADDR_TYPE_REDIST, &addr, true);
+
+	ret = access_redist_reg(v.gic_fd, 0, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x000, "read typer of rdist #0");
+
+	ret = access_redist_reg(v.gic_fd, 3, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x300, "read typer of rdist #1");
+
+	ret = access_redist_reg(v.gic_fd, 5, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x500, "read typer of rdist #2");
+
+	ret = access_redist_reg(v.gic_fd, 1, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x100, "read typer of rdist #3");
+
+	ret = access_redist_reg(v.gic_fd, 2, GICR_TYPER, &val, false);
+	TEST_ASSERT(!ret && val == 0x210, "read typer of rdist #3");
+
+	vm_gic_destroy(&v);
+}
+
+void test_kvm_device(void)
+{
+	struct vm_gic v;
+	int ret, fd;
+
+	v.vm = vm_create_default_with_vcpus(NR_VCPUS, 0, 0, guest_code, NULL);
+
+	/* try to create a non existing KVM device */
+	ret = _kvm_create_device(v.vm, 0, true, &fd);
+	TEST_ASSERT(ret && errno == ENODEV, "unsupported device");
+
+	/* trial mode with VGIC_V3 device */
+	ret = _kvm_create_device(v.vm, KVM_DEV_TYPE_ARM_VGIC_V3, true, &fd);
+	if (ret) {
+		print_skip("GICv3 not supported");
+		exit(KSFT_SKIP);
+	}
+	v.gic_fd = kvm_create_device(v.vm, KVM_DEV_TYPE_ARM_VGIC_V3, false);
+
+	ret = _kvm_create_device(v.vm, KVM_DEV_TYPE_ARM_VGIC_V3, false, &fd);
+	TEST_ASSERT(ret && errno == EEXIST, "create GICv3 device twice");
+
+	kvm_create_device(v.vm, KVM_DEV_TYPE_ARM_VGIC_V3, true);
+
+	if (!_kvm_create_device(v.vm, KVM_DEV_TYPE_ARM_VGIC_V2, true, &fd)) {
+		ret = _kvm_create_device(v.vm, KVM_DEV_TYPE_ARM_VGIC_V2, false, &fd);
+		TEST_ASSERT(ret && errno == EINVAL, "create GICv2 while v3 exists");
+	}
+
+	vm_gic_destroy(&v);
+}
+
+int main(int ac, char **av)
+{
+	max_ipa_bits = kvm_check_cap(KVM_CAP_ARM_VM_IPA_SIZE);
+
+	test_kvm_device();
+	test_vcpus_then_vgic();
+	test_vgic_then_vcpus();
+	test_new_redist_regions();
+	test_typer_accesses();
+	test_last_bit_redist_regions();
+	test_last_bit_single_rdist();
+
+	return 0;
+}
diff --git a/tools/testing/selftests/kvm/include/kvm_util.h b/tools/testing/selftests/kvm/include/kvm_util.h
index 0f4258e..56ac3cf 100644
--- a/tools/testing/selftests/kvm/include/kvm_util.h
+++ b/tools/testing/selftests/kvm/include/kvm_util.h
@@ -225,6 +225,15 @@
 #endif
 void *vcpu_map_dirty_ring(struct kvm_vm *vm, uint32_t vcpuid);
 
+int _kvm_device_check_attr(int dev_fd, uint32_t group, uint64_t attr);
+int kvm_device_check_attr(int dev_fd, uint32_t group, uint64_t attr);
+int _kvm_create_device(struct kvm_vm *vm, uint64_t type, bool test, int *fd);
+int kvm_create_device(struct kvm_vm *vm, uint64_t type, bool test);
+int _kvm_device_access(int dev_fd, uint32_t group, uint64_t attr,
+		       void *val, bool write);
+int kvm_device_access(int dev_fd, uint32_t group, uint64_t attr,
+		      void *val, bool write);
+
 const char *exit_reason_str(unsigned int exit_reason);
 
 void virt_pgd_alloc(struct kvm_vm *vm, uint32_t pgd_memslot);
diff --git a/tools/testing/selftests/kvm/lib/kvm_util.c b/tools/testing/selftests/kvm/lib/kvm_util.c
index b8849a1..7711124 100644
--- a/tools/testing/selftests/kvm/lib/kvm_util.c
+++ b/tools/testing/selftests/kvm/lib/kvm_util.c
@@ -1734,6 +1734,81 @@
 }
 
 /*
+ * Device Ioctl
+ */
+
+int _kvm_device_check_attr(int dev_fd, uint32_t group, uint64_t attr)
+{
+	struct kvm_device_attr attribute = {
+		.group = group,
+		.attr = attr,
+		.flags = 0,
+	};
+
+	return ioctl(dev_fd, KVM_HAS_DEVICE_ATTR, &attribute);
+}
+
+int kvm_device_check_attr(int dev_fd, uint32_t group, uint64_t attr)
+{
+	int ret = _kvm_device_check_attr(dev_fd, group, attr);
+
+	TEST_ASSERT(ret >= 0, "KVM_HAS_DEVICE_ATTR failed, rc: %i errno: %i", ret, errno);
+	return ret;
+}
+
+int _kvm_create_device(struct kvm_vm *vm, uint64_t type, bool test, int *fd)
+{
+	struct kvm_create_device create_dev;
+	int ret;
+
+	create_dev.type = type;
+	create_dev.fd = -1;
+	create_dev.flags = test ? KVM_CREATE_DEVICE_TEST : 0;
+	ret = ioctl(vm_get_fd(vm), KVM_CREATE_DEVICE, &create_dev);
+	*fd = create_dev.fd;
+	return ret;
+}
+
+int kvm_create_device(struct kvm_vm *vm, uint64_t type, bool test)
+{
+	int fd, ret;
+
+	ret = _kvm_create_device(vm, type, test, &fd);
+
+	if (!test) {
+		TEST_ASSERT(ret >= 0,
+			    "KVM_CREATE_DEVICE IOCTL failed, rc: %i errno: %i", ret, errno);
+		return fd;
+	}
+	return ret;
+}
+
+int _kvm_device_access(int dev_fd, uint32_t group, uint64_t attr,
+		      void *val, bool write)
+{
+	struct kvm_device_attr kvmattr = {
+		.group = group,
+		.attr = attr,
+		.flags = 0,
+		.addr = (uintptr_t)val,
+	};
+	int ret;
+
+	ret = ioctl(dev_fd, write ? KVM_SET_DEVICE_ATTR : KVM_GET_DEVICE_ATTR,
+		    &kvmattr);
+	return ret;
+}
+
+int kvm_device_access(int dev_fd, uint32_t group, uint64_t attr,
+		      void *val, bool write)
+{
+	int ret = _kvm_device_access(dev_fd, group, attr, val, write);
+
+	TEST_ASSERT(ret >= 0, "KVM_SET|GET_DEVICE_ATTR IOCTL failed, rc: %i errno: %i", ret, errno);
+	return ret;
+}
+
+/*
  * VM Dump
  *
  * Input Args:
diff --git a/tools/testing/selftests/kvm/x86_64/xen_shinfo_test.c b/tools/testing/selftests/kvm/x86_64/xen_shinfo_test.c
index 804ff5f..1f4a059 100644
--- a/tools/testing/selftests/kvm/x86_64/xen_shinfo_test.c
+++ b/tools/testing/selftests/kvm/x86_64/xen_shinfo_test.c
@@ -186,7 +186,7 @@
 		vcpu_ioctl(vm, VCPU_ID, KVM_XEN_VCPU_SET_ATTR, &st);
 	}
 
-	struct vcpu_runstate_info *rs = addr_gpa2hva(vm, RUNSTATE_ADDR);;
+	struct vcpu_runstate_info *rs = addr_gpa2hva(vm, RUNSTATE_ADDR);
 	rs->state = 0x5a;
 
 	for (;;) {
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index 383df23..5fb69f8 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -451,35 +451,189 @@
 	srcu_read_unlock(&kvm->srcu, idx);
 }
 
+typedef bool (*hva_handler_t)(struct kvm *kvm, struct kvm_gfn_range *range);
+
+typedef void (*on_lock_fn_t)(struct kvm *kvm, unsigned long start,
+			     unsigned long end);
+
+struct kvm_hva_range {
+	unsigned long start;
+	unsigned long end;
+	pte_t pte;
+	hva_handler_t handler;
+	on_lock_fn_t on_lock;
+	bool must_lock;
+	bool flush_on_ret;
+	bool may_block;
+};
+
+/*
+ * Use a dedicated stub instead of NULL to indicate that there is no callback
+ * function/handler.  The compiler technically can't guarantee that a real
+ * function will have a non-zero address, and so it will generate code to
+ * check for !NULL, whereas comparing against a stub will be elided at compile
+ * time (unless the compiler is getting long in the tooth, e.g. gcc 4.9).
+ */
+static void kvm_null_fn(void)
+{
+
+}
+#define IS_KVM_NULL_FN(fn) ((fn) == (void *)kvm_null_fn)
+
+
+/* Acquire mmu_lock if necessary.  Returns %true if @handler is "null" */
+static __always_inline bool kvm_mmu_lock_and_check_handler(struct kvm *kvm,
+							   const struct kvm_hva_range *range,
+							   bool *locked)
+{
+	if (*locked)
+		return false;
+
+	*locked = true;
+
+	KVM_MMU_LOCK(kvm);
+
+	if (!IS_KVM_NULL_FN(range->on_lock))
+		range->on_lock(kvm, range->start, range->end);
+
+	return IS_KVM_NULL_FN(range->handler);
+}
+
+static __always_inline int __kvm_handle_hva_range(struct kvm *kvm,
+						  const struct kvm_hva_range *range)
+{
+	bool ret = false, locked = false;
+	struct kvm_gfn_range gfn_range;
+	struct kvm_memory_slot *slot;
+	struct kvm_memslots *slots;
+	int i, idx;
+
+	/* A null handler is allowed if and only if on_lock() is provided. */
+	if (WARN_ON_ONCE(IS_KVM_NULL_FN(range->on_lock) &&
+			 IS_KVM_NULL_FN(range->handler)))
+		return 0;
+
+	idx = srcu_read_lock(&kvm->srcu);
+
+	if (range->must_lock &&
+	    kvm_mmu_lock_and_check_handler(kvm, range, &locked))
+		goto out_unlock;
+
+	for (i = 0; i < KVM_ADDRESS_SPACE_NUM; i++) {
+		slots = __kvm_memslots(kvm, i);
+		kvm_for_each_memslot(slot, slots) {
+			unsigned long hva_start, hva_end;
+
+			hva_start = max(range->start, slot->userspace_addr);
+			hva_end = min(range->end, slot->userspace_addr +
+						  (slot->npages << PAGE_SHIFT));
+			if (hva_start >= hva_end)
+				continue;
+
+			/*
+			 * To optimize for the likely case where the address
+			 * range is covered by zero or one memslots, don't
+			 * bother making these conditional (to avoid writes on
+			 * the second or later invocation of the handler).
+			 */
+			gfn_range.pte = range->pte;
+			gfn_range.may_block = range->may_block;
+
+			/*
+			 * {gfn(page) | page intersects with [hva_start, hva_end)} =
+			 * {gfn_start, gfn_start+1, ..., gfn_end-1}.
+			 */
+			gfn_range.start = hva_to_gfn_memslot(hva_start, slot);
+			gfn_range.end = hva_to_gfn_memslot(hva_end + PAGE_SIZE - 1, slot);
+			gfn_range.slot = slot;
+
+			if (kvm_mmu_lock_and_check_handler(kvm, range, &locked))
+				goto out_unlock;
+
+			ret |= range->handler(kvm, &gfn_range);
+		}
+	}
+
+	if (range->flush_on_ret && (ret || kvm->tlbs_dirty))
+		kvm_flush_remote_tlbs(kvm);
+
+out_unlock:
+	if (locked)
+		KVM_MMU_UNLOCK(kvm);
+
+	srcu_read_unlock(&kvm->srcu, idx);
+
+	/* The notifiers are averse to booleans. :-( */
+	return (int)ret;
+}
+
+static __always_inline int kvm_handle_hva_range(struct mmu_notifier *mn,
+						unsigned long start,
+						unsigned long end,
+						pte_t pte,
+						hva_handler_t handler)
+{
+	struct kvm *kvm = mmu_notifier_to_kvm(mn);
+	const struct kvm_hva_range range = {
+		.start		= start,
+		.end		= end,
+		.pte		= pte,
+		.handler	= handler,
+		.on_lock	= (void *)kvm_null_fn,
+		.must_lock	= false,
+		.flush_on_ret	= true,
+		.may_block	= false,
+	};
+
+	return __kvm_handle_hva_range(kvm, &range);
+}
+
+static __always_inline int kvm_handle_hva_range_no_flush(struct mmu_notifier *mn,
+							 unsigned long start,
+							 unsigned long end,
+							 hva_handler_t handler)
+{
+	struct kvm *kvm = mmu_notifier_to_kvm(mn);
+	const struct kvm_hva_range range = {
+		.start		= start,
+		.end		= end,
+		.pte		= __pte(0),
+		.handler	= handler,
+		.on_lock	= (void *)kvm_null_fn,
+		.must_lock	= false,
+		.flush_on_ret	= false,
+		.may_block	= false,
+	};
+
+	return __kvm_handle_hva_range(kvm, &range);
+}
 static void kvm_mmu_notifier_change_pte(struct mmu_notifier *mn,
 					struct mm_struct *mm,
 					unsigned long address,
 					pte_t pte)
 {
 	struct kvm *kvm = mmu_notifier_to_kvm(mn);
-	int idx;
 
-	idx = srcu_read_lock(&kvm->srcu);
+	trace_kvm_set_spte_hva(address);
 
-	KVM_MMU_LOCK(kvm);
+	/*
+	 * .change_pte() must be surrounded by .invalidate_range_{start,end}(),
+	 * If mmu_notifier_count is zero, then start() didn't find a relevant
+	 * memslot and wasn't forced down the slow path; rechecking here is
+	 * unnecessary.  This can only occur if memslot updates are blocked;
+	 * otherwise, mmu_notifier_count is incremented unconditionally.
+	 */
+	if (!kvm->mmu_notifier_count) {
+		lockdep_assert_held(&kvm->mmu_notifier_slots_lock);
+		return;
+	}
 
-	kvm->mmu_notifier_seq++;
-
-	if (kvm_set_spte_hva(kvm, address, pte))
-		kvm_flush_remote_tlbs(kvm);
-
-	KVM_MMU_UNLOCK(kvm);
-	srcu_read_unlock(&kvm->srcu, idx);
+	kvm_handle_hva_range(mn, address, address + 1, pte, kvm_set_spte_gfn);
 }
 
-static int kvm_mmu_notifier_invalidate_range_start(struct mmu_notifier *mn,
-					const struct mmu_notifier_range *range)
+static void kvm_inc_notifier_count(struct kvm *kvm, unsigned long start,
+				   unsigned long end)
 {
-	struct kvm *kvm = mmu_notifier_to_kvm(mn);
-	int need_tlb_flush = 0, idx;
-
-	idx = srcu_read_lock(&kvm->srcu);
-	KVM_MMU_LOCK(kvm);
 	/*
 	 * The count increase must become visible at unlock time as no
 	 * spte can be established without taking the mmu_lock and
@@ -487,8 +641,8 @@
 	 */
 	kvm->mmu_notifier_count++;
 	if (likely(kvm->mmu_notifier_count == 1)) {
-		kvm->mmu_notifier_range_start = range->start;
-		kvm->mmu_notifier_range_end = range->end;
+		kvm->mmu_notifier_range_start = start;
+		kvm->mmu_notifier_range_end = end;
 	} else {
 		/*
 		 * Fully tracking multiple concurrent ranges has dimishing
@@ -500,28 +654,54 @@
 		 * complete.
 		 */
 		kvm->mmu_notifier_range_start =
-			min(kvm->mmu_notifier_range_start, range->start);
+			min(kvm->mmu_notifier_range_start, start);
 		kvm->mmu_notifier_range_end =
-			max(kvm->mmu_notifier_range_end, range->end);
+			max(kvm->mmu_notifier_range_end, end);
 	}
-	need_tlb_flush = kvm_unmap_hva_range(kvm, range->start, range->end,
-					     range->flags);
-	/* we've to flush the tlb before the pages can be freed */
-	if (need_tlb_flush || kvm->tlbs_dirty)
-		kvm_flush_remote_tlbs(kvm);
+}
 
-	KVM_MMU_UNLOCK(kvm);
-	srcu_read_unlock(&kvm->srcu, idx);
+static int kvm_mmu_notifier_invalidate_range_start(struct mmu_notifier *mn,
+					const struct mmu_notifier_range *range)
+{
+	bool blockable = mmu_notifier_range_blockable(range);
+	struct kvm *kvm = mmu_notifier_to_kvm(mn);
+	const struct kvm_hva_range hva_range = {
+		.start		= range->start,
+		.end		= range->end,
+		.pte		= __pte(0),
+		.handler	= kvm_unmap_gfn_range,
+		.on_lock	= kvm_inc_notifier_count,
+		.must_lock	= !blockable,
+		.flush_on_ret	= true,
+		.may_block	= blockable,
+	};
+
+	trace_kvm_unmap_hva_range(range->start, range->end);
+
+	/*
+	 * Prevent memslot modification between range_start() and range_end()
+	 * so that conditionally locking provides the same result in both
+	 * functions.  Without that guarantee, the mmu_notifier_count
+	 * adjustments will be imbalanced.
+	 *
+	 * Skip the memslot-lookup lock elision (set @must_lock above) to avoid
+	 * having to take the semaphore on non-blockable calls, e.g. OOM kill.
+	 * The complexity required to handle conditional locking for this case
+	 * is not worth the marginal benefits, the VM is likely doomed anyways.
+	 *
+	 * Pairs with the up_read in range_end().
+	 */
+	if (blockable)
+		down_read(&kvm->mmu_notifier_slots_lock);
+
+	__kvm_handle_hva_range(kvm, &hva_range);
 
 	return 0;
 }
 
-static void kvm_mmu_notifier_invalidate_range_end(struct mmu_notifier *mn,
-					const struct mmu_notifier_range *range)
+static void kvm_dec_notifier_count(struct kvm *kvm, unsigned long start,
+				   unsigned long end)
 {
-	struct kvm *kvm = mmu_notifier_to_kvm(mn);
-
-	KVM_MMU_LOCK(kvm);
 	/*
 	 * This sequence increase will notify the kvm page fault that
 	 * the page that is going to be mapped in the spte could have
@@ -535,7 +715,29 @@
 	 * in conjunction with the smp_rmb in mmu_notifier_retry().
 	 */
 	kvm->mmu_notifier_count--;
-	KVM_MMU_UNLOCK(kvm);
+}
+
+static void kvm_mmu_notifier_invalidate_range_end(struct mmu_notifier *mn,
+					const struct mmu_notifier_range *range)
+{
+	bool blockable = mmu_notifier_range_blockable(range);
+	struct kvm *kvm = mmu_notifier_to_kvm(mn);
+	const struct kvm_hva_range hva_range = {
+		.start		= range->start,
+		.end		= range->end,
+		.pte		= __pte(0),
+		.handler	= (void *)kvm_null_fn,
+		.on_lock	= kvm_dec_notifier_count,
+		.must_lock	= !blockable,
+		.flush_on_ret	= false,
+		.may_block	= blockable,
+	};
+
+	__kvm_handle_hva_range(kvm, &hva_range);
+
+	/* Pairs with the down_read in range_start(). */
+	if (blockable)
+		up_read(&kvm->mmu_notifier_slots_lock);
 
 	BUG_ON(kvm->mmu_notifier_count < 0);
 }
@@ -545,20 +747,9 @@
 					      unsigned long start,
 					      unsigned long end)
 {
-	struct kvm *kvm = mmu_notifier_to_kvm(mn);
-	int young, idx;
+	trace_kvm_age_hva(start, end);
 
-	idx = srcu_read_lock(&kvm->srcu);
-	KVM_MMU_LOCK(kvm);
-
-	young = kvm_age_hva(kvm, start, end);
-	if (young)
-		kvm_flush_remote_tlbs(kvm);
-
-	KVM_MMU_UNLOCK(kvm);
-	srcu_read_unlock(&kvm->srcu, idx);
-
-	return young;
+	return kvm_handle_hva_range(mn, start, end, __pte(0), kvm_age_gfn);
 }
 
 static int kvm_mmu_notifier_clear_young(struct mmu_notifier *mn,
@@ -566,11 +757,8 @@
 					unsigned long start,
 					unsigned long end)
 {
-	struct kvm *kvm = mmu_notifier_to_kvm(mn);
-	int young, idx;
+	trace_kvm_age_hva(start, end);
 
-	idx = srcu_read_lock(&kvm->srcu);
-	KVM_MMU_LOCK(kvm);
 	/*
 	 * Even though we do not flush TLB, this will still adversely
 	 * affect performance on pre-Haswell Intel EPT, where there is
@@ -584,27 +772,17 @@
 	 * cadence. If we find this inaccurate, we might come up with a
 	 * more sophisticated heuristic later.
 	 */
-	young = kvm_age_hva(kvm, start, end);
-	KVM_MMU_UNLOCK(kvm);
-	srcu_read_unlock(&kvm->srcu, idx);
-
-	return young;
+	return kvm_handle_hva_range_no_flush(mn, start, end, kvm_age_gfn);
 }
 
 static int kvm_mmu_notifier_test_young(struct mmu_notifier *mn,
 				       struct mm_struct *mm,
 				       unsigned long address)
 {
-	struct kvm *kvm = mmu_notifier_to_kvm(mn);
-	int young, idx;
+	trace_kvm_test_age_hva(address);
 
-	idx = srcu_read_lock(&kvm->srcu);
-	KVM_MMU_LOCK(kvm);
-	young = kvm_test_age_hva(kvm, address);
-	KVM_MMU_UNLOCK(kvm);
-	srcu_read_unlock(&kvm->srcu, idx);
-
-	return young;
+	return kvm_handle_hva_range_no_flush(mn, address, address + 1,
+					     kvm_test_age_gfn);
 }
 
 static void kvm_mmu_notifier_release(struct mmu_notifier *mn,
@@ -773,6 +951,7 @@
 	mutex_init(&kvm->lock);
 	mutex_init(&kvm->irq_lock);
 	mutex_init(&kvm->slots_lock);
+	init_rwsem(&kvm->mmu_notifier_slots_lock);
 	INIT_LIST_HEAD(&kvm->devices);
 
 	BUILD_BUG_ON(KVM_MEM_SLOTS_NUM > SHRT_MAX);
@@ -893,6 +1072,16 @@
 	kvm_coalesced_mmio_free(kvm);
 #if defined(CONFIG_MMU_NOTIFIER) && defined(KVM_ARCH_WANT_MMU_NOTIFIER)
 	mmu_notifier_unregister(&kvm->mmu_notifier, kvm->mm);
+	/*
+	 * Reset the lock used to prevent memslot updates between MMU notifier
+	 * invalidate_range_start() and invalidate_range_end().  At this point,
+	 * no more MMU notifiers will run and pending calls to ...start() have
+	 * completed.  But, the lock could still be held if KVM's notifier was
+	 * removed between ...start() and ...end().  No threads can be waiting
+	 * on the lock as the last reference on KVM has been dropped.  If the
+	 * lock is still held, freeing memslots will deadlock.
+	 */
+	init_rwsem(&kvm->mmu_notifier_slots_lock);
 #else
 	kvm_arch_flush_shadow_all(kvm);
 #endif
@@ -1144,7 +1333,10 @@
 	WARN_ON(gen & KVM_MEMSLOT_GEN_UPDATE_IN_PROGRESS);
 	slots->generation = gen | KVM_MEMSLOT_GEN_UPDATE_IN_PROGRESS;
 
+	down_write(&kvm->mmu_notifier_slots_lock);
 	rcu_assign_pointer(kvm->memslots[as_id], slots);
+	up_write(&kvm->mmu_notifier_slots_lock);
+
 	synchronize_srcu_expedited(&kvm->srcu);
 
 	/*
@@ -4062,6 +4254,12 @@
 	KVM_COMPAT(kvm_vm_compat_ioctl),
 };
 
+bool file_is_kvm(struct file *file)
+{
+	return file && file->f_op == &kvm_vm_fops;
+}
+EXPORT_SYMBOL_GPL(file_is_kvm);
+
 static int kvm_dev_ioctl_create_vm(unsigned long type)
 {
 	int r;