Merge branches 'fixes', 'arm/smmu/updates', 'amd/amd-vi' and 'core' into next
diff --git a/Documentation/admin-guide/kernel-parameters.txt b/Documentation/admin-guide/kernel-parameters.txt
index 1058f2a..6d4b926 100644
--- a/Documentation/admin-guide/kernel-parameters.txt
+++ b/Documentation/admin-guide/kernel-parameters.txt
@@ -2675,6 +2675,15 @@
 			1 - Bypass the IOMMU for DMA.
 			unset - Use value of CONFIG_IOMMU_DEFAULT_PASSTHROUGH.
 
+	iommu.debug_pagealloc=
+			[KNL,EARLY] When CONFIG_IOMMU_DEBUG_PAGEALLOC is set, this
+			parameter enables the feature at boot time. By default, it
+			is disabled and the system behaves the same way as a kernel
+			built without CONFIG_IOMMU_DEBUG_PAGEALLOC.
+			Format: { "0" | "1" }
+			0 - Sanitizer disabled.
+			1 - Sanitizer enabled, expect runtime overhead.
+
 	io7=		[HW] IO7 for Marvel-based Alpha systems
 			See comment before marvel_specify_io7 in
 			arch/alpha/kernel/core_marvel.c.
diff --git a/MAINTAINERS b/MAINTAINERS
index da9dbc1..3aa2aac 100644
--- a/MAINTAINERS
+++ b/MAINTAINERS
@@ -13249,6 +13249,7 @@
 F:	include/linux/iommu.h
 F:	include/linux/iova.h
 F:	include/linux/of_iommu.h
+F:	rust/kernel/iommu/
 
 IOMMUFD
 M:	Jason Gunthorpe <jgg@nvidia.com>
diff --git a/drivers/iommu/Kconfig b/drivers/iommu/Kconfig
index 9909564..f86262b 100644
--- a/drivers/iommu/Kconfig
+++ b/drivers/iommu/Kconfig
@@ -384,6 +384,25 @@
 
 	  Say Y here if you want to use the multimedia devices listed above.
 
+config IOMMU_DEBUG_PAGEALLOC
+	bool "Debug IOMMU mappings against page allocations"
+	depends on DEBUG_PAGEALLOC && IOMMU_API && PAGE_EXTENSION
+	help
+	  This enables a consistency check between the kernel page allocator and
+	  the IOMMU subsystem. It verifies that pages being allocated or freed
+	  are not currently mapped in any IOMMU domain.
+
+	  This helps detect DMA use-after-free bugs where a driver frees a page
+	  but forgets to unmap it from the IOMMU, potentially allowing a device
+	  to overwrite memory that the kernel has repurposed.
+
+	  These checks are best-effort and may not detect all problems.
+
+	  Due to performance overhead, this feature is disabled by default.
+	  You must enable "iommu.debug_pagealloc" from the kernel command
+	  line to activate the runtime checks.
+
+	  If unsure, say N.
 endif # IOMMU_SUPPORT
 
 source "drivers/iommu/generic_pt/Kconfig"
diff --git a/drivers/iommu/Makefile b/drivers/iommu/Makefile
index 8e88433..0275821 100644
--- a/drivers/iommu/Makefile
+++ b/drivers/iommu/Makefile
@@ -36,3 +36,4 @@
 obj-$(CONFIG_IOMMU_IOPF) += io-pgfault.o
 obj-$(CONFIG_SPRD_IOMMU) += sprd-iommu.o
 obj-$(CONFIG_APPLE_DART) += apple-dart.o
+obj-$(CONFIG_IOMMU_DEBUG_PAGEALLOC) += iommu-debug-pagealloc.o
diff --git a/drivers/iommu/amd/Kconfig b/drivers/iommu/amd/Kconfig
index f2acf47..588355f 100644
--- a/drivers/iommu/amd/Kconfig
+++ b/drivers/iommu/amd/Kconfig
@@ -30,6 +30,16 @@
 	  your BIOS for an option to enable it or if you have an IVRS ACPI
 	  table.
 
+config AMD_IOMMU_IOMMUFD
+	bool "Enable IOMMUFD features for AMD IOMMU (EXPERIMENTAL)"
+	depends on IOMMUFD
+	depends on AMD_IOMMU
+	help
+	  Support for IOMMUFD features intended to support virtual machines
+	  with accelerated virtual IOMMUs.
+
+	  Say Y here if you are doing development and testing on this feature.
+
 config AMD_IOMMU_DEBUGFS
 	bool "Enable AMD IOMMU internals in DebugFS"
 	depends on AMD_IOMMU && IOMMU_DEBUGFS
diff --git a/drivers/iommu/amd/Makefile b/drivers/iommu/amd/Makefile
index 5412a56..94b8ef2 100644
--- a/drivers/iommu/amd/Makefile
+++ b/drivers/iommu/amd/Makefile
@@ -1,3 +1,4 @@
 # SPDX-License-Identifier: GPL-2.0-only
 obj-y += iommu.o init.o quirks.o ppr.o pasid.o
+obj-$(CONFIG_AMD_IOMMU_IOMMUFD) += iommufd.o nested.o
 obj-$(CONFIG_AMD_IOMMU_DEBUGFS) += debugfs.o
diff --git a/drivers/iommu/amd/amd_iommu.h b/drivers/iommu/amd/amd_iommu.h
index b742ef1..02f1092 100644
--- a/drivers/iommu/amd/amd_iommu.h
+++ b/drivers/iommu/amd/amd_iommu.h
@@ -190,4 +190,37 @@ void amd_iommu_domain_set_pgtable(struct protection_domain *domain,
 struct dev_table_entry *get_dev_table(struct amd_iommu *iommu);
 struct iommu_dev_data *search_dev_data(struct amd_iommu *iommu, u16 devid);
 
+void amd_iommu_set_dte_v1(struct iommu_dev_data *dev_data,
+			  struct protection_domain *domain, u16 domid,
+			  struct pt_iommu_amdv1_hw_info *pt_info,
+			  struct dev_table_entry *new);
+void amd_iommu_update_dte(struct amd_iommu *iommu,
+			  struct iommu_dev_data *dev_data,
+			  struct dev_table_entry *new);
+
+static inline void
+amd_iommu_make_clear_dte(struct iommu_dev_data *dev_data, struct dev_table_entry *new)
+{
+	struct dev_table_entry *initial_dte;
+	struct amd_iommu *iommu = get_amd_iommu_from_dev(dev_data->dev);
+
+	/* All existing DTE must have V bit set */
+	new->data128[0] = DTE_FLAG_V;
+	new->data128[1] = 0;
+
+	/*
+	 * Restore cached persistent DTE bits, which can be set by information
+	 * in IVRS table. See set_dev_entry_from_acpi().
+	 */
+	initial_dte = amd_iommu_get_ivhd_dte_flags(iommu->pci_seg->id, dev_data->devid);
+	if (initial_dte) {
+		new->data128[0] |= initial_dte->data128[0];
+		new->data128[1] |= initial_dte->data128[1];
+	}
+}
+
+/* NESTED */
+struct iommu_domain *
+amd_iommu_alloc_domain_nested(struct iommufd_viommu *viommu, u32 flags,
+			      const struct iommu_user_data *user_data);
 #endif /* AMD_IOMMU_H */
diff --git a/drivers/iommu/amd/amd_iommu_types.h b/drivers/iommu/amd/amd_iommu_types.h
index 320733e..cfcbad6 100644
--- a/drivers/iommu/amd/amd_iommu_types.h
+++ b/drivers/iommu/amd/amd_iommu_types.h
@@ -17,9 +17,12 @@
 #include <linux/list.h>
 #include <linux/spinlock.h>
 #include <linux/pci.h>
+#include <linux/iommufd.h>
 #include <linux/irqreturn.h>
 #include <linux/generic_pt/iommu.h>
 
+#include <uapi/linux/iommufd.h>
+
 /*
  * Maximum number of IOMMUs supported
  */
@@ -108,6 +111,7 @@
 
 /* Extended Feature 2 Bits */
 #define FEATURE_SEVSNPIO_SUP	BIT_ULL(1)
+#define FEATURE_GCR3TRPMODE	BIT_ULL(3)
 #define FEATURE_SNPAVICSUP	GENMASK_ULL(7, 5)
 #define FEATURE_SNPAVICSUP_GAM(x) \
 	(FIELD_GET(FEATURE_SNPAVICSUP, x) == 0x1)
@@ -186,6 +190,7 @@
 #define CONTROL_EPH_EN		45
 #define CONTROL_XT_EN		50
 #define CONTROL_INTCAPXT_EN	51
+#define CONTROL_GCR3TRPMODE	58
 #define CONTROL_IRTCACHEDIS	59
 #define CONTROL_SNPAVIC_EN	61
 
@@ -350,6 +355,9 @@
 #define DTE_FLAG_V	BIT_ULL(0)
 #define DTE_FLAG_TV	BIT_ULL(1)
 #define DTE_FLAG_HAD	(3ULL << 7)
+#define DTE_MODE_MASK	GENMASK_ULL(11, 9)
+#define DTE_HOST_TRP	GENMASK_ULL(51, 12)
+#define DTE_FLAG_PPR	BIT_ULL(52)
 #define DTE_FLAG_GIOV	BIT_ULL(54)
 #define DTE_FLAG_GV	BIT_ULL(55)
 #define DTE_GLX		GENMASK_ULL(57, 56)
@@ -358,7 +366,7 @@
 
 #define DTE_FLAG_IOTLB	BIT_ULL(32)
 #define DTE_FLAG_MASK	(0x3ffULL << 32)
-#define DEV_DOMID_MASK	0xffffULL
+#define DTE_DOMID_MASK	GENMASK_ULL(15, 0)
 
 #define DTE_GCR3_14_12	GENMASK_ULL(60, 58)
 #define DTE_GCR3_30_15	GENMASK_ULL(31, 16)
@@ -493,6 +501,38 @@ struct pdom_iommu_info {
 	u32 refcnt;	/* Count of attached dev/pasid per domain/IOMMU */
 };
 
+struct amd_iommu_viommu {
+	struct iommufd_viommu core;
+	struct protection_domain *parent; /* nest parent domain for this viommu */
+	struct list_head pdom_list;	  /* For protection_domain->viommu_list */
+
+	/*
+	 * Per-vIOMMU guest domain ID to host domain ID mapping.
+	 * Indexed by guest domain ID.
+	 */
+	struct xarray gdomid_array;
+};
+
+/*
+ * Contains guest domain ID mapping info,
+ * which is stored in the struct xarray gdomid_array.
+ */
+struct guest_domain_mapping_info {
+	refcount_t users;
+	u32 hdom_id;		/* Host domain ID */
+};
+
+/*
+ * Nested domain is specifically used for nested translation
+ */
+struct nested_domain {
+	struct iommu_domain domain; /* generic domain handle used by iommu core code */
+	u16 gdom_id;                /* domain ID from gDTE */
+	struct guest_domain_mapping_info *gdom_info;
+	struct iommu_hwpt_amd_guest gdte; /* Guest vIOMMU DTE */
+	struct amd_iommu_viommu *viommu;  /* AMD hw-viommu this nested domain belong to */
+};
+
 /*
  * This structure contains generic data for  IOMMU protection domains
  * independent of their use.
@@ -513,6 +553,12 @@ struct protection_domain {
 
 	struct mmu_notifier mn;	/* mmu notifier for the SVA domain */
 	struct list_head dev_data_list; /* List of pdom_dev_data */
+
+	/*
+	 * Store reference to list of vIOMMUs, which use this protection domain.
+	 * This will be used to look up host domain ID when flushing this domain.
+	 */
+	struct list_head viommu_list;
 };
 PT_IOMMU_CHECK_DOMAIN(struct protection_domain, iommu, domain);
 PT_IOMMU_CHECK_DOMAIN(struct protection_domain, amdv1.iommu, domain);
diff --git a/drivers/iommu/amd/init.c b/drivers/iommu/amd/init.c
index 384c90b..b1c344e 100644
--- a/drivers/iommu/amd/init.c
+++ b/drivers/iommu/amd/init.c
@@ -1122,6 +1122,14 @@ static void iommu_enable_gt(struct amd_iommu *iommu)
 		return;
 
 	iommu_feature_enable(iommu, CONTROL_GT_EN);
+
+	/*
+	 * This feature needs to be enabled prior to a call
+	 * to iommu_snp_enable(). Since this function is called
+	 * in early_enable_iommu(), it is safe to enable here.
+	 */
+	if (check_feature2(FEATURE_GCR3TRPMODE))
+		iommu_feature_enable(iommu, CONTROL_GCR3TRPMODE);
 }
 
 /* sets a specific bit in the device table entry. */
@@ -1179,7 +1187,7 @@ static bool __reuse_device_table(struct amd_iommu *iommu)
 	for (devid = 0; devid <= pci_seg->last_bdf; devid++) {
 		old_dev_tbl_entry = &pci_seg->old_dev_tbl_cpy[devid];
 		dte_v = FIELD_GET(DTE_FLAG_V, old_dev_tbl_entry->data[0]);
-		dom_id = FIELD_GET(DEV_DOMID_MASK, old_dev_tbl_entry->data[1]);
+		dom_id = FIELD_GET(DTE_DOMID_MASK, old_dev_tbl_entry->data[1]);
 
 		if (!dte_v || !dom_id)
 			continue;
diff --git a/drivers/iommu/amd/iommu.c b/drivers/iommu/amd/iommu.c
index 5d45795..7e724f5 100644
--- a/drivers/iommu/amd/iommu.c
+++ b/drivers/iommu/amd/iommu.c
@@ -43,6 +43,7 @@
 #include <linux/generic_pt/iommu.h>
 
 #include "amd_iommu.h"
+#include "iommufd.h"
 #include "../irq_remapping.h"
 #include "../iommu-pages.h"
 
@@ -75,6 +76,8 @@ static void set_dte_entry(struct amd_iommu *iommu,
 			  struct iommu_dev_data *dev_data,
 			  phys_addr_t top_paddr, unsigned int top_level);
 
+static int device_flush_dte(struct iommu_dev_data *dev_data);
+
 static void amd_iommu_change_top(struct pt_iommu *iommu_table,
 				 phys_addr_t top_paddr, unsigned int top_level);
 
@@ -85,6 +88,10 @@ static bool amd_iommu_enforce_cache_coherency(struct iommu_domain *domain);
 static int amd_iommu_set_dirty_tracking(struct iommu_domain *domain,
 					bool enable);
 
+static void clone_aliases(struct amd_iommu *iommu, struct device *dev);
+
+static int iommu_completion_wait(struct amd_iommu *iommu);
+
 /****************************************************************************
  *
  * Helper functions
@@ -202,6 +209,16 @@ static void update_dte256(struct amd_iommu *iommu, struct iommu_dev_data *dev_da
 	spin_unlock_irqrestore(&dev_data->dte_lock, flags);
 }
 
+void amd_iommu_update_dte(struct amd_iommu *iommu,
+			     struct iommu_dev_data *dev_data,
+			     struct dev_table_entry *new)
+{
+	update_dte256(iommu, dev_data, new);
+	clone_aliases(iommu, dev_data->dev);
+	device_flush_dte(dev_data);
+	iommu_completion_wait(iommu);
+}
+
 static void get_dte256(struct amd_iommu *iommu, struct iommu_dev_data *dev_data,
 		      struct dev_table_entry *dte)
 {
@@ -1185,7 +1202,12 @@ static int wait_on_sem(struct amd_iommu *iommu, u64 data)
 {
 	int i = 0;
 
-	while (*iommu->cmd_sem != data && i < LOOP_TIMEOUT) {
+	/*
+	 * cmd_sem holds a monotonically non-decreasing completion sequence
+	 * number.
+	 */
+	while ((__s64)(READ_ONCE(*iommu->cmd_sem) - data) < 0 &&
+	       i < LOOP_TIMEOUT) {
 		udelay(1);
 		i += 1;
 	}
@@ -1437,14 +1459,13 @@ static int iommu_completion_wait(struct amd_iommu *iommu)
 	raw_spin_lock_irqsave(&iommu->lock, flags);
 
 	ret = __iommu_queue_command_sync(iommu, &cmd, false);
+	raw_spin_unlock_irqrestore(&iommu->lock, flags);
+
 	if (ret)
-		goto out_unlock;
+		return ret;
 
 	ret = wait_on_sem(iommu, data);
 
-out_unlock:
-	raw_spin_unlock_irqrestore(&iommu->lock, flags);
-
 	return ret;
 }
 
@@ -1522,6 +1543,32 @@ static void amd_iommu_flush_tlb_domid(struct amd_iommu *iommu, u32 dom_id)
 	iommu_completion_wait(iommu);
 }
 
+static int iommu_flush_pages_v1_hdom_ids(struct protection_domain *pdom, u64 address, size_t size)
+{
+	int ret = 0;
+	struct amd_iommu_viommu *aviommu;
+
+	list_for_each_entry(aviommu, &pdom->viommu_list, pdom_list) {
+		unsigned long i;
+		struct guest_domain_mapping_info *gdom_info;
+		struct amd_iommu *iommu = container_of(aviommu->core.iommu_dev,
+						       struct amd_iommu, iommu);
+
+		xa_lock(&aviommu->gdomid_array);
+		xa_for_each(&aviommu->gdomid_array, i, gdom_info) {
+			struct iommu_cmd cmd;
+
+			pr_debug("%s: iommu=%#x, hdom_id=%#x\n", __func__,
+				 iommu->devid, gdom_info->hdom_id);
+			build_inv_iommu_pages(&cmd, address, size, gdom_info->hdom_id,
+					      IOMMU_NO_PASID, false);
+			ret |= iommu_queue_command(iommu, &cmd);
+		}
+		xa_unlock(&aviommu->gdomid_array);
+	}
+	return ret;
+}
+
 static void amd_iommu_flush_all(struct amd_iommu *iommu)
 {
 	struct iommu_cmd cmd;
@@ -1670,6 +1717,17 @@ static int domain_flush_pages_v1(struct protection_domain *pdom,
 		ret |= iommu_queue_command(pdom_iommu_info->iommu, &cmd);
 	}
 
+	/*
+	 * A domain w/ v1 table can be a nest parent, which can have
+	 * multiple nested domains. Each nested domain has 1:1 mapping
+	 * between gDomID and hDomID. Therefore, flush every hDomID
+	 * associated to this nest parent domain.
+	 *
+	 * See drivers/iommu/amd/nested.c: amd_iommu_alloc_domain_nested()
+	 */
+	if (!list_empty(&pdom->viommu_list))
+		ret |= iommu_flush_pages_v1_hdom_ids(pdom, address, size);
+
 	return ret;
 }
 
@@ -2010,127 +2068,112 @@ int amd_iommu_clear_gcr3(struct iommu_dev_data *dev_data, ioasid_t pasid)
 	return ret;
 }
 
-static void make_clear_dte(struct iommu_dev_data *dev_data, struct dev_table_entry *ptr,
-			   struct dev_table_entry *new)
-{
-	/* All existing DTE must have V bit set */
-	new->data128[0] = DTE_FLAG_V;
-	new->data128[1] = 0;
-}
-
 /*
  * Note:
  * The old value for GCR3 table and GPT have been cleared from caller.
  */
-static void set_dte_gcr3_table(struct amd_iommu *iommu,
-			       struct iommu_dev_data *dev_data,
-			       struct dev_table_entry *target)
+static void set_dte_gcr3_table(struct iommu_dev_data *dev_data,
+			       struct dev_table_entry *new)
 {
 	struct gcr3_tbl_info *gcr3_info = &dev_data->gcr3_info;
-	u64 gcr3;
+	u64 gcr3 = iommu_virt_to_phys(gcr3_info->gcr3_tbl);
 
-	if (!gcr3_info->gcr3_tbl)
-		return;
+	new->data[0] |= DTE_FLAG_TV |
+			(dev_data->ppr ? DTE_FLAG_PPR : 0) |
+			(pdom_is_v2_pgtbl_mode(dev_data->domain) ?  DTE_FLAG_GIOV : 0) |
+			DTE_FLAG_GV |
+			FIELD_PREP(DTE_GLX, gcr3_info->glx) |
+			FIELD_PREP(DTE_GCR3_14_12, gcr3 >> 12) |
+			DTE_FLAG_IR | DTE_FLAG_IW;
 
-	pr_debug("%s: devid=%#x, glx=%#x, gcr3_tbl=%#llx\n",
-		 __func__, dev_data->devid, gcr3_info->glx,
-		 (unsigned long long)gcr3_info->gcr3_tbl);
-
-	gcr3 = iommu_virt_to_phys(gcr3_info->gcr3_tbl);
-
-	target->data[0] |= DTE_FLAG_GV |
-			   FIELD_PREP(DTE_GLX, gcr3_info->glx) |
-			   FIELD_PREP(DTE_GCR3_14_12, gcr3 >> 12);
-	if (pdom_is_v2_pgtbl_mode(dev_data->domain))
-		target->data[0] |= DTE_FLAG_GIOV;
-
-	target->data[1] |= FIELD_PREP(DTE_GCR3_30_15, gcr3 >> 15) |
-			   FIELD_PREP(DTE_GCR3_51_31, gcr3 >> 31);
+	new->data[1] |= FIELD_PREP(DTE_DOMID_MASK, dev_data->gcr3_info.domid) |
+			FIELD_PREP(DTE_GCR3_30_15, gcr3 >> 15) |
+			(dev_data->ats_enabled ? DTE_FLAG_IOTLB : 0) |
+			FIELD_PREP(DTE_GCR3_51_31, gcr3 >> 31);
 
 	/* Guest page table can only support 4 and 5 levels  */
 	if (amd_iommu_gpt_level == PAGE_MODE_5_LEVEL)
-		target->data[2] |= FIELD_PREP(DTE_GPT_LEVEL_MASK, GUEST_PGTABLE_5_LEVEL);
+		new->data[2] |= FIELD_PREP(DTE_GPT_LEVEL_MASK, GUEST_PGTABLE_5_LEVEL);
 	else
-		target->data[2] |= FIELD_PREP(DTE_GPT_LEVEL_MASK, GUEST_PGTABLE_4_LEVEL);
+		new->data[2] |= FIELD_PREP(DTE_GPT_LEVEL_MASK, GUEST_PGTABLE_4_LEVEL);
+}
+
+void amd_iommu_set_dte_v1(struct iommu_dev_data *dev_data,
+			  struct protection_domain *domain, u16 domid,
+			  struct pt_iommu_amdv1_hw_info *pt_info,
+			  struct dev_table_entry *new)
+{
+	u64 host_pt_root = __sme_set(pt_info->host_pt_root);
+
+	/* Note Dirty tracking is used for v1 table only for now */
+	new->data[0] |= DTE_FLAG_TV |
+			FIELD_PREP(DTE_MODE_MASK, pt_info->mode) |
+			(domain->dirty_tracking ? DTE_FLAG_HAD : 0) |
+			FIELD_PREP(DTE_HOST_TRP, host_pt_root >> 12) |
+			DTE_FLAG_IR | DTE_FLAG_IW;
+
+	new->data[1] |= FIELD_PREP(DTE_DOMID_MASK, domid) |
+			(dev_data->ats_enabled ? DTE_FLAG_IOTLB : 0);
+}
+
+static void set_dte_v1(struct iommu_dev_data *dev_data,
+		       struct protection_domain *domain, u16 domid,
+		       phys_addr_t top_paddr, unsigned int top_level,
+		       struct dev_table_entry *new)
+{
+	struct pt_iommu_amdv1_hw_info pt_info;
+
+	/*
+	 * When updating the IO pagetable, the new top and level
+	 * are provided as parameters. For other operations i.e.
+	 * device attach, retrieve the current pagetable info
+	 * via the IOMMU PT API.
+	 */
+	if (top_paddr) {
+		pt_info.host_pt_root = top_paddr;
+		pt_info.mode = top_level + 1;
+	} else {
+		WARN_ON(top_paddr || top_level);
+		pt_iommu_amdv1_hw_info(&domain->amdv1, &pt_info);
+	}
+
+	amd_iommu_set_dte_v1(dev_data, domain, domid, &pt_info, new);
+}
+
+static void set_dte_passthrough(struct iommu_dev_data *dev_data,
+				struct protection_domain *domain,
+				struct dev_table_entry *new)
+{
+	new->data[0] |= DTE_FLAG_TV | DTE_FLAG_IR | DTE_FLAG_IW;
+
+	new->data[1] |= FIELD_PREP(DTE_DOMID_MASK, domain->id) |
+			(dev_data->ats_enabled) ? DTE_FLAG_IOTLB : 0;
 }
 
 static void set_dte_entry(struct amd_iommu *iommu,
 			  struct iommu_dev_data *dev_data,
 			  phys_addr_t top_paddr, unsigned int top_level)
 {
-	u16 domid;
 	u32 old_domid;
-	struct dev_table_entry *initial_dte;
 	struct dev_table_entry new = {};
 	struct protection_domain *domain = dev_data->domain;
 	struct gcr3_tbl_info *gcr3_info = &dev_data->gcr3_info;
 	struct dev_table_entry *dte = &get_dev_table(iommu)[dev_data->devid];
-	struct pt_iommu_amdv1_hw_info pt_info;
 
-	make_clear_dte(dev_data, dte, &new);
+	amd_iommu_make_clear_dte(dev_data, &new);
 
-	if (gcr3_info && gcr3_info->gcr3_tbl)
-		domid = dev_data->gcr3_info.domid;
-	else {
-		domid = domain->id;
+	old_domid = READ_ONCE(dte->data[1]) & DTE_DOMID_MASK;
+	if (gcr3_info->gcr3_tbl)
+		set_dte_gcr3_table(dev_data, &new);
+	else if (domain->domain.type == IOMMU_DOMAIN_IDENTITY)
+		set_dte_passthrough(dev_data, domain, &new);
+	else if ((domain->domain.type & __IOMMU_DOMAIN_PAGING) &&
+		 domain->pd_mode == PD_MODE_V1)
+		set_dte_v1(dev_data, domain, domain->id, top_paddr, top_level, &new);
+	else
+		WARN_ON(true);
 
-		if (domain->domain.type & __IOMMU_DOMAIN_PAGING) {
-			/*
-			 * When updating the IO pagetable, the new top and level
-			 * are provided as parameters. For other operations i.e.
-			 * device attach, retrieve the current pagetable info
-			 * via the IOMMU PT API.
-			 */
-			if (top_paddr) {
-				pt_info.host_pt_root = top_paddr;
-				pt_info.mode = top_level + 1;
-			} else {
-				WARN_ON(top_paddr || top_level);
-				pt_iommu_amdv1_hw_info(&domain->amdv1,
-						       &pt_info);
-			}
-
-			new.data[0] |= __sme_set(pt_info.host_pt_root) |
-				       (pt_info.mode & DEV_ENTRY_MODE_MASK)
-					       << DEV_ENTRY_MODE_SHIFT;
-		}
-	}
-
-	new.data[0] |= DTE_FLAG_IR | DTE_FLAG_IW;
-
-	/*
-	 * When SNP is enabled, we can only support TV=1 with non-zero domain ID.
-	 * This is prevented by the SNP-enable and IOMMU_DOMAIN_IDENTITY check in
-	 * do_iommu_domain_alloc().
-	 */
-	WARN_ON(amd_iommu_snp_en && (domid == 0));
-	new.data[0] |= DTE_FLAG_TV;
-
-	if (dev_data->ppr)
-		new.data[0] |= 1ULL << DEV_ENTRY_PPR;
-
-	if (domain->dirty_tracking)
-		new.data[0] |= DTE_FLAG_HAD;
-
-	if (dev_data->ats_enabled)
-		new.data[1] |= DTE_FLAG_IOTLB;
-
-	old_domid = READ_ONCE(dte->data[1]) & DEV_DOMID_MASK;
-	new.data[1] |= domid;
-
-	/*
-	 * Restore cached persistent DTE bits, which can be set by information
-	 * in IVRS table. See set_dev_entry_from_acpi().
-	 */
-	initial_dte = amd_iommu_get_ivhd_dte_flags(iommu->pci_seg->id, dev_data->devid);
-	if (initial_dte) {
-		new.data128[0] |= initial_dte->data128[0];
-		new.data128[1] |= initial_dte->data128[1];
-	}
-
-	set_dte_gcr3_table(iommu, dev_data, &new);
-
-	update_dte256(iommu, dev_data, &new);
+	amd_iommu_update_dte(iommu, dev_data, &new);
 
 	/*
 	 * A kdump kernel might be replacing a domain ID that was copied from
@@ -2148,10 +2191,9 @@ static void set_dte_entry(struct amd_iommu *iommu,
 static void clear_dte_entry(struct amd_iommu *iommu, struct iommu_dev_data *dev_data)
 {
 	struct dev_table_entry new = {};
-	struct dev_table_entry *dte = &get_dev_table(iommu)[dev_data->devid];
 
-	make_clear_dte(dev_data, dte, &new);
-	update_dte256(iommu, dev_data, &new);
+	amd_iommu_make_clear_dte(dev_data, &new);
+	amd_iommu_update_dte(iommu, dev_data, &new);
 }
 
 /* Update and flush DTE for the given device */
@@ -2163,10 +2205,6 @@ static void dev_update_dte(struct iommu_dev_data *dev_data, bool set)
 		set_dte_entry(iommu, dev_data, 0, 0);
 	else
 		clear_dte_entry(iommu, dev_data);
-
-	clone_aliases(iommu, dev_data->dev);
-	device_flush_dte(dev_data);
-	iommu_completion_wait(iommu);
 }
 
 /*
@@ -2450,8 +2488,6 @@ static struct iommu_device *amd_iommu_probe_device(struct device *dev)
 		goto out_err;
 	}
 
-out_err:
-
 	iommu_completion_wait(iommu);
 
 	if (FEATURE_NUM_INT_REMAP_SUP_2K(amd_iommu_efr2))
@@ -2462,6 +2498,7 @@ static struct iommu_device *amd_iommu_probe_device(struct device *dev)
 	if (dev_is_pci(dev))
 		pci_prepare_ats(to_pci_dev(dev), PAGE_SHIFT);
 
+out_err:
 	return iommu_dev;
 }
 
@@ -2500,6 +2537,7 @@ static void protection_domain_init(struct protection_domain *domain)
 	spin_lock_init(&domain->lock);
 	INIT_LIST_HEAD(&domain->dev_list);
 	INIT_LIST_HEAD(&domain->dev_data_list);
+	INIT_LIST_HEAD(&domain->viommu_list);
 	xa_init(&domain->iommu_array);
 }
 
@@ -2761,6 +2799,14 @@ static struct iommu_domain *amd_iommu_domain_alloc_paging_v2(struct device *dev,
 	return &domain->domain;
 }
 
+static inline bool is_nest_parent_supported(u32 flags)
+{
+	/* Only allow nest parent when these features are supported */
+	return check_feature(FEATURE_GT) &&
+	       check_feature(FEATURE_GIOSUP) &&
+	       check_feature2(FEATURE_GCR3TRPMODE);
+}
+
 static struct iommu_domain *
 amd_iommu_domain_alloc_paging_flags(struct device *dev, u32 flags,
 				    const struct iommu_user_data *user_data)
@@ -2768,16 +2814,28 @@ amd_iommu_domain_alloc_paging_flags(struct device *dev, u32 flags,
 {
 	struct amd_iommu *iommu = get_amd_iommu_from_dev(dev);
 	const u32 supported_flags = IOMMU_HWPT_ALLOC_DIRTY_TRACKING |
-						IOMMU_HWPT_ALLOC_PASID;
+						IOMMU_HWPT_ALLOC_PASID |
+						IOMMU_HWPT_ALLOC_NEST_PARENT;
 
 	if ((flags & ~supported_flags) || user_data)
 		return ERR_PTR(-EOPNOTSUPP);
 
 	switch (flags & supported_flags) {
 	case IOMMU_HWPT_ALLOC_DIRTY_TRACKING:
-		/* Allocate domain with v1 page table for dirty tracking */
-		if (!amd_iommu_hd_support(iommu))
+	case IOMMU_HWPT_ALLOC_NEST_PARENT:
+	case IOMMU_HWPT_ALLOC_DIRTY_TRACKING | IOMMU_HWPT_ALLOC_NEST_PARENT:
+		/*
+		 * Allocate domain with v1 page table for dirty tracking
+		 * and/or Nest parent.
+		 */
+		if ((flags & IOMMU_HWPT_ALLOC_DIRTY_TRACKING) &&
+		    !amd_iommu_hd_support(iommu))
 			break;
+
+		if ((flags & IOMMU_HWPT_ALLOC_NEST_PARENT) &&
+		    !is_nest_parent_supported(flags))
+			break;
+
 		return amd_iommu_domain_alloc_paging_v1(dev, flags);
 	case IOMMU_HWPT_ALLOC_PASID:
 		/* Allocate domain with v2 page table if IOMMU supports PASID. */
@@ -3079,6 +3137,7 @@ static bool amd_iommu_enforce_cache_coherency(struct iommu_domain *domain)
 
 const struct iommu_ops amd_iommu_ops = {
 	.capable = amd_iommu_capable,
+	.hw_info = amd_iommufd_hw_info,
 	.blocked_domain = &blocked_domain,
 	.release_domain = &blocked_domain,
 	.identity_domain = &identity_domain.domain,
@@ -3091,6 +3150,8 @@ const struct iommu_ops amd_iommu_ops = {
 	.is_attach_deferred = amd_iommu_is_attach_deferred,
 	.def_domain_type = amd_iommu_def_domain_type,
 	.page_response = amd_iommu_page_response,
+	.get_viommu_size = amd_iommufd_get_viommu_size,
+	.viommu_init = amd_iommufd_viommu_init,
 };
 
 #ifdef CONFIG_IRQ_REMAP
@@ -3121,13 +3182,18 @@ static void iommu_flush_irt_and_complete(struct amd_iommu *iommu, u16 devid)
 	raw_spin_lock_irqsave(&iommu->lock, flags);
 	ret = __iommu_queue_command_sync(iommu, &cmd, true);
 	if (ret)
-		goto out;
+		goto out_err;
 	ret = __iommu_queue_command_sync(iommu, &cmd2, false);
 	if (ret)
-		goto out;
-	wait_on_sem(iommu, data);
-out:
+		goto out_err;
 	raw_spin_unlock_irqrestore(&iommu->lock, flags);
+
+	wait_on_sem(iommu, data);
+	return;
+
+out_err:
+	raw_spin_unlock_irqrestore(&iommu->lock, flags);
+	return;
 }
 
 static inline u8 iommu_get_int_tablen(struct iommu_dev_data *dev_data)
@@ -3240,7 +3306,7 @@ static struct irq_remap_table *alloc_irq_table(struct amd_iommu *iommu,
 	struct irq_remap_table *new_table = NULL;
 	struct amd_iommu_pci_seg *pci_seg;
 	unsigned long flags;
-	int nid = iommu && iommu->dev ? dev_to_node(&iommu->dev->dev) : NUMA_NO_NODE;
+	int nid = iommu->dev ? dev_to_node(&iommu->dev->dev) : NUMA_NO_NODE;
 	u16 alias;
 
 	spin_lock_irqsave(&iommu_table_lock, flags);
diff --git a/drivers/iommu/amd/iommufd.c b/drivers/iommu/amd/iommufd.c
new file mode 100644
index 0000000..ad627fd5
--- /dev/null
+++ b/drivers/iommu/amd/iommufd.c
@@ -0,0 +1,77 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Copyright (C) 2025 Advanced Micro Devices, Inc.
+ */
+
+#include <linux/iommu.h>
+
+#include "iommufd.h"
+#include "amd_iommu.h"
+#include "amd_iommu_types.h"
+
+static const struct iommufd_viommu_ops amd_viommu_ops;
+
+void *amd_iommufd_hw_info(struct device *dev, u32 *length, u32 *type)
+{
+	struct iommu_hw_info_amd *hwinfo;
+
+	if (*type != IOMMU_HW_INFO_TYPE_DEFAULT &&
+	    *type != IOMMU_HW_INFO_TYPE_AMD)
+		return ERR_PTR(-EOPNOTSUPP);
+
+	hwinfo = kzalloc(sizeof(*hwinfo), GFP_KERNEL);
+	if (!hwinfo)
+		return ERR_PTR(-ENOMEM);
+
+	*length = sizeof(*hwinfo);
+	*type = IOMMU_HW_INFO_TYPE_AMD;
+
+	hwinfo->efr = amd_iommu_efr;
+	hwinfo->efr2 = amd_iommu_efr2;
+
+	return hwinfo;
+}
+
+size_t amd_iommufd_get_viommu_size(struct device *dev, enum iommu_viommu_type viommu_type)
+{
+	return VIOMMU_STRUCT_SIZE(struct amd_iommu_viommu, core);
+}
+
+int amd_iommufd_viommu_init(struct iommufd_viommu *viommu, struct iommu_domain *parent,
+			    const struct iommu_user_data *user_data)
+{
+	unsigned long flags;
+	struct protection_domain *pdom = to_pdomain(parent);
+	struct amd_iommu_viommu *aviommu = container_of(viommu, struct amd_iommu_viommu, core);
+
+	xa_init_flags(&aviommu->gdomid_array, XA_FLAGS_ALLOC1);
+	aviommu->parent = pdom;
+
+	viommu->ops = &amd_viommu_ops;
+
+	spin_lock_irqsave(&pdom->lock, flags);
+	list_add(&aviommu->pdom_list, &pdom->viommu_list);
+	spin_unlock_irqrestore(&pdom->lock, flags);
+
+	return 0;
+}
+
+static void amd_iommufd_viommu_destroy(struct iommufd_viommu *viommu)
+{
+	unsigned long flags;
+	struct amd_iommu_viommu *aviommu = container_of(viommu, struct amd_iommu_viommu, core);
+	struct protection_domain *pdom = aviommu->parent;
+
+	spin_lock_irqsave(&pdom->lock, flags);
+	list_del(&aviommu->pdom_list);
+	spin_unlock_irqrestore(&pdom->lock, flags);
+	xa_destroy(&aviommu->gdomid_array);
+}
+
+/*
+ * See include/linux/iommufd.h
+ * struct iommufd_viommu_ops - vIOMMU specific operations
+ */
+static const struct iommufd_viommu_ops amd_viommu_ops = {
+	.destroy = amd_iommufd_viommu_destroy,
+};
diff --git a/drivers/iommu/amd/iommufd.h b/drivers/iommu/amd/iommufd.h
new file mode 100644
index 0000000..f05aad4
--- /dev/null
+++ b/drivers/iommu/amd/iommufd.h
@@ -0,0 +1,20 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+/*
+ * Copyright (C) 2025 Advanced Micro Devices, Inc.
+ */
+
+#ifndef AMD_IOMMUFD_H
+#define AMD_IOMMUFD_H
+
+#if IS_ENABLED(CONFIG_AMD_IOMMU_IOMMUFD)
+void *amd_iommufd_hw_info(struct device *dev, u32 *length, u32 *type);
+size_t amd_iommufd_get_viommu_size(struct device *dev, enum iommu_viommu_type viommu_type);
+int amd_iommufd_viommu_init(struct iommufd_viommu *viommu, struct iommu_domain *parent,
+			    const struct iommu_user_data *user_data);
+#else
+#define amd_iommufd_hw_info NULL
+#define amd_iommufd_viommu_init NULL
+#define amd_iommufd_get_viommu_size NULL
+#endif /* CONFIG_AMD_IOMMU_IOMMUFD */
+
+#endif /* AMD_IOMMUFD_H */
diff --git a/drivers/iommu/amd/nested.c b/drivers/iommu/amd/nested.c
new file mode 100644
index 0000000..66cc361
--- /dev/null
+++ b/drivers/iommu/amd/nested.c
@@ -0,0 +1,294 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Copyright (C) 2025 Advanced Micro Devices, Inc.
+ */
+
+#define dev_fmt(fmt)	"AMD-Vi: " fmt
+
+#include <linux/iommu.h>
+#include <linux/refcount.h>
+#include <uapi/linux/iommufd.h>
+
+#include "amd_iommu.h"
+
+static const struct iommu_domain_ops nested_domain_ops;
+
+static inline struct nested_domain *to_ndomain(struct iommu_domain *dom)
+{
+	return container_of(dom, struct nested_domain, domain);
+}
+
+/*
+ * Validate guest DTE to make sure that configuration for host (v1)
+ * and guest (v2) page tables are valid when allocating nested domain.
+ */
+static int validate_gdte_nested(struct iommu_hwpt_amd_guest *gdte)
+{
+	u32 gpt_level = FIELD_GET(DTE_GPT_LEVEL_MASK, gdte->dte[2]);
+
+	/* Must be zero: Mode, Host-TPR */
+	if (FIELD_GET(DTE_MODE_MASK, gdte->dte[0]) != 0 ||
+	    FIELD_GET(DTE_HOST_TRP, gdte->dte[0]) != 0)
+		return -EINVAL;
+
+	/* GCR3 TRP must be non-zero if V, GV is set */
+	if (FIELD_GET(DTE_FLAG_V, gdte->dte[0]) == 1 &&
+	    FIELD_GET(DTE_FLAG_GV, gdte->dte[0]) == 1 &&
+	    FIELD_GET(DTE_GCR3_14_12, gdte->dte[0]) == 0 &&
+	    FIELD_GET(DTE_GCR3_30_15, gdte->dte[1]) == 0 &&
+	    FIELD_GET(DTE_GCR3_51_31, gdte->dte[1]) == 0)
+		return -EINVAL;
+
+	/* Valid Guest Paging Mode values are 0 and 1 */
+	if (gpt_level != GUEST_PGTABLE_4_LEVEL &&
+	    gpt_level != GUEST_PGTABLE_5_LEVEL)
+		return -EINVAL;
+
+	/* GLX = 3 is reserved */
+	if (FIELD_GET(DTE_GLX, gdte->dte[0]) == 3)
+		return -EINVAL;
+
+	/*
+	 * We need to check host capability before setting
+	 * the Guest Paging Mode
+	 */
+	if (gpt_level == GUEST_PGTABLE_5_LEVEL &&
+	    amd_iommu_gpt_level < PAGE_MODE_5_LEVEL)
+		return -EOPNOTSUPP;
+
+	return 0;
+}
+
+static void *gdom_info_load_or_alloc_locked(struct xarray *xa, unsigned long index)
+{
+	struct guest_domain_mapping_info *elm, *res;
+
+	elm = xa_load(xa, index);
+	if (elm)
+		return elm;
+
+	xa_unlock(xa);
+	elm = kzalloc(sizeof(struct guest_domain_mapping_info), GFP_KERNEL);
+	xa_lock(xa);
+	if (!elm)
+		return ERR_PTR(-ENOMEM);
+
+	res = __xa_cmpxchg(xa, index, NULL, elm, GFP_KERNEL);
+	if (xa_is_err(res))
+		res = ERR_PTR(xa_err(res));
+
+	if (res) {
+		kfree(elm);
+		return res;
+	}
+
+	refcount_set(&elm->users, 0);
+	return elm;
+}
+
+/*
+ * This function is assigned to struct iommufd_viommu_ops.alloc_domain_nested()
+ * during the call to struct iommu_ops.viommu_init().
+ */
+struct iommu_domain *
+amd_iommu_alloc_domain_nested(struct iommufd_viommu *viommu, u32 flags,
+			      const struct iommu_user_data *user_data)
+{
+	int ret;
+	struct nested_domain *ndom;
+	struct guest_domain_mapping_info *gdom_info;
+	struct amd_iommu_viommu *aviommu = container_of(viommu, struct amd_iommu_viommu, core);
+
+	if (user_data->type != IOMMU_HWPT_DATA_AMD_GUEST)
+		return ERR_PTR(-EOPNOTSUPP);
+
+	ndom = kzalloc(sizeof(*ndom), GFP_KERNEL);
+	if (!ndom)
+		return ERR_PTR(-ENOMEM);
+
+	ret = iommu_copy_struct_from_user(&ndom->gdte, user_data,
+					  IOMMU_HWPT_DATA_AMD_GUEST,
+					  dte);
+	if (ret)
+		goto out_err;
+
+	ret = validate_gdte_nested(&ndom->gdte);
+	if (ret)
+		goto out_err;
+
+	ndom->gdom_id = FIELD_GET(DTE_DOMID_MASK, ndom->gdte.dte[1]);
+	ndom->domain.ops = &nested_domain_ops;
+	ndom->domain.type = IOMMU_DOMAIN_NESTED;
+	ndom->viommu = aviommu;
+
+	/*
+	 * Normally, when a guest has multiple pass-through devices,
+	 * the IOMMU driver setup DTEs with the same stage-2 table and
+	 * use the same host domain ID (hDomId). In case of nested translation,
+	 * if the guest setup different stage-1 tables with same PASID,
+	 * IOMMU would use the same TLB tag. This will results in TLB
+	 * aliasing issue.
+	 *
+	 * The guest is assigning gDomIDs based on its own algorithm for managing
+	 * cache tags of (DomID, PASID). Within a single viommu, the nest parent domain
+	 * (w/ S2 table) is used by all DTEs. But we need to consistently map the gDomID
+	 * to a single hDomID. This is done using an xarray in the vIOMMU to
+	 * keep track of the gDomID mapping. When the S2 is changed, the INVALIDATE_IOMMU_PAGES
+	 * command must be issued for each hDomID in the xarray.
+	 */
+	xa_lock(&aviommu->gdomid_array);
+
+	gdom_info = gdom_info_load_or_alloc_locked(&aviommu->gdomid_array, ndom->gdom_id);
+	if (IS_ERR(gdom_info)) {
+		xa_unlock(&aviommu->gdomid_array);
+		ret = PTR_ERR(gdom_info);
+		goto out_err;
+	}
+
+	/* Check if gDomID exist */
+	if (refcount_inc_not_zero(&gdom_info->users)) {
+		ndom->gdom_info = gdom_info;
+		xa_unlock(&aviommu->gdomid_array);
+
+		pr_debug("%s: Found gdom_id=%#x, hdom_id=%#x\n",
+			  __func__, ndom->gdom_id, gdom_info->hdom_id);
+
+		return &ndom->domain;
+	}
+
+	/* The gDomID does not exist. We allocate new hdom_id */
+	gdom_info->hdom_id = amd_iommu_pdom_id_alloc();
+	if (gdom_info->hdom_id <= 0) {
+		__xa_cmpxchg(&aviommu->gdomid_array,
+			     ndom->gdom_id, gdom_info, NULL, GFP_ATOMIC);
+		xa_unlock(&aviommu->gdomid_array);
+		ret = -ENOSPC;
+		goto out_err_gdom_info;
+	}
+
+	ndom->gdom_info = gdom_info;
+	refcount_set(&gdom_info->users, 1);
+
+	xa_unlock(&aviommu->gdomid_array);
+
+	pr_debug("%s: Allocate gdom_id=%#x, hdom_id=%#x\n",
+		 __func__, ndom->gdom_id, gdom_info->hdom_id);
+
+	return &ndom->domain;
+
+out_err_gdom_info:
+	kfree(gdom_info);
+out_err:
+	kfree(ndom);
+	return ERR_PTR(ret);
+}
+
+static void set_dte_nested(struct amd_iommu *iommu, struct iommu_domain *dom,
+			   struct iommu_dev_data *dev_data, struct dev_table_entry *new)
+{
+	struct protection_domain *parent;
+	struct nested_domain *ndom = to_ndomain(dom);
+	struct iommu_hwpt_amd_guest *gdte = &ndom->gdte;
+	struct pt_iommu_amdv1_hw_info pt_info;
+
+	/*
+	 * The nest parent domain is attached during the call to the
+	 * struct iommu_ops.viommu_init(), which will be stored as part
+	 * of the struct amd_iommu_viommu.parent.
+	 */
+	if (WARN_ON(!ndom->viommu || !ndom->viommu->parent))
+		return;
+
+	parent = ndom->viommu->parent;
+	amd_iommu_make_clear_dte(dev_data, new);
+
+	/* Retrieve the current pagetable info via the IOMMU PT API. */
+	pt_iommu_amdv1_hw_info(&parent->amdv1, &pt_info);
+
+	/*
+	 * Use domain ID from nested domain to program DTE.
+	 * See amd_iommu_alloc_domain_nested().
+	 */
+	amd_iommu_set_dte_v1(dev_data, parent, ndom->gdom_info->hdom_id,
+			     &pt_info, new);
+
+	/* GV is required for nested page table */
+	new->data[0] |= DTE_FLAG_GV;
+
+	/* Guest PPR */
+	new->data[0] |= gdte->dte[0] & DTE_FLAG_PPR;
+
+	/* Guest translation stuff */
+	new->data[0] |= gdte->dte[0] & (DTE_GLX | DTE_FLAG_GIOV);
+
+	/* GCR3 table */
+	new->data[0] |= gdte->dte[0] & DTE_GCR3_14_12;
+	new->data[1] |= gdte->dte[1] & (DTE_GCR3_30_15 | DTE_GCR3_51_31);
+
+	/* Guest paging mode */
+	new->data[2] |= gdte->dte[2] & DTE_GPT_LEVEL_MASK;
+}
+
+static int nested_attach_device(struct iommu_domain *dom, struct device *dev,
+				struct iommu_domain *old)
+{
+	struct dev_table_entry new = {0};
+	struct iommu_dev_data *dev_data = dev_iommu_priv_get(dev);
+	struct amd_iommu *iommu = get_amd_iommu_from_dev_data(dev_data);
+	int ret = 0;
+
+	/*
+	 * Needs to make sure PASID is not enabled
+	 * for this attach path.
+	 */
+	if (WARN_ON(dev_data->pasid_enabled))
+		return -EINVAL;
+
+	mutex_lock(&dev_data->mutex);
+
+	set_dte_nested(iommu, dom, dev_data, &new);
+
+	amd_iommu_update_dte(iommu, dev_data, &new);
+
+	mutex_unlock(&dev_data->mutex);
+
+	return ret;
+}
+
+static void nested_domain_free(struct iommu_domain *dom)
+{
+	struct guest_domain_mapping_info *curr;
+	struct nested_domain *ndom = to_ndomain(dom);
+	struct amd_iommu_viommu *aviommu = ndom->viommu;
+
+	xa_lock(&aviommu->gdomid_array);
+
+	if (!refcount_dec_and_test(&ndom->gdom_info->users)) {
+		xa_unlock(&aviommu->gdomid_array);
+		return;
+	}
+
+	/*
+	 * The refcount for the gdom_id to hdom_id mapping is zero.
+	 * It is now safe to remove the mapping.
+	 */
+	curr = __xa_cmpxchg(&aviommu->gdomid_array, ndom->gdom_id,
+			    ndom->gdom_info, NULL, GFP_ATOMIC);
+
+	xa_unlock(&aviommu->gdomid_array);
+	if (WARN_ON(!curr || xa_err(curr)))
+		return;
+
+	/* success */
+	pr_debug("%s: Free gdom_id=%#x, hdom_id=%#x\n",
+		__func__, ndom->gdom_id, curr->hdom_id);
+
+	amd_iommu_pdom_id_free(ndom->gdom_info->hdom_id);
+	kfree(curr);
+	kfree(ndom);
+}
+
+static const struct iommu_domain_ops nested_domain_ops = {
+	.attach_dev = nested_attach_device,
+	.free = nested_domain_free,
+};
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
index d16d35c..fc5a705 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
@@ -487,20 +487,26 @@ static void arm_smmu_cmdq_skip_err(struct arm_smmu_device *smmu)
  */
 static void arm_smmu_cmdq_shared_lock(struct arm_smmu_cmdq *cmdq)
 {
-	int val;
-
 	/*
-	 * We can try to avoid the cmpxchg() loop by simply incrementing the
-	 * lock counter. When held in exclusive state, the lock counter is set
-	 * to INT_MIN so these increments won't hurt as the value will remain
-	 * negative.
+	 * When held in exclusive state, the lock counter is set to INT_MIN
+	 * so these increments won't hurt as the value will remain negative.
+	 * The increment will also signal the exclusive locker that there are
+	 * shared waiters.
 	 */
 	if (atomic_fetch_inc_relaxed(&cmdq->lock) >= 0)
 		return;
 
-	do {
-		val = atomic_cond_read_relaxed(&cmdq->lock, VAL >= 0);
-	} while (atomic_cmpxchg_relaxed(&cmdq->lock, val, val + 1) != val);
+	/*
+	 * Someone else is holding the lock in exclusive state, so wait
+	 * for them to finish. Since we already incremented the lock counter,
+	 * no exclusive lock can be acquired until we finish. We don't need
+	 * the return value since we only care that the exclusive lock is
+	 * released (i.e. the lock counter is non-negative).
+	 * Once the exclusive locker releases the lock, the sign bit will
+	 * be cleared and our increment will make the lock counter positive,
+	 * allowing us to proceed.
+	 */
+	atomic_cond_read_relaxed(&cmdq->lock, VAL > 0);
 }
 
 static void arm_smmu_cmdq_shared_unlock(struct arm_smmu_cmdq *cmdq)
@@ -527,9 +533,14 @@ static bool arm_smmu_cmdq_shared_tryunlock(struct arm_smmu_cmdq *cmdq)
 	__ret;								\
 })
 
+/*
+ * Only clear the sign bit when releasing the exclusive lock this will
+ * allow any shared_lock() waiters to proceed without the possibility
+ * of entering the exclusive lock in a tight loop.
+ */
 #define arm_smmu_cmdq_exclusive_unlock_irqrestore(cmdq, flags)		\
 ({									\
-	atomic_set_release(&cmdq->lock, 0);				\
+	atomic_fetch_andnot_release(INT_MIN, &cmdq->lock);		\
 	local_irq_restore(flags);					\
 })
 
@@ -2551,7 +2562,7 @@ static int arm_smmu_domain_finalise(struct arm_smmu_domain *smmu_domain,
 				     ARM_SMMU_FEAT_VAX) ? 52 : 48;
 
 		pgtbl_cfg.ias = min_t(unsigned long, ias, VA_BITS);
-		pgtbl_cfg.oas = smmu->ias;
+		pgtbl_cfg.oas = smmu->oas;
 		if (enable_dirty)
 			pgtbl_cfg.quirks |= IO_PGTABLE_QUIRK_ARM_HD;
 		fmt = ARM_64_LPAE_S1;
@@ -2561,7 +2572,7 @@ static int arm_smmu_domain_finalise(struct arm_smmu_domain *smmu_domain,
 	case ARM_SMMU_DOMAIN_S2:
 		if (enable_dirty)
 			return -EOPNOTSUPP;
-		pgtbl_cfg.ias = smmu->ias;
+		pgtbl_cfg.ias = smmu->oas;
 		pgtbl_cfg.oas = smmu->oas;
 		fmt = ARM_64_LPAE_S2;
 		finalise_stage_fn = arm_smmu_domain_finalise_s2;
@@ -3125,7 +3136,8 @@ int arm_smmu_set_pasid(struct arm_smmu_master *master,
 		       struct arm_smmu_domain *smmu_domain, ioasid_t pasid,
 		       struct arm_smmu_cd *cd, struct iommu_domain *old)
 {
-	struct iommu_domain *sid_domain = iommu_get_domain_for_dev(master->dev);
+	struct iommu_domain *sid_domain =
+		iommu_driver_get_domain_for_dev(master->dev);
 	struct arm_smmu_attach_state state = {
 		.master = master,
 		.ssid = pasid,
@@ -3191,7 +3203,7 @@ static int arm_smmu_blocking_set_dev_pasid(struct iommu_domain *new_domain,
 	 */
 	if (!arm_smmu_ssids_in_use(&master->cd_table)) {
 		struct iommu_domain *sid_domain =
-			iommu_get_domain_for_dev(master->dev);
+			iommu_driver_get_domain_for_dev(master->dev);
 
 		if (sid_domain->type == IOMMU_DOMAIN_IDENTITY ||
 		    sid_domain->type == IOMMU_DOMAIN_BLOCKED)
@@ -4395,13 +4407,7 @@ static int arm_smmu_device_hw_probe(struct arm_smmu_device *smmu)
 	}
 
 	/* We only support the AArch64 table format at present */
-	switch (FIELD_GET(IDR0_TTF, reg)) {
-	case IDR0_TTF_AARCH32_64:
-		smmu->ias = 40;
-		fallthrough;
-	case IDR0_TTF_AARCH64:
-		break;
-	default:
+	if (!(FIELD_GET(IDR0_TTF, reg) & IDR0_TTF_AARCH64)) {
 		dev_err(smmu->dev, "AArch64 table format not supported!\n");
 		return -ENXIO;
 	}
@@ -4514,8 +4520,6 @@ static int arm_smmu_device_hw_probe(struct arm_smmu_device *smmu)
 		dev_warn(smmu->dev,
 			 "failed to set DMA mask for table walker\n");
 
-	smmu->ias = max(smmu->ias, smmu->oas);
-
 	if ((smmu->features & ARM_SMMU_FEAT_TRANS_S1) &&
 	    (smmu->features & ARM_SMMU_FEAT_TRANS_S2))
 		smmu->features |= ARM_SMMU_FEAT_NESTING;
@@ -4525,8 +4529,8 @@ static int arm_smmu_device_hw_probe(struct arm_smmu_device *smmu)
 	if (arm_smmu_sva_supported(smmu))
 		smmu->features |= ARM_SMMU_FEAT_SVA;
 
-	dev_info(smmu->dev, "ias %lu-bit, oas %lu-bit (features 0x%08x)\n",
-		 smmu->ias, smmu->oas, smmu->features);
+	dev_info(smmu->dev, "oas %lu-bit (features 0x%08x)\n",
+		 smmu->oas, smmu->features);
 	return 0;
 }
 
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
index ae23aac..0a5bb57d 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
@@ -43,7 +43,6 @@ struct arm_vsmmu;
 #define IDR0_COHACC			(1 << 4)
 #define IDR0_TTF			GENMASK(3, 2)
 #define IDR0_TTF_AARCH64		2
-#define IDR0_TTF_AARCH32_64		3
 #define IDR0_S1P			(1 << 1)
 #define IDR0_S2P			(1 << 0)
 
@@ -784,7 +783,6 @@ struct arm_smmu_device {
 	int				gerr_irq;
 	int				combined_irq;
 
-	unsigned long			ias; /* IPA */
 	unsigned long			oas; /* PA */
 	unsigned long			pgsize_bitmap;
 
diff --git a/drivers/iommu/arm/arm-smmu/arm-smmu-qcom.c b/drivers/iommu/arm/arm-smmu/arm-smmu-qcom.c
index 5730853..456d514 100644
--- a/drivers/iommu/arm/arm-smmu/arm-smmu-qcom.c
+++ b/drivers/iommu/arm/arm-smmu/arm-smmu-qcom.c
@@ -41,12 +41,38 @@ static const struct of_device_id qcom_smmu_actlr_client_of_match[] = {
 			.data = (const void *) (PREFETCH_DEEP | CPRE | CMTLB) },
 	{ .compatible = "qcom,fastrpc",
 			.data = (const void *) (PREFETCH_DEEP | CPRE | CMTLB) },
+	{ .compatible = "qcom,qcm2290-mdss",
+			.data = (const void *) (PREFETCH_SHALLOW | CPRE | CMTLB) },
 	{ .compatible = "qcom,sc7280-mdss",
 			.data = (const void *) (PREFETCH_SHALLOW | CPRE | CMTLB) },
 	{ .compatible = "qcom,sc7280-venus",
 			.data = (const void *) (PREFETCH_SHALLOW | CPRE | CMTLB) },
+	{ .compatible = "qcom,sc8180x-mdss",
+			.data = (const void *) (PREFETCH_SHALLOW | CPRE | CMTLB) },
+	{ .compatible = "qcom,sc8280xp-mdss",
+			.data = (const void *) (PREFETCH_SHALLOW | CPRE | CMTLB) },
+	{ .compatible = "qcom,sm6115-mdss",
+			.data = (const void *) (PREFETCH_SHALLOW | CPRE | CMTLB) },
+	{ .compatible = "qcom,sm6125-mdss",
+			.data = (const void *) (PREFETCH_SHALLOW | CPRE | CMTLB) },
+	{ .compatible = "qcom,sm6350-mdss",
+			.data = (const void *) (PREFETCH_SHALLOW | CPRE | CMTLB) },
+	{ .compatible = "qcom,sm8150-mdss",
+			.data = (const void *) (PREFETCH_SHALLOW | CPRE | CMTLB) },
+	{ .compatible = "qcom,sm8250-mdss",
+			.data = (const void *) (PREFETCH_SHALLOW | CPRE | CMTLB) },
+	{ .compatible = "qcom,sm8350-mdss",
+			.data = (const void *) (PREFETCH_SHALLOW | CPRE | CMTLB) },
+	{ .compatible = "qcom,sm8450-mdss",
+			.data = (const void *) (PREFETCH_DEFAULT | CMTLB) },
 	{ .compatible = "qcom,sm8550-mdss",
 			.data = (const void *) (PREFETCH_DEFAULT | CMTLB) },
+	{ .compatible = "qcom,sm8650-mdss",
+			.data = (const void *) (PREFETCH_DEFAULT | CMTLB) },
+	{ .compatible = "qcom,sm8750-mdss",
+			.data = (const void *) (PREFETCH_DEFAULT | CMTLB) },
+	{ .compatible = "qcom,x1e80100-mdss",
+			.data = (const void *) (PREFETCH_DEFAULT | CMTLB) },
 	{ }
 };
 
diff --git a/drivers/iommu/arm/arm-smmu/qcom_iommu.c b/drivers/iommu/arm/arm-smmu/qcom_iommu.c
index f69d927..c98bed3 100644
--- a/drivers/iommu/arm/arm-smmu/qcom_iommu.c
+++ b/drivers/iommu/arm/arm-smmu/qcom_iommu.c
@@ -761,14 +761,10 @@ static struct platform_driver qcom_iommu_ctx_driver = {
 
 static bool qcom_iommu_has_secure_context(struct qcom_iommu_dev *qcom_iommu)
 {
-	struct device_node *child;
-
-	for_each_child_of_node(qcom_iommu->dev->of_node, child) {
+	for_each_child_of_node_scoped(qcom_iommu->dev->of_node, child) {
 		if (of_device_is_compatible(child, "qcom,msm-iommu-v1-sec") ||
-		    of_device_is_compatible(child, "qcom,msm-iommu-v2-sec")) {
-			of_node_put(child);
+		    of_device_is_compatible(child, "qcom,msm-iommu-v2-sec"))
 			return true;
-		}
 	}
 
 	return false;
diff --git a/drivers/iommu/dma-iommu.c b/drivers/iommu/dma-iommu.c
index c9208885..aeaf8fa 100644
--- a/drivers/iommu/dma-iommu.c
+++ b/drivers/iommu/dma-iommu.c
@@ -2097,10 +2097,8 @@ void dma_iova_destroy(struct device *dev, struct dma_iova_state *state,
 }
 EXPORT_SYMBOL_GPL(dma_iova_destroy);
 
-void iommu_setup_dma_ops(struct device *dev)
+void iommu_setup_dma_ops(struct device *dev, struct iommu_domain *domain)
 {
-	struct iommu_domain *domain = iommu_get_domain_for_dev(dev);
-
 	if (dev_is_pci(dev))
 		dev->iommu->pci_32bit_workaround = !iommu_dma_forcedac;
 
diff --git a/drivers/iommu/dma-iommu.h b/drivers/iommu/dma-iommu.h
index eca201c..040d002 100644
--- a/drivers/iommu/dma-iommu.h
+++ b/drivers/iommu/dma-iommu.h
@@ -9,7 +9,7 @@
 
 #ifdef CONFIG_IOMMU_DMA
 
-void iommu_setup_dma_ops(struct device *dev);
+void iommu_setup_dma_ops(struct device *dev, struct iommu_domain *domain);
 
 int iommu_get_dma_cookie(struct iommu_domain *domain);
 void iommu_put_dma_cookie(struct iommu_domain *domain);
@@ -26,7 +26,8 @@ extern bool iommu_dma_forcedac;
 
 #else /* CONFIG_IOMMU_DMA */
 
-static inline void iommu_setup_dma_ops(struct device *dev)
+static inline void iommu_setup_dma_ops(struct device *dev,
+				       struct iommu_domain *domain)
 {
 }
 
diff --git a/drivers/iommu/generic_pt/fmt/amdv1.h b/drivers/iommu/generic_pt/fmt/amdv1.h
index aa8e1a8..3b2c41d 100644
--- a/drivers/iommu/generic_pt/fmt/amdv1.h
+++ b/drivers/iommu/generic_pt/fmt/amdv1.h
@@ -354,7 +354,8 @@ static inline int amdv1pt_iommu_set_prot(struct pt_common *common,
 	 * Ideally we'd have an IOMMU_ENCRYPTED flag set by higher levels to
 	 * control this. For now if the tables use sme_set then so do the ptes.
 	 */
-	if (pt_feature(common, PT_FEAT_AMDV1_ENCRYPT_TABLES))
+	if (pt_feature(common, PT_FEAT_AMDV1_ENCRYPT_TABLES) &&
+	    !(iommu_prot & IOMMU_MMIO))
 		pte = __sme_set(pte);
 
 	attrs->descriptor_bits = pte;
diff --git a/drivers/iommu/generic_pt/fmt/x86_64.h b/drivers/iommu/generic_pt/fmt/x86_64.h
index 210748d9..ed9a47c 100644
--- a/drivers/iommu/generic_pt/fmt/x86_64.h
+++ b/drivers/iommu/generic_pt/fmt/x86_64.h
@@ -227,7 +227,8 @@ static inline int x86_64_pt_iommu_set_prot(struct pt_common *common,
 	 * Ideally we'd have an IOMMU_ENCRYPTED flag set by higher levels to
 	 * control this. For now if the tables use sme_set then so do the ptes.
 	 */
-	if (pt_feature(common, PT_FEAT_X86_64_AMD_ENCRYPT_TABLES))
+	if (pt_feature(common, PT_FEAT_X86_64_AMD_ENCRYPT_TABLES) &&
+	    !(iommu_prot & IOMMU_MMIO))
 		pte = __sme_set(pte);
 
 	attrs->descriptor_bits = pte;
diff --git a/drivers/iommu/generic_pt/iommu_pt.h b/drivers/iommu/generic_pt/iommu_pt.h
index 3327116..52ef028 100644
--- a/drivers/iommu/generic_pt/iommu_pt.h
+++ b/drivers/iommu/generic_pt/iommu_pt.h
@@ -645,7 +645,7 @@ static __always_inline int __do_map_single_page(struct pt_range *range,
 	struct pt_iommu_map_args *map = arg;
 
 	pts.type = pt_load_single_entry(&pts);
-	if (level == 0) {
+	if (pts.level == 0) {
 		if (pts.type != PT_ENTRY_EMPTY)
 			return -EADDRINUSE;
 		pt_install_leaf_entry(&pts, map->oa, PAGE_SHIFT,
diff --git a/drivers/iommu/io-pgtable-arm.c b/drivers/iommu/io-pgtable-arm.c
index e662600..05d63fe 100644
--- a/drivers/iommu/io-pgtable-arm.c
+++ b/drivers/iommu/io-pgtable-arm.c
@@ -637,7 +637,7 @@ static size_t __arm_lpae_unmap(struct arm_lpae_io_pgtable *data,
 	pte = READ_ONCE(*ptep);
 	if (!pte) {
 		WARN_ON(!(data->iop.cfg.quirks & IO_PGTABLE_QUIRK_NO_WARN));
-		return -ENOENT;
+		return 0;
 	}
 
 	/* If the size matches this level, we're in the right place */
diff --git a/drivers/iommu/iommu-debug-pagealloc.c b/drivers/iommu/iommu-debug-pagealloc.c
new file mode 100644
index 0000000..80164df
--- /dev/null
+++ b/drivers/iommu/iommu-debug-pagealloc.c
@@ -0,0 +1,164 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Copyright (C) 2025 - Google Inc
+ * Author: Mostafa Saleh <smostafa@google.com>
+ * IOMMU API debug page alloc sanitizer
+ */
+#include <linux/atomic.h>
+#include <linux/iommu.h>
+#include <linux/iommu-debug-pagealloc.h>
+#include <linux/kernel.h>
+#include <linux/page_ext.h>
+#include <linux/page_owner.h>
+
+#include "iommu-priv.h"
+
+static bool needed;
+DEFINE_STATIC_KEY_FALSE(iommu_debug_initialized);
+
+struct iommu_debug_metadata {
+	atomic_t ref;
+};
+
+static __init bool need_iommu_debug(void)
+{
+	return needed;
+}
+
+struct page_ext_operations page_iommu_debug_ops = {
+	.size = sizeof(struct iommu_debug_metadata),
+	.need = need_iommu_debug,
+};
+
+static struct iommu_debug_metadata *get_iommu_data(struct page_ext *page_ext)
+{
+	return page_ext_data(page_ext, &page_iommu_debug_ops);
+}
+
+static void iommu_debug_inc_page(phys_addr_t phys)
+{
+	struct page_ext *page_ext = page_ext_from_phys(phys);
+	struct iommu_debug_metadata *d;
+
+	if (!page_ext)
+		return;
+
+	d = get_iommu_data(page_ext);
+	WARN_ON(atomic_inc_return_relaxed(&d->ref) <= 0);
+	page_ext_put(page_ext);
+}
+
+static void iommu_debug_dec_page(phys_addr_t phys)
+{
+	struct page_ext *page_ext = page_ext_from_phys(phys);
+	struct iommu_debug_metadata *d;
+
+	if (!page_ext)
+		return;
+
+	d = get_iommu_data(page_ext);
+	WARN_ON(atomic_dec_return_relaxed(&d->ref) < 0);
+	page_ext_put(page_ext);
+}
+
+/*
+ * IOMMU page size doesn't have to match the CPU page size. So, we use
+ * the smallest IOMMU page size to refcount the pages in the vmemmap.
+ * That is important as both map and unmap has to use the same page size
+ * to update the refcount to avoid double counting the same page.
+ * And as we can't know from iommu_unmap() what was the original page size
+ * used for map, we just use the minimum supported one for both.
+ */
+static size_t iommu_debug_page_size(struct iommu_domain *domain)
+{
+	return 1UL << __ffs(domain->pgsize_bitmap);
+}
+
+static bool iommu_debug_page_count(const struct page *page)
+{
+	unsigned int ref;
+	struct page_ext *page_ext = page_ext_get(page);
+	struct iommu_debug_metadata *d = get_iommu_data(page_ext);
+
+	ref = atomic_read(&d->ref);
+	page_ext_put(page_ext);
+	return ref != 0;
+}
+
+void __iommu_debug_check_unmapped(const struct page *page, int numpages)
+{
+	while (numpages--) {
+		if (WARN_ON(iommu_debug_page_count(page))) {
+			pr_warn("iommu: Detected page leak!\n");
+			dump_page_owner(page);
+		}
+		page++;
+	}
+}
+
+void __iommu_debug_map(struct iommu_domain *domain, phys_addr_t phys, size_t size)
+{
+	size_t off, end;
+	size_t page_size = iommu_debug_page_size(domain);
+
+	if (WARN_ON(!phys || check_add_overflow(phys, size, &end)))
+		return;
+
+	for (off = 0 ; off < size ; off += page_size)
+		iommu_debug_inc_page(phys + off);
+}
+
+static void __iommu_debug_update_iova(struct iommu_domain *domain,
+				      unsigned long iova, size_t size, bool inc)
+{
+	size_t off, end;
+	size_t page_size = iommu_debug_page_size(domain);
+
+	if (WARN_ON(check_add_overflow(iova, size, &end)))
+		return;
+
+	for (off = 0 ; off < size ; off += page_size) {
+		phys_addr_t phys = iommu_iova_to_phys(domain, iova + off);
+
+		if (!phys)
+			continue;
+
+		if (inc)
+			iommu_debug_inc_page(phys);
+		else
+			iommu_debug_dec_page(phys);
+	}
+}
+
+void __iommu_debug_unmap_begin(struct iommu_domain *domain,
+			       unsigned long iova, size_t size)
+{
+	__iommu_debug_update_iova(domain, iova, size, false);
+}
+
+void __iommu_debug_unmap_end(struct iommu_domain *domain,
+			     unsigned long iova, size_t size,
+			     size_t unmapped)
+{
+	if ((unmapped == size) || WARN_ON_ONCE(unmapped > size))
+		return;
+
+	/* If unmap failed, re-increment the refcount. */
+	__iommu_debug_update_iova(domain, iova + unmapped,
+				  size - unmapped, true);
+}
+
+void iommu_debug_init(void)
+{
+	if (!needed)
+		return;
+
+	pr_info("iommu: Debugging page allocations, expect overhead or disable iommu.debug_pagealloc");
+	static_branch_enable(&iommu_debug_initialized);
+}
+
+static int __init iommu_debug_pagealloc(char *str)
+{
+	return kstrtobool(str, &needed);
+}
+early_param("iommu.debug_pagealloc", iommu_debug_pagealloc);
diff --git a/drivers/iommu/iommu-priv.h b/drivers/iommu/iommu-priv.h
index c95394c..aaffad5 100644
--- a/drivers/iommu/iommu-priv.h
+++ b/drivers/iommu/iommu-priv.h
@@ -5,6 +5,7 @@
 #define __LINUX_IOMMU_PRIV_H
 
 #include <linux/iommu.h>
+#include <linux/iommu-debug-pagealloc.h>
 #include <linux/msi.h>
 
 static inline const struct iommu_ops *dev_iommu_ops(struct device *dev)
@@ -65,4 +66,61 @@ static inline int iommufd_sw_msi(struct iommu_domain *domain,
 int iommu_replace_device_pasid(struct iommu_domain *domain,
 			       struct device *dev, ioasid_t pasid,
 			       struct iommu_attach_handle *handle);
+
+#ifdef CONFIG_IOMMU_DEBUG_PAGEALLOC
+
+void __iommu_debug_map(struct iommu_domain *domain, phys_addr_t phys,
+		       size_t size);
+void __iommu_debug_unmap_begin(struct iommu_domain *domain,
+			       unsigned long iova, size_t size);
+void __iommu_debug_unmap_end(struct iommu_domain *domain,
+			     unsigned long iova, size_t size, size_t unmapped);
+
+static inline void iommu_debug_map(struct iommu_domain *domain,
+				   phys_addr_t phys, size_t size)
+{
+	if (static_branch_unlikely(&iommu_debug_initialized))
+		__iommu_debug_map(domain, phys, size);
+}
+
+static inline void iommu_debug_unmap_begin(struct iommu_domain *domain,
+					   unsigned long iova, size_t size)
+{
+	if (static_branch_unlikely(&iommu_debug_initialized))
+		__iommu_debug_unmap_begin(domain, iova, size);
+}
+
+static inline void iommu_debug_unmap_end(struct iommu_domain *domain,
+					 unsigned long iova, size_t size,
+					 size_t unmapped)
+{
+	if (static_branch_unlikely(&iommu_debug_initialized))
+		__iommu_debug_unmap_end(domain, iova, size, unmapped);
+}
+
+void iommu_debug_init(void);
+
+#else
+static inline void iommu_debug_map(struct iommu_domain *domain,
+				   phys_addr_t phys, size_t size)
+{
+}
+
+static inline void iommu_debug_unmap_begin(struct iommu_domain *domain,
+					   unsigned long iova, size_t size)
+{
+}
+
+static inline void iommu_debug_unmap_end(struct iommu_domain *domain,
+					 unsigned long iova, size_t size,
+					 size_t unmapped)
+{
+}
+
+static inline void iommu_debug_init(void)
+{
+}
+
+#endif /* CONFIG_IOMMU_DEBUG_PAGEALLOC */
+
 #endif /* __LINUX_IOMMU_PRIV_H */
diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index 2ca990d..4926a43 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -61,6 +61,11 @@ struct iommu_group {
 	int id;
 	struct iommu_domain *default_domain;
 	struct iommu_domain *blocking_domain;
+	/*
+	 * During a group device reset, @resetting_domain points to the physical
+	 * domain, while @domain points to the attached domain before the reset.
+	 */
+	struct iommu_domain *resetting_domain;
 	struct iommu_domain *domain;
 	struct list_head entry;
 	unsigned int owner_cnt;
@@ -232,6 +237,8 @@ static int __init iommu_subsys_init(void)
 	if (!nb)
 		return -ENOMEM;
 
+	iommu_debug_init();
+
 	for (int i = 0; i < ARRAY_SIZE(iommu_buses); i++) {
 		nb[i].notifier_call = iommu_bus_notifier;
 		bus_register_notifier(iommu_buses[i], &nb[i]);
@@ -661,7 +668,7 @@ static int __iommu_probe_device(struct device *dev, struct list_head *group_list
 	}
 
 	if (group->default_domain)
-		iommu_setup_dma_ops(dev);
+		iommu_setup_dma_ops(dev, group->default_domain);
 
 	mutex_unlock(&group->mutex);
 
@@ -1173,12 +1180,11 @@ static int iommu_create_device_direct_mappings(struct iommu_domain *domain,
 					       struct device *dev)
 {
 	struct iommu_resv_region *entry;
-	struct list_head mappings;
+	LIST_HEAD(mappings);
 	unsigned long pg_size;
 	int ret = 0;
 
 	pg_size = domain->pgsize_bitmap ? 1UL << __ffs(domain->pgsize_bitmap) : 0;
-	INIT_LIST_HEAD(&mappings);
 
 	if (WARN_ON_ONCE(iommu_is_dma_domain(domain) && !pg_size))
 		return -EINVAL;
@@ -1949,7 +1955,7 @@ static int bus_iommu_probe(const struct bus_type *bus)
 			return ret;
 		}
 		for_each_group_device(group, gdev)
-			iommu_setup_dma_ops(gdev->dev);
+			iommu_setup_dma_ops(gdev->dev, group->default_domain);
 		mutex_unlock(&group->mutex);
 
 		/*
@@ -2185,10 +2191,26 @@ EXPORT_SYMBOL_GPL(iommu_attach_device);
 
 int iommu_deferred_attach(struct device *dev, struct iommu_domain *domain)
 {
-	if (dev->iommu && dev->iommu->attach_deferred)
-		return __iommu_attach_device(domain, dev, NULL);
+	/*
+	 * This is called on the dma mapping fast path so avoid locking. This is
+	 * racy, but we have an expectation that the driver will setup its DMAs
+	 * inside probe while being single threaded to avoid racing.
+	 */
+	if (!dev->iommu || !dev->iommu->attach_deferred)
+		return 0;
 
-	return 0;
+	guard(mutex)(&dev->iommu_group->mutex);
+
+	/*
+	 * This is a concurrent attach during a device reset. Reject it until
+	 * pci_dev_reset_iommu_done() attaches the device to group->domain.
+	 *
+	 * Note that this might fail the iommu_dma_map(). But there's nothing
+	 * more we can do here.
+	 */
+	if (dev->iommu_group->resetting_domain)
+		return -EBUSY;
+	return __iommu_attach_device(domain, dev, NULL);
 }
 
 void iommu_detach_device(struct iommu_domain *domain, struct device *dev)
@@ -2210,6 +2232,15 @@ void iommu_detach_device(struct iommu_domain *domain, struct device *dev)
 }
 EXPORT_SYMBOL_GPL(iommu_detach_device);
 
+/**
+ * iommu_get_domain_for_dev() - Return the DMA API domain pointer
+ * @dev: Device to query
+ *
+ * This function can be called within a driver bound to dev. The returned
+ * pointer is valid for the lifetime of the bound driver.
+ *
+ * It should not be called by drivers with driver_managed_dma = true.
+ */
 struct iommu_domain *iommu_get_domain_for_dev(struct device *dev)
 {
 	/* Caller must be a probed driver on dev */
@@ -2218,10 +2249,40 @@ struct iommu_domain *iommu_get_domain_for_dev(struct device *dev)
 	if (!group)
 		return NULL;
 
+	lockdep_assert_not_held(&group->mutex);
+
 	return group->domain;
 }
 EXPORT_SYMBOL_GPL(iommu_get_domain_for_dev);
 
+/**
+ * iommu_driver_get_domain_for_dev() - Return the driver-level domain pointer
+ * @dev: Device to query
+ *
+ * This function can be called by an iommu driver that wants to get the physical
+ * domain within an iommu callback function where group->mutex is held.
+ */
+struct iommu_domain *iommu_driver_get_domain_for_dev(struct device *dev)
+{
+	struct iommu_group *group = dev->iommu_group;
+
+	lockdep_assert_held(&group->mutex);
+
+	/*
+	 * Driver handles the low-level __iommu_attach_device(), including the
+	 * one invoked by pci_dev_reset_iommu_done() re-attaching the device to
+	 * the cached group->domain. In this case, the driver must get the old
+	 * domain from group->resetting_domain rather than group->domain. This
+	 * prevents it from re-attaching the device from group->domain (old) to
+	 * group->domain (new).
+	 */
+	if (group->resetting_domain)
+		return group->resetting_domain;
+
+	return group->domain;
+}
+EXPORT_SYMBOL_GPL(iommu_driver_get_domain_for_dev);
+
 /*
  * For IOMMU_DOMAIN_DMA implementations which already provide their own
  * guarantees that the group and its default domain are valid and correct.
@@ -2375,6 +2436,13 @@ static int __iommu_group_set_domain_internal(struct iommu_group *group,
 		return -EINVAL;
 
 	/*
+	 * This is a concurrent attach during a device reset. Reject it until
+	 * pci_dev_reset_iommu_done() attaches the device to group->domain.
+	 */
+	if (group->resetting_domain)
+		return -EBUSY;
+
+	/*
 	 * Changing the domain is done by calling attach_dev() on the new
 	 * domain. This switch does not have to be atomic and DMA can be
 	 * discarded during the transition. DMA must only be able to access
@@ -2562,10 +2630,12 @@ int iommu_map_nosync(struct iommu_domain *domain, unsigned long iova,
 	}
 
 	/* unroll mapping in case something went wrong */
-	if (ret)
+	if (ret) {
 		iommu_unmap(domain, orig_iova, orig_size - size);
-	else
+	} else {
 		trace_map(orig_iova, orig_paddr, orig_size);
+		iommu_debug_map(domain, orig_paddr, orig_size);
+	}
 
 	return ret;
 }
@@ -2627,6 +2697,8 @@ static size_t __iommu_unmap(struct iommu_domain *domain,
 
 	pr_debug("unmap this: iova 0x%lx size 0x%zx\n", iova, size);
 
+	iommu_debug_unmap_begin(domain, iova, size);
+
 	/*
 	 * Keep iterating until we either unmap 'size' bytes (or more)
 	 * or we hit an area that isn't mapped.
@@ -2647,6 +2719,7 @@ static size_t __iommu_unmap(struct iommu_domain *domain,
 	}
 
 	trace_unmap(orig_iova, size, unmapped);
+	iommu_debug_unmap_end(domain, orig_iova, size, unmapped);
 	return unmapped;
 }
 
@@ -3148,7 +3221,7 @@ static ssize_t iommu_group_store_type(struct iommu_group *group,
 
 	/* Make sure dma_ops is appropriatley set */
 	for_each_group_device(group, gdev)
-		iommu_setup_dma_ops(gdev->dev);
+		iommu_setup_dma_ops(gdev->dev, group->default_domain);
 
 out_unlock:
 	mutex_unlock(&group->mutex);
@@ -3492,6 +3565,16 @@ int iommu_attach_device_pasid(struct iommu_domain *domain,
 		return -EINVAL;
 
 	mutex_lock(&group->mutex);
+
+	/*
+	 * This is a concurrent attach during a device reset. Reject it until
+	 * pci_dev_reset_iommu_done() attaches the device to group->domain.
+	 */
+	if (group->resetting_domain) {
+		ret = -EBUSY;
+		goto out_unlock;
+	}
+
 	for_each_group_device(group, device) {
 		/*
 		 * Skip PASID validation for devices without PASID support
@@ -3575,6 +3658,16 @@ int iommu_replace_device_pasid(struct iommu_domain *domain,
 		return -EINVAL;
 
 	mutex_lock(&group->mutex);
+
+	/*
+	 * This is a concurrent attach during a device reset. Reject it until
+	 * pci_dev_reset_iommu_done() attaches the device to group->domain.
+	 */
+	if (group->resetting_domain) {
+		ret = -EBUSY;
+		goto out_unlock;
+	}
+
 	entry = iommu_make_pasid_array_entry(domain, handle);
 	curr = xa_cmpxchg(&group->pasid_array, pasid, NULL,
 			  XA_ZERO_ENTRY, GFP_KERNEL);
@@ -3832,6 +3925,127 @@ int iommu_replace_group_handle(struct iommu_group *group,
 }
 EXPORT_SYMBOL_NS_GPL(iommu_replace_group_handle, "IOMMUFD_INTERNAL");
 
+/**
+ * pci_dev_reset_iommu_prepare() - Block IOMMU to prepare for a PCI device reset
+ * @pdev: PCI device that is going to enter a reset routine
+ *
+ * The PCIe r6.0, sec 10.3.1 IMPLEMENTATION NOTE recommends to disable and block
+ * ATS before initiating a reset. This means that a PCIe device during the reset
+ * routine wants to block any IOMMU activity: translation and ATS invalidation.
+ *
+ * This function attaches the device's RID/PASID(s) the group->blocking_domain,
+ * setting the group->resetting_domain. This allows the IOMMU driver pausing any
+ * IOMMU activity while leaving the group->domain pointer intact. Later when the
+ * reset is finished, pci_dev_reset_iommu_done() can restore everything.
+ *
+ * Caller must use pci_dev_reset_iommu_prepare() with pci_dev_reset_iommu_done()
+ * before/after the core-level reset routine, to unset the resetting_domain.
+ *
+ * Return: 0 on success or negative error code if the preparation failed.
+ *
+ * These two functions are designed to be used by PCI reset functions that would
+ * not invoke any racy iommu_release_device(), since PCI sysfs node gets removed
+ * before it notifies with a BUS_NOTIFY_REMOVED_DEVICE. When using them in other
+ * case, callers must ensure there will be no racy iommu_release_device() call,
+ * which otherwise would UAF the dev->iommu_group pointer.
+ */
+int pci_dev_reset_iommu_prepare(struct pci_dev *pdev)
+{
+	struct iommu_group *group = pdev->dev.iommu_group;
+	unsigned long pasid;
+	void *entry;
+	int ret;
+
+	if (!pci_ats_supported(pdev) || !dev_has_iommu(&pdev->dev))
+		return 0;
+
+	guard(mutex)(&group->mutex);
+
+	/* Re-entry is not allowed */
+	if (WARN_ON(group->resetting_domain))
+		return -EBUSY;
+
+	ret = __iommu_group_alloc_blocking_domain(group);
+	if (ret)
+		return ret;
+
+	/* Stage RID domain at blocking_domain while retaining group->domain */
+	if (group->domain != group->blocking_domain) {
+		ret = __iommu_attach_device(group->blocking_domain, &pdev->dev,
+					    group->domain);
+		if (ret)
+			return ret;
+	}
+
+	/*
+	 * Stage PASID domains at blocking_domain while retaining pasid_array.
+	 *
+	 * The pasid_array is mostly fenced by group->mutex, except one reader
+	 * in iommu_attach_handle_get(), so it's safe to read without xa_lock.
+	 */
+	xa_for_each_start(&group->pasid_array, pasid, entry, 1)
+		iommu_remove_dev_pasid(&pdev->dev, pasid,
+				       pasid_array_entry_to_domain(entry));
+
+	group->resetting_domain = group->blocking_domain;
+	return ret;
+}
+EXPORT_SYMBOL_GPL(pci_dev_reset_iommu_prepare);
+
+/**
+ * pci_dev_reset_iommu_done() - Restore IOMMU after a PCI device reset is done
+ * @pdev: PCI device that has finished a reset routine
+ *
+ * After a PCIe device finishes a reset routine, it wants to restore its IOMMU
+ * IOMMU activity, including new translation as well as cache invalidation, by
+ * re-attaching all RID/PASID of the device's back to the domains retained in
+ * the core-level structure.
+ *
+ * Caller must pair it with a successful pci_dev_reset_iommu_prepare().
+ *
+ * Note that, although unlikely, there is a risk that re-attaching domains might
+ * fail due to some unexpected happening like OOM.
+ */
+void pci_dev_reset_iommu_done(struct pci_dev *pdev)
+{
+	struct iommu_group *group = pdev->dev.iommu_group;
+	unsigned long pasid;
+	void *entry;
+
+	if (!pci_ats_supported(pdev) || !dev_has_iommu(&pdev->dev))
+		return;
+
+	guard(mutex)(&group->mutex);
+
+	/* pci_dev_reset_iommu_prepare() was bypassed for the device */
+	if (!group->resetting_domain)
+		return;
+
+	/* pci_dev_reset_iommu_prepare() was not successfully called */
+	if (WARN_ON(!group->blocking_domain))
+		return;
+
+	/* Re-attach RID domain back to group->domain */
+	if (group->domain != group->blocking_domain) {
+		WARN_ON(__iommu_attach_device(group->domain, &pdev->dev,
+					      group->blocking_domain));
+	}
+
+	/*
+	 * Re-attach PASID domains back to the domains retained in pasid_array.
+	 *
+	 * The pasid_array is mostly fenced by group->mutex, except one reader
+	 * in iommu_attach_handle_get(), so it's safe to read without xa_lock.
+	 */
+	xa_for_each_start(&group->pasid_array, pasid, entry, 1)
+		WARN_ON(__iommu_set_group_pasid(
+			pasid_array_entry_to_domain(entry), group, pasid,
+			group->blocking_domain));
+
+	group->resetting_domain = NULL;
+}
+EXPORT_SYMBOL_GPL(pci_dev_reset_iommu_done);
+
 #if IS_ENABLED(CONFIG_IRQ_MSI_IOMMU)
 /**
  * iommu_dma_prepare_msi() - Map the MSI page in the IOMMU domain
diff --git a/drivers/pci/pci-acpi.c b/drivers/pci/pci-acpi.c
index 93693777..651d9b5 100644
--- a/drivers/pci/pci-acpi.c
+++ b/drivers/pci/pci-acpi.c
@@ -9,6 +9,7 @@
 
 #include <linux/delay.h>
 #include <linux/init.h>
+#include <linux/iommu.h>
 #include <linux/irqdomain.h>
 #include <linux/pci.h>
 #include <linux/msi.h>
@@ -971,6 +972,7 @@ void pci_set_acpi_fwnode(struct pci_dev *dev)
 int pci_dev_acpi_reset(struct pci_dev *dev, bool probe)
 {
 	acpi_handle handle = ACPI_HANDLE(&dev->dev);
+	int ret;
 
 	if (!handle || !acpi_has_method(handle, "_RST"))
 		return -ENOTTY;
@@ -978,12 +980,19 @@ int pci_dev_acpi_reset(struct pci_dev *dev, bool probe)
 	if (probe)
 		return 0;
 
-	if (ACPI_FAILURE(acpi_evaluate_object(handle, "_RST", NULL, NULL))) {
-		pci_warn(dev, "ACPI _RST failed\n");
-		return -ENOTTY;
+	ret = pci_dev_reset_iommu_prepare(dev);
+	if (ret) {
+		pci_err(dev, "failed to stop IOMMU for a PCI reset: %d\n", ret);
+		return ret;
 	}
 
-	return 0;
+	if (ACPI_FAILURE(acpi_evaluate_object(handle, "_RST", NULL, NULL))) {
+		pci_warn(dev, "ACPI _RST failed\n");
+		ret = -ENOTTY;
+	}
+
+	pci_dev_reset_iommu_done(dev);
+	return ret;
 }
 
 bool acpi_pci_power_manageable(struct pci_dev *dev)
diff --git a/drivers/pci/pci.c b/drivers/pci/pci.c
index 13dbb40..a0ba42a 100644
--- a/drivers/pci/pci.c
+++ b/drivers/pci/pci.c
@@ -13,6 +13,7 @@
 #include <linux/delay.h>
 #include <linux/dmi.h>
 #include <linux/init.h>
+#include <linux/iommu.h>
 #include <linux/msi.h>
 #include <linux/of.h>
 #include <linux/pci.h>
@@ -25,6 +26,7 @@
 #include <linux/logic_pio.h>
 #include <linux/device.h>
 #include <linux/pm_runtime.h>
+#include <linux/pci-ats.h>
 #include <linux/pci_hotplug.h>
 #include <linux/vmalloc.h>
 #include <asm/dma.h>
@@ -4330,13 +4332,22 @@ EXPORT_SYMBOL(pci_wait_for_pending_transaction);
  */
 int pcie_flr(struct pci_dev *dev)
 {
+	int ret;
+
 	if (!pci_wait_for_pending_transaction(dev))
 		pci_err(dev, "timed out waiting for pending transaction; performing function level reset anyway\n");
 
+	/* Have to call it after waiting for pending DMA transaction */
+	ret = pci_dev_reset_iommu_prepare(dev);
+	if (ret) {
+		pci_err(dev, "failed to stop IOMMU for a PCI reset: %d\n", ret);
+		return ret;
+	}
+
 	pcie_capability_set_word(dev, PCI_EXP_DEVCTL, PCI_EXP_DEVCTL_BCR_FLR);
 
 	if (dev->imm_ready)
-		return 0;
+		goto done;
 
 	/*
 	 * Per PCIe r4.0, sec 6.6.2, a device must complete an FLR within
@@ -4345,7 +4356,10 @@ int pcie_flr(struct pci_dev *dev)
 	 */
 	msleep(100);
 
-	return pci_dev_wait(dev, "FLR", PCIE_RESET_READY_POLL_MS);
+	ret = pci_dev_wait(dev, "FLR", PCIE_RESET_READY_POLL_MS);
+done:
+	pci_dev_reset_iommu_done(dev);
+	return ret;
 }
 EXPORT_SYMBOL_GPL(pcie_flr);
 
@@ -4373,6 +4387,7 @@ EXPORT_SYMBOL_GPL(pcie_reset_flr);
 
 static int pci_af_flr(struct pci_dev *dev, bool probe)
 {
+	int ret;
 	int pos;
 	u8 cap;
 
@@ -4399,10 +4414,17 @@ static int pci_af_flr(struct pci_dev *dev, bool probe)
 				 PCI_AF_STATUS_TP << 8))
 		pci_err(dev, "timed out waiting for pending transaction; performing AF function level reset anyway\n");
 
+	/* Have to call it after waiting for pending DMA transaction */
+	ret = pci_dev_reset_iommu_prepare(dev);
+	if (ret) {
+		pci_err(dev, "failed to stop IOMMU for a PCI reset: %d\n", ret);
+		return ret;
+	}
+
 	pci_write_config_byte(dev, pos + PCI_AF_CTRL, PCI_AF_CTRL_FLR);
 
 	if (dev->imm_ready)
-		return 0;
+		goto done;
 
 	/*
 	 * Per Advanced Capabilities for Conventional PCI ECN, 13 April 2006,
@@ -4412,7 +4434,10 @@ static int pci_af_flr(struct pci_dev *dev, bool probe)
 	 */
 	msleep(100);
 
-	return pci_dev_wait(dev, "AF_FLR", PCIE_RESET_READY_POLL_MS);
+	ret = pci_dev_wait(dev, "AF_FLR", PCIE_RESET_READY_POLL_MS);
+done:
+	pci_dev_reset_iommu_done(dev);
+	return ret;
 }
 
 /**
@@ -4433,6 +4458,7 @@ static int pci_af_flr(struct pci_dev *dev, bool probe)
 static int pci_pm_reset(struct pci_dev *dev, bool probe)
 {
 	u16 csr;
+	int ret;
 
 	if (!dev->pm_cap || dev->dev_flags & PCI_DEV_FLAGS_NO_PM_RESET)
 		return -ENOTTY;
@@ -4447,6 +4473,12 @@ static int pci_pm_reset(struct pci_dev *dev, bool probe)
 	if (dev->current_state != PCI_D0)
 		return -EINVAL;
 
+	ret = pci_dev_reset_iommu_prepare(dev);
+	if (ret) {
+		pci_err(dev, "failed to stop IOMMU for a PCI reset: %d\n", ret);
+		return ret;
+	}
+
 	csr &= ~PCI_PM_CTRL_STATE_MASK;
 	csr |= PCI_D3hot;
 	pci_write_config_word(dev, dev->pm_cap + PCI_PM_CTRL, csr);
@@ -4457,7 +4489,9 @@ static int pci_pm_reset(struct pci_dev *dev, bool probe)
 	pci_write_config_word(dev, dev->pm_cap + PCI_PM_CTRL, csr);
 	pci_dev_d3_sleep(dev);
 
-	return pci_dev_wait(dev, "PM D3hot->D0", PCIE_RESET_READY_POLL_MS);
+	ret = pci_dev_wait(dev, "PM D3hot->D0", PCIE_RESET_READY_POLL_MS);
+	pci_dev_reset_iommu_done(dev);
+	return ret;
 }
 
 /**
@@ -4885,10 +4919,20 @@ static int pci_reset_bus_function(struct pci_dev *dev, bool probe)
 		return -ENOTTY;
 	}
 
+	rc = pci_dev_reset_iommu_prepare(dev);
+	if (rc) {
+		pci_err(dev, "failed to stop IOMMU for a PCI reset: %d\n", rc);
+		return rc;
+	}
+
 	rc = pci_dev_reset_slot_function(dev, probe);
 	if (rc != -ENOTTY)
-		return rc;
-	return pci_parent_bus_reset(dev, probe);
+		goto done;
+
+	rc = pci_parent_bus_reset(dev, probe);
+done:
+	pci_dev_reset_iommu_done(dev);
+	return rc;
 }
 
 static int cxl_reset_bus_function(struct pci_dev *dev, bool probe)
@@ -4912,6 +4956,12 @@ static int cxl_reset_bus_function(struct pci_dev *dev, bool probe)
 	if (rc)
 		return -ENOTTY;
 
+	rc = pci_dev_reset_iommu_prepare(dev);
+	if (rc) {
+		pci_err(dev, "failed to stop IOMMU for a PCI reset: %d\n", rc);
+		return rc;
+	}
+
 	if (reg & PCI_DVSEC_CXL_PORT_CTL_UNMASK_SBR) {
 		val = reg;
 	} else {
@@ -4926,6 +4976,7 @@ static int cxl_reset_bus_function(struct pci_dev *dev, bool probe)
 		pci_write_config_word(bridge, dvsec + PCI_DVSEC_CXL_PORT_CTL,
 				      reg);
 
+	pci_dev_reset_iommu_done(dev);
 	return rc;
 }
 
diff --git a/drivers/pci/quirks.c b/drivers/pci/quirks.c
index 280cd50..6df24dd 100644
--- a/drivers/pci/quirks.c
+++ b/drivers/pci/quirks.c
@@ -21,6 +21,7 @@
 #include <linux/pci.h>
 #include <linux/isa-dma.h> /* isa_dma_bridge_buggy */
 #include <linux/init.h>
+#include <linux/iommu.h>
 #include <linux/delay.h>
 #include <linux/acpi.h>
 #include <linux/dmi.h>
@@ -4228,6 +4229,22 @@ static const struct pci_dev_reset_methods pci_dev_reset_methods[] = {
 	{ 0 }
 };
 
+static int __pci_dev_specific_reset(struct pci_dev *dev, bool probe,
+				    const struct pci_dev_reset_methods *i)
+{
+	int ret;
+
+	ret = pci_dev_reset_iommu_prepare(dev);
+	if (ret) {
+		pci_err(dev, "failed to stop IOMMU for a PCI reset: %d\n", ret);
+		return ret;
+	}
+
+	ret = i->reset(dev, probe);
+	pci_dev_reset_iommu_done(dev);
+	return ret;
+}
+
 /*
  * These device-specific reset methods are here rather than in a driver
  * because when a host assigns a device to a guest VM, the host may need
@@ -4242,7 +4259,7 @@ int pci_dev_specific_reset(struct pci_dev *dev, bool probe)
 		     i->vendor == (u16)PCI_ANY_ID) &&
 		    (i->device == dev->device ||
 		     i->device == (u16)PCI_ANY_ID))
-			return i->reset(dev, probe);
+			return __pci_dev_specific_reset(dev, probe, i);
 	}
 
 	return -ENOTTY;
diff --git a/include/linux/iommu-debug-pagealloc.h b/include/linux/iommu-debug-pagealloc.h
new file mode 100644
index 0000000..46c3c1f
--- /dev/null
+++ b/include/linux/iommu-debug-pagealloc.h
@@ -0,0 +1,32 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Copyright (C) 2025 - Google Inc
+ * Author: Mostafa Saleh <smostafa@google.com>
+ * IOMMU API debug page alloc sanitizer
+ */
+
+#ifndef __LINUX_IOMMU_DEBUG_PAGEALLOC_H
+#define __LINUX_IOMMU_DEBUG_PAGEALLOC_H
+
+#ifdef CONFIG_IOMMU_DEBUG_PAGEALLOC
+DECLARE_STATIC_KEY_FALSE(iommu_debug_initialized);
+
+extern struct page_ext_operations page_iommu_debug_ops;
+
+void __iommu_debug_check_unmapped(const struct page *page, int numpages);
+
+static inline void iommu_debug_check_unmapped(const struct page *page, int numpages)
+{
+	if (static_branch_unlikely(&iommu_debug_initialized))
+		__iommu_debug_check_unmapped(page, numpages);
+}
+
+#else
+static inline void iommu_debug_check_unmapped(const struct page *page,
+					      int numpages)
+{
+}
+
+#endif /* CONFIG_IOMMU_DEBUG_PAGEALLOC */
+
+#endif /* __LINUX_IOMMU_DEBUG_PAGEALLOC_H */
diff --git a/include/linux/iommu.h b/include/linux/iommu.h
index 8c66284..54b8b48 100644
--- a/include/linux/iommu.h
+++ b/include/linux/iommu.h
@@ -910,6 +910,7 @@ extern int iommu_attach_device(struct iommu_domain *domain,
 extern void iommu_detach_device(struct iommu_domain *domain,
 				struct device *dev);
 extern struct iommu_domain *iommu_get_domain_for_dev(struct device *dev);
+struct iommu_domain *iommu_driver_get_domain_for_dev(struct device *dev);
 extern struct iommu_domain *iommu_get_dma_domain(struct device *dev);
 extern int iommu_map(struct iommu_domain *domain, unsigned long iova,
 		     phys_addr_t paddr, size_t size, int prot, gfp_t gfp);
@@ -1187,6 +1188,10 @@ void iommu_detach_device_pasid(struct iommu_domain *domain,
 			       struct device *dev, ioasid_t pasid);
 ioasid_t iommu_alloc_global_pasid(struct device *dev);
 void iommu_free_global_pasid(ioasid_t pasid);
+
+/* PCI device reset functions */
+int pci_dev_reset_iommu_prepare(struct pci_dev *pdev);
+void pci_dev_reset_iommu_done(struct pci_dev *pdev);
 #else /* CONFIG_IOMMU_API */
 
 struct iommu_ops {};
@@ -1510,6 +1515,15 @@ static inline ioasid_t iommu_alloc_global_pasid(struct device *dev)
 }
 
 static inline void iommu_free_global_pasid(ioasid_t pasid) {}
+
+static inline int pci_dev_reset_iommu_prepare(struct pci_dev *pdev)
+{
+	return 0;
+}
+
+static inline void pci_dev_reset_iommu_done(struct pci_dev *pdev)
+{
+}
 #endif /* CONFIG_IOMMU_API */
 
 #ifdef CONFIG_IRQ_MSI_IOMMU
diff --git a/include/linux/mm.h b/include/linux/mm.h
index 6f959d8..32205d2 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -36,6 +36,7 @@
 #include <linux/rcuwait.h>
 #include <linux/bitmap.h>
 #include <linux/bitops.h>
+#include <linux/iommu-debug-pagealloc.h>
 
 struct mempolicy;
 struct anon_vma;
@@ -4133,12 +4134,16 @@ extern void __kernel_map_pages(struct page *page, int numpages, int enable);
 #ifdef CONFIG_DEBUG_PAGEALLOC
 static inline void debug_pagealloc_map_pages(struct page *page, int numpages)
 {
+	iommu_debug_check_unmapped(page, numpages);
+
 	if (debug_pagealloc_enabled_static())
 		__kernel_map_pages(page, numpages, 1);
 }
 
 static inline void debug_pagealloc_unmap_pages(struct page *page, int numpages)
 {
+	iommu_debug_check_unmapped(page, numpages);
+
 	if (debug_pagealloc_enabled_static())
 		__kernel_map_pages(page, numpages, 0);
 }
diff --git a/include/linux/page_ext.h b/include/linux/page_ext.h
index 76c8171..61e876e 100644
--- a/include/linux/page_ext.h
+++ b/include/linux/page_ext.h
@@ -93,6 +93,7 @@ static inline bool page_ext_iter_next_fast_possible(unsigned long next_pfn)
 #endif
 
 extern struct page_ext *page_ext_get(const struct page *page);
+extern struct page_ext *page_ext_from_phys(phys_addr_t phys);
 extern void page_ext_put(struct page_ext *page_ext);
 extern struct page_ext *page_ext_lookup(unsigned long pfn);
 
@@ -215,6 +216,11 @@ static inline struct page_ext *page_ext_get(const struct page *page)
 	return NULL;
 }
 
+static inline struct page_ext *page_ext_from_phys(phys_addr_t phys)
+{
+	return NULL;
+}
+
 static inline void page_ext_put(struct page_ext *page_ext)
 {
 }
diff --git a/include/uapi/linux/iommufd.h b/include/uapi/linux/iommufd.h
index 2c41920..1dafbc5 100644
--- a/include/uapi/linux/iommufd.h
+++ b/include/uapi/linux/iommufd.h
@@ -466,15 +466,26 @@ struct iommu_hwpt_arm_smmuv3 {
 };
 
 /**
+ * struct iommu_hwpt_amd_guest - AMD IOMMU guest I/O page table data
+ *				 (IOMMU_HWPT_DATA_AMD_GUEST)
+ * @dte: Guest Device Table Entry (DTE)
+ */
+struct iommu_hwpt_amd_guest {
+	__aligned_u64 dte[4];
+};
+
+/**
  * enum iommu_hwpt_data_type - IOMMU HWPT Data Type
  * @IOMMU_HWPT_DATA_NONE: no data
  * @IOMMU_HWPT_DATA_VTD_S1: Intel VT-d stage-1 page table
  * @IOMMU_HWPT_DATA_ARM_SMMUV3: ARM SMMUv3 Context Descriptor Table
+ * @IOMMU_HWPT_DATA_AMD_GUEST: AMD IOMMU guest page table
  */
 enum iommu_hwpt_data_type {
 	IOMMU_HWPT_DATA_NONE = 0,
 	IOMMU_HWPT_DATA_VTD_S1 = 1,
 	IOMMU_HWPT_DATA_ARM_SMMUV3 = 2,
+	IOMMU_HWPT_DATA_AMD_GUEST = 3,
 };
 
 /**
@@ -624,6 +635,32 @@ struct iommu_hw_info_tegra241_cmdqv {
 };
 
 /**
+ * struct iommu_hw_info_amd - AMD IOMMU device info
+ *
+ * @efr : Value of AMD IOMMU Extended Feature Register (EFR)
+ * @efr2: Value of AMD IOMMU Extended Feature 2 Register (EFR2)
+ *
+ * Please See description of these registers in the following sections of
+ * the AMD I/O Virtualization Technology (IOMMU) Specification.
+ * (https://docs.amd.com/v/u/en-US/48882_3.10_PUB)
+ *
+ * - MMIO Offset 0030h IOMMU Extended Feature Register
+ * - MMIO Offset 01A0h IOMMU Extended Feature 2 Register
+ *
+ * Note: The EFR and EFR2 are raw values reported by hardware.
+ * VMM is responsible to determine the appropriate flags to be exposed to
+ * the VM since cetertain features are not currently supported by the kernel
+ * for HW-vIOMMU.
+ *
+ * Current VMM-allowed list of feature flags are:
+ * - EFR[GTSup, GASup, GioSup, PPRSup, EPHSup, GATS, GLX, PASmax]
+ */
+struct iommu_hw_info_amd {
+	__aligned_u64 efr;
+	__aligned_u64 efr2;
+};
+
+/**
  * enum iommu_hw_info_type - IOMMU Hardware Info Types
  * @IOMMU_HW_INFO_TYPE_NONE: Output by the drivers that do not report hardware
  *                           info
@@ -632,6 +669,7 @@ struct iommu_hw_info_tegra241_cmdqv {
  * @IOMMU_HW_INFO_TYPE_ARM_SMMUV3: ARM SMMUv3 iommu info type
  * @IOMMU_HW_INFO_TYPE_TEGRA241_CMDQV: NVIDIA Tegra241 CMDQV (extension for ARM
  *                                     SMMUv3) info type
+ * @IOMMU_HW_INFO_TYPE_AMD: AMD IOMMU info type
  */
 enum iommu_hw_info_type {
 	IOMMU_HW_INFO_TYPE_NONE = 0,
@@ -639,6 +677,7 @@ enum iommu_hw_info_type {
 	IOMMU_HW_INFO_TYPE_INTEL_VTD = 1,
 	IOMMU_HW_INFO_TYPE_ARM_SMMUV3 = 2,
 	IOMMU_HW_INFO_TYPE_TEGRA241_CMDQV = 3,
+	IOMMU_HW_INFO_TYPE_AMD = 4,
 };
 
 /**
diff --git a/include/uapi/linux/vfio.h b/include/uapi/linux/vfio.h
index ac2329f..bb7b893 100644
--- a/include/uapi/linux/vfio.h
+++ b/include/uapi/linux/vfio.h
@@ -964,6 +964,10 @@ struct vfio_device_bind_iommufd {
  * hwpt corresponding to the given pt_id.
  *
  * Return: 0 on success, -errno on failure.
+ *
+ * When a device is resetting, -EBUSY will be returned to reject any concurrent
+ * attachment to the resetting device itself or any sibling device in the IOMMU
+ * group having the resetting device.
  */
 struct vfio_device_attach_iommufd_pt {
 	__u32	argsz;
diff --git a/mm/page_ext.c b/mm/page_ext.c
index d7396a8..e2e92bd 100644
--- a/mm/page_ext.c
+++ b/mm/page_ext.c
@@ -11,6 +11,7 @@
 #include <linux/page_table_check.h>
 #include <linux/rcupdate.h>
 #include <linux/pgalloc_tag.h>
+#include <linux/iommu-debug-pagealloc.h>
 
 /*
  * struct page extension
@@ -89,6 +90,9 @@ static struct page_ext_operations *page_ext_ops[] __initdata = {
 #ifdef CONFIG_PAGE_TABLE_CHECK
 	&page_table_check_ops,
 #endif
+#ifdef CONFIG_IOMMU_DEBUG_PAGEALLOC
+	&page_iommu_debug_ops,
+#endif
 };
 
 unsigned long page_ext_size;
@@ -535,6 +539,29 @@ struct page_ext *page_ext_get(const struct page *page)
 }
 
 /**
+ * page_ext_from_phys() - Get the page_ext structure for a physical address.
+ * @phys: The physical address to query.
+ *
+ * This function safely gets the `struct page_ext` associated with a given
+ * physical address. It performs validation to ensure the address corresponds
+ * to a valid, online struct page before attempting to access it.
+ * It returns NULL for MMIO, ZONE_DEVICE, holes and offline memory.
+ *
+ * Return: NULL if no page_ext exists for this physical address.
+ * Context: Any context.  Caller may not sleep until they have called
+ * page_ext_put().
+ */
+struct page_ext *page_ext_from_phys(phys_addr_t phys)
+{
+	struct page *page = pfn_to_online_page(__phys_to_pfn(phys));
+
+	if (!page)
+		return NULL;
+
+	return page_ext_get(page);
+}
+
+/**
  * page_ext_put() - Working with page extended information is done.
  * @page_ext: Page extended information received from page_ext_get().
  *
diff --git a/rust/bindings/bindings_helper.h b/rust/bindings/bindings_helper.h
index a067038..1b05a5e 100644
--- a/rust/bindings/bindings_helper.h
+++ b/rust/bindings/bindings_helper.h
@@ -56,9 +56,10 @@
 #include <linux/fdtable.h>
 #include <linux/file.h>
 #include <linux/firmware.h>
-#include <linux/interrupt.h>
 #include <linux/fs.h>
 #include <linux/i2c.h>
+#include <linux/interrupt.h>
+#include <linux/io-pgtable.h>
 #include <linux/ioport.h>
 #include <linux/jiffies.h>
 #include <linux/jump_label.h>
diff --git a/rust/kernel/iommu/mod.rs b/rust/kernel/iommu/mod.rs
new file mode 100644
index 0000000..1423d7b
--- /dev/null
+++ b/rust/kernel/iommu/mod.rs
@@ -0,0 +1,5 @@
+// SPDX-License-Identifier: GPL-2.0
+
+//! Rust support related to IOMMU.
+
+pub mod pgtable;
diff --git a/rust/kernel/iommu/pgtable.rs b/rust/kernel/iommu/pgtable.rs
new file mode 100644
index 0000000..6135ba1
--- /dev/null
+++ b/rust/kernel/iommu/pgtable.rs
@@ -0,0 +1,279 @@
+// SPDX-License-Identifier: GPL-2.0
+
+//! IOMMU page table management.
+//!
+//! C header: [`include/io-pgtable.h`](srctree/include/io-pgtable.h)
+
+use core::{
+    marker::PhantomData,
+    ptr::NonNull, //
+};
+
+use crate::{
+    alloc,
+    bindings,
+    device::{
+        Bound,
+        Device //
+    },
+    devres::Devres,
+    error::to_result,
+    io::PhysAddr,
+    prelude::*, //
+};
+
+use bindings::io_pgtable_fmt;
+
+/// Protection flags used with IOMMU mappings.
+pub mod prot {
+    /// Read access.
+    pub const READ: u32 = bindings::IOMMU_READ;
+    /// Write access.
+    pub const WRITE: u32 = bindings::IOMMU_WRITE;
+    /// Request cache coherency.
+    pub const CACHE: u32 = bindings::IOMMU_CACHE;
+    /// Request no-execute permission.
+    pub const NOEXEC: u32 = bindings::IOMMU_NOEXEC;
+    /// MMIO peripheral mapping.
+    pub const MMIO: u32 = bindings::IOMMU_MMIO;
+    /// Privileged mapping.
+    pub const PRIVILEGED: u32 = bindings::IOMMU_PRIV;
+}
+
+/// Represents a requested `io_pgtable` configuration.
+pub struct Config {
+    /// Quirk bitmask (type-specific).
+    pub quirks: usize,
+    /// Valid page sizes, as a bitmask of powers of two.
+    pub pgsize_bitmap: usize,
+    /// Input address space size in bits.
+    pub ias: u32,
+    /// Output address space size in bits.
+    pub oas: u32,
+    /// IOMMU uses coherent accesses for page table walks.
+    pub coherent_walk: bool,
+}
+
+/// An io page table using a specific format.
+///
+/// # Invariants
+///
+/// The pointer references a valid io page table.
+pub struct IoPageTable<F: IoPageTableFmt> {
+    ptr: NonNull<bindings::io_pgtable_ops>,
+    _marker: PhantomData<F>,
+}
+
+// SAFETY: `struct io_pgtable_ops` is not restricted to a single thread.
+unsafe impl<F: IoPageTableFmt> Send for IoPageTable<F> {}
+// SAFETY: `struct io_pgtable_ops` may be accessed concurrently.
+unsafe impl<F: IoPageTableFmt> Sync for IoPageTable<F> {}
+
+/// The format used by this page table.
+pub trait IoPageTableFmt: 'static {
+    /// The value representing this format.
+    const FORMAT: io_pgtable_fmt;
+}
+
+impl<F: IoPageTableFmt> IoPageTable<F> {
+    /// Create a new `IoPageTable` as a device resource.
+    #[inline]
+    pub fn new(
+        dev: &Device<Bound>,
+        config: Config,
+    ) -> impl PinInit<Devres<IoPageTable<F>>, Error> + '_ {
+        // SAFETY: Devres ensures that the value is dropped during device unbind.
+        Devres::new(dev, unsafe { Self::new_raw(dev, config) })
+    }
+
+    /// Create a new `IoPageTable`.
+    ///
+    /// # Safety
+    ///
+    /// If successful, then the returned `IoPageTable` must be dropped before the device is
+    /// unbound.
+    #[inline]
+    pub unsafe fn new_raw(dev: &Device<Bound>, config: Config) -> Result<IoPageTable<F>> {
+        let mut raw_cfg = bindings::io_pgtable_cfg {
+            quirks: config.quirks,
+            pgsize_bitmap: config.pgsize_bitmap,
+            ias: config.ias,
+            oas: config.oas,
+            coherent_walk: config.coherent_walk,
+            tlb: &raw const NOOP_FLUSH_OPS,
+            iommu_dev: dev.as_raw(),
+            // SAFETY: All zeroes is a valid value for `struct io_pgtable_cfg`.
+            ..unsafe { core::mem::zeroed() }
+        };
+
+        // SAFETY:
+        // * The raw_cfg pointer is valid for the duration of this call.
+        // * The provided `FLUSH_OPS` contains valid function pointers that accept a null pointer
+        //   as cookie.
+        // * The caller ensures that the io pgtable does not outlive the device.
+        let ops = unsafe {
+            bindings::alloc_io_pgtable_ops(F::FORMAT, &mut raw_cfg, core::ptr::null_mut())
+        };
+
+        // INVARIANT: We successfully created a valid page table.
+        Ok(IoPageTable {
+            ptr: NonNull::new(ops).ok_or(ENOMEM)?,
+            _marker: PhantomData,
+        })
+    }
+
+    /// Obtain a raw pointer to the underlying `struct io_pgtable_ops`.
+    #[inline]
+    pub fn raw_ops(&self) -> *mut bindings::io_pgtable_ops {
+        self.ptr.as_ptr()
+    }
+
+    /// Obtain a raw pointer to the underlying `struct io_pgtable`.
+    #[inline]
+    pub fn raw_pgtable(&self) -> *mut bindings::io_pgtable {
+        // SAFETY: The io_pgtable_ops of an io-pgtable is always the ops field of a io_pgtable.
+        unsafe { kernel::container_of!(self.raw_ops(), bindings::io_pgtable, ops) }
+    }
+
+    /// Obtain a raw pointer to the underlying `struct io_pgtable_cfg`.
+    #[inline]
+    pub fn raw_cfg(&self) -> *mut bindings::io_pgtable_cfg {
+        // SAFETY: The `raw_pgtable()` method returns a valid pointer.
+        unsafe { &raw mut (*self.raw_pgtable()).cfg }
+    }
+
+    /// Map a physically contiguous range of pages of the same size.
+    ///
+    /// Even if successful, this operation may not map the entire range. In that case, only a
+    /// prefix of the range is mapped, and the returned integer indicates its length in bytes. In
+    /// this case, the caller will usually call `map_pages` again for the remaining range.
+    ///
+    /// The returned [`Result`] indicates whether an error was encountered while mapping pages.
+    /// Note that this may return a non-zero length even if an error was encountered. The caller
+    /// will usually [unmap the relevant pages](Self::unmap_pages) on error.
+    ///
+    /// The caller must flush the TLB before using the pgtable to access the newly created mapping.
+    ///
+    /// # Safety
+    ///
+    /// * No other io-pgtable operation may access the range `iova .. iova+pgsize*pgcount` while
+    ///   this `map_pages` operation executes.
+    /// * This page table must not contain any mapping that overlaps with the mapping created by
+    ///   this call.
+    /// * If this page table is live, then the caller must ensure that it's okay to access the
+    ///   physical address being mapped for the duration in which it is mapped.
+    #[inline]
+    pub unsafe fn map_pages(
+        &self,
+        iova: usize,
+        paddr: PhysAddr,
+        pgsize: usize,
+        pgcount: usize,
+        prot: u32,
+        flags: alloc::Flags,
+    ) -> (usize, Result) {
+        let mut mapped: usize = 0;
+
+        // SAFETY: The `map_pages` function in `io_pgtable_ops` is never null.
+        let map_pages = unsafe { (*self.raw_ops()).map_pages.unwrap_unchecked() };
+
+        // SAFETY: The safety requirements of this method are sufficient to call `map_pages`.
+        let ret = to_result(unsafe {
+            (map_pages)(
+                self.raw_ops(),
+                iova,
+                paddr,
+                pgsize,
+                pgcount,
+                prot as i32,
+                flags.as_raw(),
+                &mut mapped,
+            )
+        });
+
+        (mapped, ret)
+    }
+
+    /// Unmap a range of virtually contiguous pages of the same size.
+    ///
+    /// This may not unmap the entire range, and returns the length of the unmapped prefix in
+    /// bytes.
+    ///
+    /// # Safety
+    ///
+    /// * No other io-pgtable operation may access the range `iova .. iova+pgsize*pgcount` while
+    ///   this `unmap_pages` operation executes.
+    /// * This page table must contain one or more consecutive mappings starting at `iova` whose
+    ///   total size is `pgcount * pgsize`.
+    #[inline]
+    #[must_use]
+    pub unsafe fn unmap_pages(&self, iova: usize, pgsize: usize, pgcount: usize) -> usize {
+        // SAFETY: The `unmap_pages` function in `io_pgtable_ops` is never null.
+        let unmap_pages = unsafe { (*self.raw_ops()).unmap_pages.unwrap_unchecked() };
+
+        // SAFETY: The safety requirements of this method are sufficient to call `unmap_pages`.
+        unsafe { (unmap_pages)(self.raw_ops(), iova, pgsize, pgcount, core::ptr::null_mut()) }
+    }
+}
+
+// For the initial users of these rust bindings, the GPU FW is managing the IOTLB and performs all
+// required invalidations using a range. There is no need for it get ARM style invalidation
+// instructions from the page table code.
+//
+// Support for flushing the TLB with ARM style invalidation instructions may be added in the
+// future.
+static NOOP_FLUSH_OPS: bindings::iommu_flush_ops = bindings::iommu_flush_ops {
+    tlb_flush_all: Some(rust_tlb_flush_all_noop),
+    tlb_flush_walk: Some(rust_tlb_flush_walk_noop),
+    tlb_add_page: None,
+};
+
+#[no_mangle]
+extern "C" fn rust_tlb_flush_all_noop(_cookie: *mut core::ffi::c_void) {}
+
+#[no_mangle]
+extern "C" fn rust_tlb_flush_walk_noop(
+    _iova: usize,
+    _size: usize,
+    _granule: usize,
+    _cookie: *mut core::ffi::c_void,
+) {
+}
+
+impl<F: IoPageTableFmt> Drop for IoPageTable<F> {
+    fn drop(&mut self) {
+        // SAFETY: The caller of `Self::ttbr()` promised that the page table is not live when this
+        // destructor runs.
+        unsafe { bindings::free_io_pgtable_ops(self.raw_ops()) };
+    }
+}
+
+/// The `ARM_64_LPAE_S1` page table format.
+pub enum ARM64LPAES1 {}
+
+impl IoPageTableFmt for ARM64LPAES1 {
+    const FORMAT: io_pgtable_fmt = bindings::io_pgtable_fmt_ARM_64_LPAE_S1 as io_pgtable_fmt;
+}
+
+impl IoPageTable<ARM64LPAES1> {
+    /// Access the `ttbr` field of the configuration.
+    ///
+    /// This is the physical address of the page table, which may be passed to the device that
+    /// needs to use it.
+    ///
+    /// # Safety
+    ///
+    /// The caller must ensure that the device stops using the page table before dropping it.
+    #[inline]
+    pub unsafe fn ttbr(&self) -> u64 {
+        // SAFETY: `arm_lpae_s1_cfg` is the right cfg type for `ARM64LPAES1`.
+        unsafe { (*self.raw_cfg()).__bindgen_anon_1.arm_lpae_s1_cfg.ttbr }
+    }
+
+    /// Access the `mair` field of the configuration.
+    #[inline]
+    pub fn mair(&self) -> u64 {
+        // SAFETY: `arm_lpae_s1_cfg` is the right cfg type for `ARM64LPAES1`.
+        unsafe { (*self.raw_cfg()).__bindgen_anon_1.arm_lpae_s1_cfg.mair }
+    }
+}
diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs
index f812cf1..e7fba6f 100644
--- a/rust/kernel/lib.rs
+++ b/rust/kernel/lib.rs
@@ -103,6 +103,7 @@
 pub mod init;
 pub mod io;
 pub mod ioctl;
+pub mod iommu;
 pub mod iov;
 pub mod irq;
 pub mod jump_label;