Merge branch 'next'
diff --git a/Documentation/arch/arm64/silicon-errata.rst b/Documentation/arch/arm64/silicon-errata.rst
index 4c300ca..c81d7fc 100644
--- a/Documentation/arch/arm64/silicon-errata.rst
+++ b/Documentation/arch/arm64/silicon-errata.rst
@@ -207,8 +207,14 @@
 +----------------+-----------------+-----------------+-----------------------------+
 | ARM            | MMU-600         | #1076982,1209401| N/A                         |
 +----------------+-----------------+-----------------+-----------------------------+
-| ARM            | MMU-700         | #2268618,2812531| N/A                         |
+| ARM            | MMU-700         | #2133013,       | N/A                         |
+|                |                 | #2268618,       |                             |
+|                |                 | #2812531,       |                             |
+|                |                 | #3777127        |                             |
 +----------------+-----------------+-----------------+-----------------------------+
+| ARM            | MMU L1          | #3878312        | N/A                         |
++----------------+-----------------+-----------------+-----------------------------+
+| ARM            | MMU S3          | #3995052        | N/A                         |
 +----------------+-----------------+-----------------+-----------------------------+
 | ARM            | GIC-700         | #2941627        | ARM64_ERRATUM_2941627       |
 +----------------+-----------------+-----------------+-----------------------------+
diff --git a/Documentation/devicetree/bindings/iommu/arm,smmu.yaml b/Documentation/devicetree/bindings/iommu/arm,smmu.yaml
index cdbd23b..27d25bc 100644
--- a/Documentation/devicetree/bindings/iommu/arm,smmu.yaml
+++ b/Documentation/devicetree/bindings/iommu/arm,smmu.yaml
@@ -35,6 +35,7 @@
       - description: Qcom SoCs implementing "qcom,smmu-500" and "arm,mmu-500"
         items:
           - enum:
+              - qcom,eliza-smmu-500
               - qcom,glymur-smmu-500
               - qcom,kaanapali-smmu-500
               - qcom,milos-smmu-500
diff --git a/drivers/iommu/amd/debugfs.c b/drivers/iommu/amd/debugfs.c
index 20b0499..4e66473 100644
--- a/drivers/iommu/amd/debugfs.c
+++ b/drivers/iommu/amd/debugfs.c
@@ -26,22 +26,19 @@ static ssize_t iommu_mmio_write(struct file *filp, const char __user *ubuf,
 {
 	struct seq_file *m = filp->private_data;
 	struct amd_iommu *iommu = m->private;
-	int ret;
-
-	iommu->dbg_mmio_offset = -1;
+	int ret, dbg_mmio_offset = iommu->dbg_mmio_offset = -1;
 
 	if (cnt > OFS_IN_SZ)
 		return -EINVAL;
 
-	ret = kstrtou32_from_user(ubuf, cnt, 0, &iommu->dbg_mmio_offset);
+	ret = kstrtou32_from_user(ubuf, cnt, 0, &dbg_mmio_offset);
 	if (ret)
 		return ret;
 
-	if (iommu->dbg_mmio_offset > iommu->mmio_phys_end - sizeof(u64)) {
-		iommu->dbg_mmio_offset = -1;
-		return  -EINVAL;
-	}
+	if (dbg_mmio_offset > iommu->mmio_phys_end - sizeof(u64))
+		return -EINVAL;
 
+	iommu->dbg_mmio_offset = dbg_mmio_offset;
 	return cnt;
 }
 
@@ -49,14 +46,16 @@ static int iommu_mmio_show(struct seq_file *m, void *unused)
 {
 	struct amd_iommu *iommu = m->private;
 	u64 value;
+	int dbg_mmio_offset = iommu->dbg_mmio_offset;
 
-	if (iommu->dbg_mmio_offset < 0) {
+	if (dbg_mmio_offset < 0 || dbg_mmio_offset >
+			iommu->mmio_phys_end - sizeof(u64)) {
 		seq_puts(m, "Please provide mmio register's offset\n");
 		return 0;
 	}
 
-	value = readq(iommu->mmio_base + iommu->dbg_mmio_offset);
-	seq_printf(m, "Offset:0x%x Value:0x%016llx\n", iommu->dbg_mmio_offset, value);
+	value = readq(iommu->mmio_base + dbg_mmio_offset);
+	seq_printf(m, "Offset:0x%x Value:0x%016llx\n", dbg_mmio_offset, value);
 
 	return 0;
 }
@@ -67,23 +66,20 @@ static ssize_t iommu_capability_write(struct file *filp, const char __user *ubuf
 {
 	struct seq_file *m = filp->private_data;
 	struct amd_iommu *iommu = m->private;
-	int ret;
-
-	iommu->dbg_cap_offset = -1;
+	int ret, dbg_cap_offset = iommu->dbg_cap_offset = -1;
 
 	if (cnt > OFS_IN_SZ)
 		return -EINVAL;
 
-	ret = kstrtou32_from_user(ubuf, cnt, 0, &iommu->dbg_cap_offset);
+	ret = kstrtou32_from_user(ubuf, cnt, 0, &dbg_cap_offset);
 	if (ret)
 		return ret;
 
 	/* Capability register at offset 0x14 is the last IOMMU capability register. */
-	if (iommu->dbg_cap_offset > 0x14) {
-		iommu->dbg_cap_offset = -1;
+	if (dbg_cap_offset > 0x14)
 		return -EINVAL;
-	}
 
+	iommu->dbg_cap_offset = dbg_cap_offset;
 	return cnt;
 }
 
@@ -91,21 +87,21 @@ static int iommu_capability_show(struct seq_file *m, void *unused)
 {
 	struct amd_iommu *iommu = m->private;
 	u32 value;
-	int err;
+	int err, dbg_cap_offset = iommu->dbg_cap_offset;
 
-	if (iommu->dbg_cap_offset < 0) {
+	if (dbg_cap_offset < 0 || dbg_cap_offset > 0x14) {
 		seq_puts(m, "Please provide capability register's offset in the range [0x00 - 0x14]\n");
 		return 0;
 	}
 
-	err = pci_read_config_dword(iommu->dev, iommu->cap_ptr + iommu->dbg_cap_offset, &value);
+	err = pci_read_config_dword(iommu->dev, iommu->cap_ptr + dbg_cap_offset, &value);
 	if (err) {
 		seq_printf(m, "Not able to read capability register at 0x%x\n",
-			   iommu->dbg_cap_offset);
+			   dbg_cap_offset);
 		return 0;
 	}
 
-	seq_printf(m, "Offset:0x%x Value:0x%08x\n", iommu->dbg_cap_offset, value);
+	seq_printf(m, "Offset:0x%x Value:0x%08x\n", dbg_cap_offset, value);
 
 	return 0;
 }
@@ -197,10 +193,11 @@ static ssize_t devid_write(struct file *filp, const char __user *ubuf,
 static int devid_show(struct seq_file *m, void *unused)
 {
 	u16 devid;
+	int sbdf_shadow = sbdf;
 
-	if (sbdf >= 0) {
-		devid = PCI_SBDF_TO_DEVID(sbdf);
-		seq_printf(m, "%04x:%02x:%02x.%x\n", PCI_SBDF_TO_SEGID(sbdf),
+	if (sbdf_shadow >= 0) {
+		devid = PCI_SBDF_TO_DEVID(sbdf_shadow);
+		seq_printf(m, "%04x:%02x:%02x.%x\n", PCI_SBDF_TO_SEGID(sbdf_shadow),
 			   PCI_BUS_NUM(devid), PCI_SLOT(devid), PCI_FUNC(devid));
 	} else
 		seq_puts(m, "No or Invalid input provided\n");
@@ -237,13 +234,14 @@ static int iommu_devtbl_show(struct seq_file *m, void *unused)
 {
 	struct amd_iommu_pci_seg *pci_seg;
 	u16 seg, devid;
+	int sbdf_shadow = sbdf;
 
-	if (sbdf < 0) {
+	if (sbdf_shadow < 0) {
 		seq_puts(m, "Enter a valid device ID to 'devid' file\n");
 		return 0;
 	}
-	seg = PCI_SBDF_TO_SEGID(sbdf);
-	devid = PCI_SBDF_TO_DEVID(sbdf);
+	seg = PCI_SBDF_TO_SEGID(sbdf_shadow);
+	devid = PCI_SBDF_TO_DEVID(sbdf_shadow);
 
 	for_each_pci_segment(pci_seg) {
 		if (pci_seg->id != seg)
@@ -336,19 +334,20 @@ static int iommu_irqtbl_show(struct seq_file *m, void *unused)
 {
 	struct amd_iommu_pci_seg *pci_seg;
 	u16 devid, seg;
+	int sbdf_shadow = sbdf;
 
 	if (!irq_remapping_enabled) {
 		seq_puts(m, "Interrupt remapping is disabled\n");
 		return 0;
 	}
 
-	if (sbdf < 0) {
+	if (sbdf_shadow < 0) {
 		seq_puts(m, "Enter a valid device ID to 'devid' file\n");
 		return 0;
 	}
 
-	seg = PCI_SBDF_TO_SEGID(sbdf);
-	devid = PCI_SBDF_TO_DEVID(sbdf);
+	seg = PCI_SBDF_TO_SEGID(sbdf_shadow);
+	devid = PCI_SBDF_TO_DEVID(sbdf_shadow);
 
 	for_each_pci_segment(pci_seg) {
 		if (pci_seg->id != seg)
diff --git a/drivers/iommu/amd/init.c b/drivers/iommu/amd/init.c
index f3fd7f3..56ad020 100644
--- a/drivers/iommu/amd/init.c
+++ b/drivers/iommu/amd/init.c
@@ -848,10 +848,11 @@ static void __init free_command_buffer(struct amd_iommu *iommu)
 void *__init iommu_alloc_4k_pages(struct amd_iommu *iommu, gfp_t gfp,
 				  size_t size)
 {
+	int nid = iommu->dev ? dev_to_node(&iommu->dev->dev) : NUMA_NO_NODE;
 	void *buf;
 
 	size = PAGE_ALIGN(size);
-	buf = iommu_alloc_pages_sz(gfp, size);
+	buf = iommu_alloc_pages_node_sz(nid, gfp, size);
 	if (!buf)
 		return NULL;
 	if (check_feature(FEATURE_SNP) &&
@@ -954,14 +955,16 @@ static int iommu_ga_log_enable(struct amd_iommu *iommu)
 
 static int iommu_init_ga_log(struct amd_iommu *iommu)
 {
+	int nid = iommu->dev ? dev_to_node(&iommu->dev->dev) : NUMA_NO_NODE;
+
 	if (!AMD_IOMMU_GUEST_IR_VAPIC(amd_iommu_guest_ir))
 		return 0;
 
-	iommu->ga_log = iommu_alloc_pages_sz(GFP_KERNEL, GA_LOG_SIZE);
+	iommu->ga_log = iommu_alloc_pages_node_sz(nid, GFP_KERNEL, GA_LOG_SIZE);
 	if (!iommu->ga_log)
 		goto err_out;
 
-	iommu->ga_log_tail = iommu_alloc_pages_sz(GFP_KERNEL, 8);
+	iommu->ga_log_tail = iommu_alloc_pages_node_sz(nid, GFP_KERNEL, 8);
 	if (!iommu->ga_log_tail)
 		goto err_out;
 
diff --git a/drivers/iommu/amd/iommu.c b/drivers/iommu/amd/iommu.c
index 760d5f46..0117136 100644
--- a/drivers/iommu/amd/iommu.c
+++ b/drivers/iommu/amd/iommu.c
@@ -403,11 +403,12 @@ struct iommu_dev_data *search_dev_data(struct amd_iommu *iommu, u16 devid)
 	return NULL;
 }
 
-static int clone_alias(struct pci_dev *pdev, u16 alias, void *data)
+static int clone_alias(struct pci_dev *pdev_origin, u16 alias, void *data)
 {
 	struct dev_table_entry new;
 	struct amd_iommu *iommu;
 	struct iommu_dev_data *dev_data, *alias_data;
+	struct pci_dev *pdev = data;
 	u16 devid = pci_dev_id(pdev);
 	int ret = 0;
 
@@ -454,9 +455,9 @@ static void clone_aliases(struct amd_iommu *iommu, struct device *dev)
 	 * part of the PCI DMA aliases if it's bus differs
 	 * from the original device.
 	 */
-	clone_alias(pdev, iommu->pci_seg->alias_table[pci_dev_id(pdev)], NULL);
+	clone_alias(pdev, iommu->pci_seg->alias_table[pci_dev_id(pdev)], pdev);
 
-	pci_for_each_dma_alias(pdev, clone_alias, NULL);
+	pci_for_each_dma_alias(pdev, clone_alias, pdev);
 }
 
 static void setup_aliases(struct amd_iommu *iommu, struct device *dev)
@@ -2991,13 +2992,17 @@ static bool amd_iommu_capable(struct device *dev, enum iommu_cap cap)
 		return amdr_ivrs_remap_support;
 	case IOMMU_CAP_ENFORCE_CACHE_COHERENCY:
 		return true;
-	case IOMMU_CAP_DEFERRED_FLUSH:
-		return true;
 	case IOMMU_CAP_DIRTY_TRACKING: {
 		struct amd_iommu *iommu = get_amd_iommu_from_dev(dev);
 
 		return amd_iommu_hd_support(iommu);
 	}
+	case IOMMU_CAP_PCI_ATS_SUPPORTED: {
+		struct iommu_dev_data *dev_data = dev_iommu_priv_get(dev);
+
+		return amd_iommu_iotlb_sup &&
+			 (dev_data->flags & AMD_IOMMU_DEVICE_FLAG_ATS_SUP);
+	}
 	default:
 		break;
 	}
@@ -3179,26 +3184,44 @@ const struct iommu_ops amd_iommu_ops = {
 static struct irq_chip amd_ir_chip;
 static DEFINE_SPINLOCK(iommu_table_lock);
 
+static int iommu_flush_dev_irt(struct pci_dev *unused, u16 devid, void *data)
+{
+	int ret;
+	struct iommu_cmd cmd;
+	struct amd_iommu *iommu = data;
+
+	build_inv_irt(&cmd, devid);
+	ret = __iommu_queue_command_sync(iommu, &cmd, true);
+	return ret;
+}
+
 static void iommu_flush_irt_and_complete(struct amd_iommu *iommu, u16 devid)
 {
 	int ret;
 	u64 data;
 	unsigned long flags;
-	struct iommu_cmd cmd, cmd2;
+	struct iommu_cmd cmd;
+	struct pci_dev *pdev = NULL;
+	struct iommu_dev_data *dev_data = search_dev_data(iommu, devid);
 
 	if (iommu->irtcachedis_enabled)
 		return;
 
-	build_inv_irt(&cmd, devid);
+	if (dev_data && dev_data->dev && dev_is_pci(dev_data->dev))
+		pdev = to_pci_dev(dev_data->dev);
 
 	raw_spin_lock_irqsave(&iommu->lock, flags);
 	data = get_cmdsem_val(iommu);
-	build_completion_wait(&cmd2, iommu, data);
+	build_completion_wait(&cmd, iommu, data);
 
-	ret = __iommu_queue_command_sync(iommu, &cmd, true);
+	if (pdev)
+		ret = pci_for_each_dma_alias(pdev, iommu_flush_dev_irt, iommu);
+	else
+		ret = iommu_flush_dev_irt(NULL, devid, iommu);
 	if (ret)
 		goto out_err;
-	ret = __iommu_queue_command_sync(iommu, &cmd2, false);
+
+	ret = __iommu_queue_command_sync(iommu, &cmd, false);
 	if (ret)
 		goto out_err;
 	raw_spin_unlock_irqrestore(&iommu->lock, flags);
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
index 59a4809..f1f8e01a 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
@@ -122,15 +122,6 @@ void arm_smmu_make_sva_cd(struct arm_smmu_cd *target,
 }
 EXPORT_SYMBOL_IF_KUNIT(arm_smmu_make_sva_cd);
 
-/*
- * Cloned from the MAX_TLBI_OPS in arch/arm64/include/asm/tlbflush.h, this
- * is used as a threshold to replace per-page TLBI commands to issue in the
- * command queue with an address-space TLBI command, when SMMU w/o a range
- * invalidation feature handles too many per-page TLBI commands, which will
- * otherwise result in a soft lockup.
- */
-#define CMDQ_MAX_TLBI_OPS		(1 << (PAGE_SHIFT - 3))
-
 static void arm_smmu_mm_arch_invalidate_secondary_tlbs(struct mmu_notifier *mn,
 						struct mm_struct *mm,
 						unsigned long start,
@@ -146,21 +137,8 @@ static void arm_smmu_mm_arch_invalidate_secondary_tlbs(struct mmu_notifier *mn,
 	 * range. So do a simple translation here by calculating size correctly.
 	 */
 	size = end - start;
-	if (!(smmu_domain->smmu->features & ARM_SMMU_FEAT_RANGE_INV)) {
-		if (size >= CMDQ_MAX_TLBI_OPS * PAGE_SIZE)
-			size = 0;
-	} else {
-		if (size == ULONG_MAX)
-			size = 0;
-	}
 
-	if (!size)
-		arm_smmu_tlb_inv_asid(smmu_domain->smmu, smmu_domain->cd.asid);
-	else
-		arm_smmu_tlb_inv_range_asid(start, size, smmu_domain->cd.asid,
-					    PAGE_SIZE, false, smmu_domain);
-
-	arm_smmu_atc_inv_domain(smmu_domain, start, size);
+	arm_smmu_domain_inv_range(smmu_domain, start, size, PAGE_SIZE, false);
 }
 
 static void arm_smmu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
@@ -191,13 +169,13 @@ static void arm_smmu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 	}
 	spin_unlock_irqrestore(&smmu_domain->devices_lock, flags);
 
-	arm_smmu_tlb_inv_asid(smmu_domain->smmu, smmu_domain->cd.asid);
-	arm_smmu_atc_inv_domain(smmu_domain, 0, 0);
+	arm_smmu_domain_inv(smmu_domain);
 }
 
 static void arm_smmu_mmu_notifier_free(struct mmu_notifier *mn)
 {
-	kfree(container_of(mn, struct arm_smmu_domain, mmu_notifier));
+	arm_smmu_domain_free(
+		container_of(mn, struct arm_smmu_domain, mmu_notifier));
 }
 
 static const struct mmu_notifier_ops arm_smmu_mmu_notifier_ops = {
@@ -301,7 +279,7 @@ static void arm_smmu_sva_domain_free(struct iommu_domain *domain)
 	/*
 	 * Ensure the ASID is empty in the iommu cache before allowing reuse.
 	 */
-	arm_smmu_tlb_inv_asid(smmu_domain->smmu, smmu_domain->cd.asid);
+	arm_smmu_domain_inv(smmu_domain);
 
 	/*
 	 * Notice that the arm_smmu_mm_arch_invalidate_secondary_tlbs op can
@@ -346,6 +324,7 @@ struct iommu_domain *arm_smmu_sva_domain_alloc(struct device *dev,
 	 * ARM_SMMU_FEAT_RANGE_INV is present
 	 */
 	smmu_domain->domain.pgsize_bitmap = PAGE_SIZE;
+	smmu_domain->stage = ARM_SMMU_DOMAIN_SVA;
 	smmu_domain->smmu = smmu;
 
 	ret = xa_alloc(&arm_smmu_asid_xa, &asid, smmu_domain,
@@ -364,6 +343,6 @@ struct iommu_domain *arm_smmu_sva_domain_alloc(struct device *dev,
 err_asid:
 	xa_erase(&arm_smmu_asid_xa, smmu_domain->cd.asid);
 err_free:
-	kfree(smmu_domain);
+	arm_smmu_domain_free(smmu_domain);
 	return ERR_PTR(ret);
 }
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-test.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-test.c
index 69c9ef4..add6713 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-test.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-test.c
@@ -637,6 +637,140 @@ static void arm_smmu_v3_write_cd_test_sva_release(struct kunit *test)
 						      NUM_EXPECTED_SYNCS(2));
 }
 
+static void arm_smmu_v3_invs_test_verify(struct kunit *test,
+					 struct arm_smmu_invs *invs,
+					 int num_invs, const int num_trashes,
+					 const int *ids, const int *users,
+					 const int *ssids)
+{
+	KUNIT_EXPECT_EQ(test, invs->num_invs, num_invs);
+	KUNIT_EXPECT_EQ(test, invs->num_trashes, num_trashes);
+	while (num_invs--) {
+		KUNIT_EXPECT_EQ(test, invs->inv[num_invs].id, ids[num_invs]);
+		KUNIT_EXPECT_EQ(test, READ_ONCE(invs->inv[num_invs].users),
+				users[num_invs]);
+		KUNIT_EXPECT_EQ(test, invs->inv[num_invs].ssid, ssids[num_invs]);
+	}
+}
+
+static struct arm_smmu_invs invs1 = {
+	.num_invs = 3,
+	.inv = { { .type = INV_TYPE_S2_VMID, .id = 1, },
+		 { .type = INV_TYPE_S2_VMID_S1_CLEAR, .id = 1, },
+		 { .type = INV_TYPE_ATS, .id = 3, }, },
+};
+
+static struct arm_smmu_invs invs2 = {
+	.num_invs = 3,
+	.inv = { { .type = INV_TYPE_S2_VMID, .id = 1, }, /* duplicated */
+		 { .type = INV_TYPE_ATS, .id = 4, },
+		 { .type = INV_TYPE_ATS, .id = 5, }, },
+};
+
+static struct arm_smmu_invs invs3 = {
+	.num_invs = 3,
+	.inv = { { .type = INV_TYPE_S2_VMID, .id = 1, }, /* duplicated */
+		 { .type = INV_TYPE_ATS, .id = 5, }, /* recover a trash */
+		 { .type = INV_TYPE_ATS, .id = 6, }, },
+};
+
+static struct arm_smmu_invs invs4 = {
+	.num_invs = 3,
+	.inv = { { .type = INV_TYPE_ATS, .id = 10, .ssid = 1 },
+		 { .type = INV_TYPE_ATS, .id = 10, .ssid = 3 },
+		 { .type = INV_TYPE_ATS, .id = 12, .ssid = 1 }, },
+};
+
+static struct arm_smmu_invs invs5 = {
+	.num_invs = 3,
+	.inv = { { .type = INV_TYPE_ATS, .id = 10, .ssid = 2 },
+		 { .type = INV_TYPE_ATS, .id = 10, .ssid = 3 }, /* duplicate */
+		 { .type = INV_TYPE_ATS, .id = 12, .ssid = 2 }, },
+};
+
+static void arm_smmu_v3_invs_test(struct kunit *test)
+{
+	const int results1[3][3] = { { 1, 1, 3, }, { 1, 1, 1, }, { 0, 0, 0, } };
+	const int results2[3][5] = { { 1, 1, 3, 4, 5, }, { 2, 1, 1, 1, 1, }, { 0, 0, 0, 0, 0, } };
+	const int results3[3][3] = { { 1, 1, 3, }, { 1, 1, 1, }, { 0, 0, 0, } };
+	const int results4[3][5] = { { 1, 1, 3, 5, 6, }, { 2, 1, 1, 1, 1, }, { 0, 0, 0, 0, 0, } };
+	const int results5[3][5] = { { 1, 1, 3, 5, 6, }, { 1, 0, 0, 1, 1, }, { 0, 0, 0, 0, 0, } };
+	const int results6[3][3] = { { 1, 5, 6, }, { 1, 1, 1, }, { 0, 0, 0, } };
+	const int results7[3][3] = { { 10, 10, 12, }, { 1, 1, 1, }, { 1, 3, 1, } };
+	const int results8[3][5] = { { 10, 10, 10, 12, 12, }, { 1, 1, 2, 1, 1, }, { 1, 2, 3, 1, 2, } };
+	const int results9[3][4] = { { 10, 10, 10, 12, }, { 1, 0, 1, 1, }, { 1, 2, 3, 1, } };
+	const int results10[3][3] = { { 10, 10, 12, }, { 1, 1, 1, }, { 1, 3, 1, } };
+	struct arm_smmu_invs *test_a, *test_b;
+
+	/* New array */
+	test_a = arm_smmu_invs_alloc(0);
+	KUNIT_EXPECT_EQ(test, test_a->num_invs, 0);
+
+	/* Test1: merge invs1 (new array) */
+	test_b = arm_smmu_invs_merge(test_a, &invs1);
+	kfree(test_a);
+	arm_smmu_v3_invs_test_verify(test, test_b, ARRAY_SIZE(results1[0]), 0,
+				     results1[0], results1[1], results1[2]);
+
+	/* Test2: merge invs2 (new array) */
+	test_a = arm_smmu_invs_merge(test_b, &invs2);
+	kfree(test_b);
+	arm_smmu_v3_invs_test_verify(test, test_a, ARRAY_SIZE(results2[0]), 0,
+				     results2[0], results2[1], results2[2]);
+
+	/* Test3: unref invs2 (same array) */
+	arm_smmu_invs_unref(test_a, &invs2);
+	arm_smmu_v3_invs_test_verify(test, test_a, ARRAY_SIZE(results3[0]), 0,
+				     results3[0], results3[1], results3[2]);
+
+	/* Test4: merge invs3 (new array) */
+	test_b = arm_smmu_invs_merge(test_a, &invs3);
+	kfree(test_a);
+	arm_smmu_v3_invs_test_verify(test, test_b, ARRAY_SIZE(results4[0]), 0,
+				     results4[0], results4[1], results4[2]);
+
+	/* Test5: unref invs1 (same array) */
+	arm_smmu_invs_unref(test_b, &invs1);
+	arm_smmu_v3_invs_test_verify(test, test_b, ARRAY_SIZE(results5[0]), 2,
+				     results5[0], results5[1], results5[2]);
+
+	/* Test6: purge test_b (new array) */
+	test_a = arm_smmu_invs_purge(test_b);
+	kfree(test_b);
+	arm_smmu_v3_invs_test_verify(test, test_a, ARRAY_SIZE(results6[0]), 0,
+				     results6[0], results6[1], results6[2]);
+
+	/* Test7: unref invs3 (same array) */
+	arm_smmu_invs_unref(test_a, &invs3);
+	KUNIT_EXPECT_EQ(test, test_a->num_invs, 0);
+	KUNIT_EXPECT_EQ(test, test_a->num_trashes, 0);
+
+	/* Test8: merge invs4 (new array) */
+	test_b = arm_smmu_invs_merge(test_a, &invs4);
+	kfree(test_a);
+	arm_smmu_v3_invs_test_verify(test, test_b, ARRAY_SIZE(results7[0]), 0,
+				     results7[0], results7[1], results7[2]);
+
+	/* Test9: merge invs5 (new array) */
+	test_a = arm_smmu_invs_merge(test_b, &invs5);
+	kfree(test_b);
+	arm_smmu_v3_invs_test_verify(test, test_a, ARRAY_SIZE(results8[0]), 0,
+				     results8[0], results8[1], results8[2]);
+
+	/* Test10: unref invs5 (same array) */
+	arm_smmu_invs_unref(test_a, &invs5);
+	arm_smmu_v3_invs_test_verify(test, test_a, ARRAY_SIZE(results9[0]), 1,
+				     results9[0], results9[1], results9[2]);
+
+	/* Test11: purge test_a (new array) */
+	test_b = arm_smmu_invs_purge(test_a);
+	kfree(test_a);
+	arm_smmu_v3_invs_test_verify(test, test_b, ARRAY_SIZE(results10[0]), 0,
+				     results10[0], results10[1], results10[2]);
+
+	kfree(test_b);
+}
+
 static struct kunit_case arm_smmu_v3_test_cases[] = {
 	KUNIT_CASE(arm_smmu_v3_write_ste_test_bypass_to_abort),
 	KUNIT_CASE(arm_smmu_v3_write_ste_test_abort_to_bypass),
@@ -662,6 +796,7 @@ static struct kunit_case arm_smmu_v3_test_cases[] = {
 	KUNIT_CASE(arm_smmu_v3_write_ste_test_nested_s1bypass_to_s1dssbypass),
 	KUNIT_CASE(arm_smmu_v3_write_cd_test_sva_clear),
 	KUNIT_CASE(arm_smmu_v3_write_cd_test_sva_release),
+	KUNIT_CASE(arm_smmu_v3_invs_test),
 	{},
 };
 
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
index 4d00d79..e8d7dbe 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.c
@@ -26,6 +26,7 @@
 #include <linux/pci.h>
 #include <linux/pci-ats.h>
 #include <linux/platform_device.h>
+#include <linux/sort.h>
 #include <linux/string_choices.h>
 #include <kunit/visibility.h>
 #include <uapi/linux/iommufd.h>
@@ -107,6 +108,7 @@ static const char * const event_class_str[] = {
 };
 
 static int arm_smmu_alloc_cd_tables(struct arm_smmu_master *master);
+static bool arm_smmu_ats_supported(struct arm_smmu_master *master);
 
 static void parse_driver_options(struct arm_smmu_device *smmu)
 {
@@ -1026,18 +1028,269 @@ static void arm_smmu_page_response(struct device *dev, struct iopf_fault *unused
 	 */
 }
 
-/* Context descriptor manipulation functions */
-void arm_smmu_tlb_inv_asid(struct arm_smmu_device *smmu, u16 asid)
+/* Invalidation array manipulation functions */
+static inline struct arm_smmu_inv *
+arm_smmu_invs_iter_next(struct arm_smmu_invs *invs, size_t next, size_t *idx)
 {
-	struct arm_smmu_cmdq_ent cmd = {
-		.opcode	= smmu->features & ARM_SMMU_FEAT_E2H ?
-			CMDQ_OP_TLBI_EL2_ASID : CMDQ_OP_TLBI_NH_ASID,
-		.tlbi.asid = asid,
-	};
-
-	arm_smmu_cmdq_issue_cmd_with_sync(smmu, &cmd);
+	while (true) {
+		if (next >= invs->num_invs) {
+			*idx = next;
+			return NULL;
+		}
+		if (!READ_ONCE(invs->inv[next].users)) {
+			next++;
+			continue;
+		}
+		*idx = next;
+		return &invs->inv[next];
+	}
 }
 
+/**
+ * arm_smmu_invs_for_each_entry - Iterate over all non-trash entries in invs
+ * @invs: the base invalidation array
+ * @idx: a stack variable of 'size_t', to store the array index
+ * @cur: a stack variable of 'struct arm_smmu_inv *'
+ */
+#define arm_smmu_invs_for_each_entry(invs, idx, cur)                           \
+	for (cur = arm_smmu_invs_iter_next(invs, 0, &(idx)); cur;              \
+	     cur = arm_smmu_invs_iter_next(invs, idx + 1, &(idx)))
+
+static int arm_smmu_inv_cmp(const struct arm_smmu_inv *inv_l,
+			    const struct arm_smmu_inv *inv_r)
+{
+	if (inv_l->smmu != inv_r->smmu)
+		return cmp_int((uintptr_t)inv_l->smmu, (uintptr_t)inv_r->smmu);
+	if (inv_l->type != inv_r->type)
+		return cmp_int(inv_l->type, inv_r->type);
+	if (inv_l->id != inv_r->id)
+		return cmp_int(inv_l->id, inv_r->id);
+	if (arm_smmu_inv_is_ats(inv_l))
+		return cmp_int(inv_l->ssid, inv_r->ssid);
+	return 0;
+}
+
+static inline int arm_smmu_invs_iter_next_cmp(struct arm_smmu_invs *invs_l,
+					      size_t next_l, size_t *idx_l,
+					      struct arm_smmu_invs *invs_r,
+					      size_t next_r, size_t *idx_r)
+{
+	struct arm_smmu_inv *cur_l =
+		arm_smmu_invs_iter_next(invs_l, next_l, idx_l);
+
+	/*
+	 * We have to update the idx_r manually, because the invs_r cannot call
+	 * arm_smmu_invs_iter_next() as the invs_r never sets any users counter.
+	 */
+	*idx_r = next_r;
+
+	/*
+	 * Compare of two sorted arrays items. If one side is past the end of
+	 * the array, return the other side to let it run out the iteration.
+	 *
+	 * If the left entry is empty, return 1 to pick the right entry.
+	 * If the right entry is empty, return -1 to pick the left entry.
+	 */
+	if (!cur_l)
+		return 1;
+	if (next_r >= invs_r->num_invs)
+		return -1;
+	return arm_smmu_inv_cmp(cur_l, &invs_r->inv[next_r]);
+}
+
+/**
+ * arm_smmu_invs_for_each_cmp - Iterate over two sorted arrays computing for
+ *                              arm_smmu_invs_merge() or arm_smmu_invs_unref()
+ * @invs_l: the base invalidation array
+ * @idx_l: a stack variable of 'size_t', to store the base array index
+ * @invs_r: the build_invs array as to_merge or to_unref
+ * @idx_r: a stack variable of 'size_t', to store the build_invs index
+ * @cmp: a stack variable of 'int', to store return value (-1, 0, or 1)
+ */
+#define arm_smmu_invs_for_each_cmp(invs_l, idx_l, invs_r, idx_r, cmp)          \
+	for (idx_l = idx_r = 0,                                                \
+	     cmp = arm_smmu_invs_iter_next_cmp(invs_l, 0, &(idx_l),            \
+					       invs_r, 0, &(idx_r));           \
+	     idx_l < invs_l->num_invs || idx_r < invs_r->num_invs;             \
+	     cmp = arm_smmu_invs_iter_next_cmp(                                \
+		     invs_l, idx_l + (cmp <= 0 ? 1 : 0), &(idx_l),             \
+		     invs_r, idx_r + (cmp >= 0 ? 1 : 0), &(idx_r)))
+
+/**
+ * arm_smmu_invs_merge() - Merge @to_merge into @invs and generate a new array
+ * @invs: the base invalidation array
+ * @to_merge: an array of invalidations to merge
+ *
+ * Return: a newly allocated array on success, or ERR_PTR
+ *
+ * This function must be locked and serialized with arm_smmu_invs_unref() and
+ * arm_smmu_invs_purge(), but do not lockdep on any lock for KUNIT test.
+ *
+ * Both @invs and @to_merge must be sorted, to ensure the returned array will be
+ * sorted as well.
+ *
+ * Caller is responsible for freeing the @invs and the returned new one.
+ *
+ * Entries marked as trash will be purged in the returned array.
+ */
+VISIBLE_IF_KUNIT
+struct arm_smmu_invs *arm_smmu_invs_merge(struct arm_smmu_invs *invs,
+					  struct arm_smmu_invs *to_merge)
+{
+	struct arm_smmu_invs *new_invs;
+	struct arm_smmu_inv *new;
+	size_t num_invs = 0;
+	size_t i, j;
+	int cmp;
+
+	arm_smmu_invs_for_each_cmp(invs, i, to_merge, j, cmp)
+		num_invs++;
+
+	new_invs = arm_smmu_invs_alloc(num_invs);
+	if (!new_invs)
+		return ERR_PTR(-ENOMEM);
+
+	new = new_invs->inv;
+	arm_smmu_invs_for_each_cmp(invs, i, to_merge, j, cmp) {
+		if (cmp < 0) {
+			*new = invs->inv[i];
+		} else if (cmp == 0) {
+			*new = invs->inv[i];
+			WRITE_ONCE(new->users, READ_ONCE(new->users) + 1);
+		} else {
+			*new = to_merge->inv[j];
+			WRITE_ONCE(new->users, 1);
+		}
+
+		/*
+		 * Check that the new array is sorted. This also validates that
+		 * to_merge is sorted.
+		 */
+		if (new != new_invs->inv)
+			WARN_ON_ONCE(arm_smmu_inv_cmp(new - 1, new) == 1);
+		if (arm_smmu_inv_is_ats(new))
+			new_invs->has_ats = true;
+		new++;
+	}
+
+	WARN_ON(new != new_invs->inv + new_invs->num_invs);
+
+	return new_invs;
+}
+EXPORT_SYMBOL_IF_KUNIT(arm_smmu_invs_merge);
+
+/**
+ * arm_smmu_invs_unref() - Find in @invs for all entries in @to_unref, decrease
+ *                         the user counts without deletions
+ * @invs: the base invalidation array
+ * @to_unref: an array of invalidations to decrease their user counts
+ *
+ * Return: the number of trash entries in the array, for arm_smmu_invs_purge()
+ *
+ * This function will not fail. Any entry with users=0 will be marked as trash,
+ * and caller will be notified about the trashed entry via @to_unref by setting
+ * a users=0.
+ *
+ * All tailing trash entries in the array will be dropped. And the size of the
+ * array will be trimmed properly. All trash entries in-between will remain in
+ * the @invs until being completely deleted by the next arm_smmu_invs_merge()
+ * or an arm_smmu_invs_purge() function call.
+ *
+ * This function must be locked and serialized with arm_smmu_invs_merge() and
+ * arm_smmu_invs_purge(), but do not lockdep on any mutex for KUNIT test.
+ *
+ * Note that the final @invs->num_invs might not reflect the actual number of
+ * invalidations due to trash entries. Any reader should take the read lock to
+ * iterate each entry and check its users counter till the last entry.
+ */
+VISIBLE_IF_KUNIT
+void arm_smmu_invs_unref(struct arm_smmu_invs *invs,
+			 struct arm_smmu_invs *to_unref)
+{
+	unsigned long flags;
+	size_t num_invs = 0;
+	size_t i, j;
+	int cmp;
+
+	arm_smmu_invs_for_each_cmp(invs, i, to_unref, j, cmp) {
+		if (cmp < 0) {
+			/* not found in to_unref, leave alone */
+			num_invs = i + 1;
+		} else if (cmp == 0) {
+			int users = READ_ONCE(invs->inv[i].users) - 1;
+
+			if (WARN_ON(users < 0))
+				continue;
+
+			/* same item */
+			WRITE_ONCE(invs->inv[i].users, users);
+			if (users) {
+				WRITE_ONCE(to_unref->inv[j].users, 1);
+				num_invs = i + 1;
+				continue;
+			}
+
+			/* Notify the caller about the trash entry */
+			WRITE_ONCE(to_unref->inv[j].users, 0);
+			invs->num_trashes++;
+		} else {
+			/* item in to_unref is not in invs or already a trash */
+			WARN_ON(true);
+		}
+	}
+
+	/* Exclude any tailing trash */
+	invs->num_trashes -= invs->num_invs - num_invs;
+
+	/* The lock is required to fence concurrent ATS operations. */
+	write_lock_irqsave(&invs->rwlock, flags);
+	WRITE_ONCE(invs->num_invs, num_invs); /* Remove tailing trash entries */
+	write_unlock_irqrestore(&invs->rwlock, flags);
+}
+EXPORT_SYMBOL_IF_KUNIT(arm_smmu_invs_unref);
+
+/**
+ * arm_smmu_invs_purge() - Purge all the trash entries in the @invs
+ * @invs: the base invalidation array
+ *
+ * Return: a newly allocated array on success removing all the trash entries, or
+ *         NULL if there is no trash entry in the array or if allocation failed
+ *
+ * This function must be locked and serialized with arm_smmu_invs_merge() and
+ * arm_smmu_invs_unref(), but do not lockdep on any lock for KUNIT test.
+ *
+ * Caller is responsible for freeing the @invs and the returned new one.
+ */
+VISIBLE_IF_KUNIT
+struct arm_smmu_invs *arm_smmu_invs_purge(struct arm_smmu_invs *invs)
+{
+	struct arm_smmu_invs *new_invs;
+	struct arm_smmu_inv *inv;
+	size_t i, num_invs = 0;
+
+	if (WARN_ON(invs->num_invs < invs->num_trashes))
+		return NULL;
+	if (!invs->num_invs || !invs->num_trashes)
+		return NULL;
+
+	new_invs = arm_smmu_invs_alloc(invs->num_invs - invs->num_trashes);
+	if (!new_invs)
+		return NULL;
+
+	arm_smmu_invs_for_each_entry(invs, i, inv) {
+		new_invs->inv[num_invs] = *inv;
+		if (arm_smmu_inv_is_ats(inv))
+			new_invs->has_ats = true;
+		num_invs++;
+	}
+
+	WARN_ON(num_invs != new_invs->num_invs);
+	return new_invs;
+}
+EXPORT_SYMBOL_IF_KUNIT(arm_smmu_invs_purge);
+
+/* Context descriptor manipulation functions */
+
 /*
  * Based on the value of ent report which bits of the STE the HW will access. It
  * would be nice if this was complete according to the spec, but minimally it
@@ -1236,6 +1489,13 @@ void arm_smmu_write_entry(struct arm_smmu_entry_writer *writer, __le64 *entry,
 	__le64 unused_update[NUM_ENTRY_QWORDS];
 	u8 used_qword_diff;
 
+	/*
+	 * Many of the entry structures have pointers to other structures that
+	 * need to have their updates be visible before any writes of the entry
+	 * happen.
+	 */
+	dma_wmb();
+
 	used_qword_diff =
 		arm_smmu_entry_qword_diff(writer, entry, target, unused_update);
 	if (hweight8(used_qword_diff) == 1) {
@@ -2240,109 +2500,42 @@ static int arm_smmu_atc_inv_master(struct arm_smmu_master *master,
 	return arm_smmu_cmdq_batch_submit(master->smmu, &cmds);
 }
 
-int arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
-			    unsigned long iova, size_t size)
-{
-	struct arm_smmu_master_domain *master_domain;
-	int i;
-	unsigned long flags;
-	struct arm_smmu_cmdq_ent cmd = {
-		.opcode = CMDQ_OP_ATC_INV,
-	};
-	struct arm_smmu_cmdq_batch cmds;
-
-	if (!(smmu_domain->smmu->features & ARM_SMMU_FEAT_ATS))
-		return 0;
-
-	/*
-	 * Ensure that we've completed prior invalidation of the main TLBs
-	 * before we read 'nr_ats_masters' in case of a concurrent call to
-	 * arm_smmu_enable_ats():
-	 *
-	 *	// unmap()			// arm_smmu_enable_ats()
-	 *	TLBI+SYNC			atomic_inc(&nr_ats_masters);
-	 *	smp_mb();			[...]
-	 *	atomic_read(&nr_ats_masters);	pci_enable_ats() // writel()
-	 *
-	 * Ensures that we always see the incremented 'nr_ats_masters' count if
-	 * ATS was enabled at the PCI device before completion of the TLBI.
-	 */
-	smp_mb();
-	if (!atomic_read(&smmu_domain->nr_ats_masters))
-		return 0;
-
-	arm_smmu_cmdq_batch_init(smmu_domain->smmu, &cmds, &cmd);
-
-	spin_lock_irqsave(&smmu_domain->devices_lock, flags);
-	list_for_each_entry(master_domain, &smmu_domain->devices,
-			    devices_elm) {
-		struct arm_smmu_master *master = master_domain->master;
-
-		if (!master->ats_enabled)
-			continue;
-
-		if (master_domain->nested_ats_flush) {
-			/*
-			 * If a S2 used as a nesting parent is changed we have
-			 * no option but to completely flush the ATC.
-			 */
-			arm_smmu_atc_inv_to_cmd(IOMMU_NO_PASID, 0, 0, &cmd);
-		} else {
-			arm_smmu_atc_inv_to_cmd(master_domain->ssid, iova, size,
-						&cmd);
-		}
-
-		for (i = 0; i < master->num_streams; i++) {
-			cmd.atc.sid = master->streams[i].id;
-			arm_smmu_cmdq_batch_add(smmu_domain->smmu, &cmds, &cmd);
-		}
-	}
-	spin_unlock_irqrestore(&smmu_domain->devices_lock, flags);
-
-	return arm_smmu_cmdq_batch_submit(smmu_domain->smmu, &cmds);
-}
-
 /* IO_PGTABLE API */
 static void arm_smmu_tlb_inv_context(void *cookie)
 {
 	struct arm_smmu_domain *smmu_domain = cookie;
-	struct arm_smmu_device *smmu = smmu_domain->smmu;
-	struct arm_smmu_cmdq_ent cmd;
 
 	/*
-	 * NOTE: when io-pgtable is in non-strict mode, we may get here with
-	 * PTEs previously cleared by unmaps on the current CPU not yet visible
-	 * to the SMMU. We are relying on the dma_wmb() implicit during cmd
-	 * insertion to guarantee those are observed before the TLBI. Do be
-	 * careful, 007.
+	 * If the DMA API is running in non-strict mode then another CPU could
+	 * have changed the page table and not invoked any flush op. Instead the
+	 * other CPU will do an atomic_read() and this CPU will have done an
+	 * atomic_write(). That handshake is enough to acquire the page table
+	 * writes from the other CPU.
+	 *
+	 * All command execution has a dma_wmb() to release all the in-memory
+	 * structures written by this CPU, that barrier must also release the
+	 * writes acquired from all the other CPUs too.
+	 *
+	 * There are other barriers and atomics on this path, but the above is
+	 * the essential mechanism for ensuring that HW sees the page table
+	 * writes from another CPU before it executes the IOTLB invalidation.
 	 */
-	if (smmu_domain->stage == ARM_SMMU_DOMAIN_S1) {
-		arm_smmu_tlb_inv_asid(smmu, smmu_domain->cd.asid);
-	} else {
-		cmd.opcode	= CMDQ_OP_TLBI_S12_VMALL;
-		cmd.tlbi.vmid	= smmu_domain->s2_cfg.vmid;
-		arm_smmu_cmdq_issue_cmd_with_sync(smmu, &cmd);
-	}
-	arm_smmu_atc_inv_domain(smmu_domain, 0, 0);
+	arm_smmu_domain_inv(smmu_domain);
 }
 
-static void __arm_smmu_tlb_inv_range(struct arm_smmu_cmdq_ent *cmd,
-				     unsigned long iova, size_t size,
-				     size_t granule,
-				     struct arm_smmu_domain *smmu_domain)
+static void arm_smmu_cmdq_batch_add_range(struct arm_smmu_device *smmu,
+					  struct arm_smmu_cmdq_batch *cmds,
+					  struct arm_smmu_cmdq_ent *cmd,
+					  unsigned long iova, size_t size,
+					  size_t granule, size_t pgsize)
 {
-	struct arm_smmu_device *smmu = smmu_domain->smmu;
-	unsigned long end = iova + size, num_pages = 0, tg = 0;
+	unsigned long end = iova + size, num_pages = 0, tg = pgsize;
 	size_t inv_range = granule;
-	struct arm_smmu_cmdq_batch cmds;
 
-	if (!size)
+	if (WARN_ON_ONCE(!size))
 		return;
 
 	if (smmu->features & ARM_SMMU_FEAT_RANGE_INV) {
-		/* Get the leaf page size */
-		tg = __ffs(smmu_domain->domain.pgsize_bitmap);
-
 		num_pages = size >> tg;
 
 		/* Convert page size of 12,14,16 (log2) to 1,2,3 */
@@ -2362,8 +2555,6 @@ static void __arm_smmu_tlb_inv_range(struct arm_smmu_cmdq_ent *cmd,
 			num_pages++;
 	}
 
-	arm_smmu_cmdq_batch_init(smmu, &cmds, cmd);
-
 	while (iova < end) {
 		if (smmu->features & ARM_SMMU_FEAT_RANGE_INV) {
 			/*
@@ -2391,62 +2582,197 @@ static void __arm_smmu_tlb_inv_range(struct arm_smmu_cmdq_ent *cmd,
 		}
 
 		cmd->tlbi.addr = iova;
-		arm_smmu_cmdq_batch_add(smmu, &cmds, cmd);
+		arm_smmu_cmdq_batch_add(smmu, cmds, cmd);
 		iova += inv_range;
 	}
-	arm_smmu_cmdq_batch_submit(smmu, &cmds);
 }
 
-static void arm_smmu_tlb_inv_range_domain(unsigned long iova, size_t size,
-					  size_t granule, bool leaf,
-					  struct arm_smmu_domain *smmu_domain)
+static bool arm_smmu_inv_size_too_big(struct arm_smmu_device *smmu, size_t size,
+				      size_t granule)
 {
-	struct arm_smmu_cmdq_ent cmd = {
-		.tlbi = {
-			.leaf	= leaf,
-		},
-	};
+	size_t max_tlbi_ops;
 
-	if (smmu_domain->stage == ARM_SMMU_DOMAIN_S1) {
-		cmd.opcode	= smmu_domain->smmu->features & ARM_SMMU_FEAT_E2H ?
-				  CMDQ_OP_TLBI_EL2_VA : CMDQ_OP_TLBI_NH_VA;
-		cmd.tlbi.asid	= smmu_domain->cd.asid;
-	} else {
-		cmd.opcode	= CMDQ_OP_TLBI_S2_IPA;
-		cmd.tlbi.vmid	= smmu_domain->s2_cfg.vmid;
-	}
-	__arm_smmu_tlb_inv_range(&cmd, iova, size, granule, smmu_domain);
+	/* 0 size means invalidate all */
+	if (!size || size == SIZE_MAX)
+		return true;
 
-	if (smmu_domain->nest_parent) {
-		/*
-		 * When the S2 domain changes all the nested S1 ASIDs have to be
-		 * flushed too.
-		 */
-		cmd.opcode = CMDQ_OP_TLBI_NH_ALL;
-		arm_smmu_cmdq_issue_cmd_with_sync(smmu_domain->smmu, &cmd);
-	}
+	if (smmu->features & ARM_SMMU_FEAT_RANGE_INV)
+		return false;
 
 	/*
-	 * Unfortunately, this can't be leaf-only since we may have
-	 * zapped an entire table.
+	 * Borrowed from the MAX_TLBI_OPS in arch/arm64/include/asm/tlbflush.h,
+	 * this is used as a threshold to replace "size_opcode" commands with a
+	 * single "nsize_opcode" command, when SMMU doesn't implement the range
+	 * invalidation feature, where there can be too many per-granule TLBIs,
+	 * resulting in a soft lockup.
 	 */
-	arm_smmu_atc_inv_domain(smmu_domain, iova, size);
+	max_tlbi_ops = 1 << (ilog2(granule) - 3);
+	return size >= max_tlbi_ops * granule;
 }
 
-void arm_smmu_tlb_inv_range_asid(unsigned long iova, size_t size, int asid,
-				 size_t granule, bool leaf,
-				 struct arm_smmu_domain *smmu_domain)
+/* Used by non INV_TYPE_ATS* invalidations */
+static void arm_smmu_inv_to_cmdq_batch(struct arm_smmu_inv *inv,
+				       struct arm_smmu_cmdq_batch *cmds,
+				       struct arm_smmu_cmdq_ent *cmd,
+				       unsigned long iova, size_t size,
+				       unsigned int granule)
 {
-	struct arm_smmu_cmdq_ent cmd = {
-		.opcode	= smmu_domain->smmu->features & ARM_SMMU_FEAT_E2H ?
-			  CMDQ_OP_TLBI_EL2_VA : CMDQ_OP_TLBI_NH_VA,
-		.tlbi = {
-			.asid	= asid,
-			.leaf	= leaf,
-		},
-	};
+	if (arm_smmu_inv_size_too_big(inv->smmu, size, granule)) {
+		cmd->opcode = inv->nsize_opcode;
+		arm_smmu_cmdq_batch_add(inv->smmu, cmds, cmd);
+		return;
+	}
 
-	__arm_smmu_tlb_inv_range(&cmd, iova, size, granule, smmu_domain);
+	cmd->opcode = inv->size_opcode;
+	arm_smmu_cmdq_batch_add_range(inv->smmu, cmds, cmd, iova, size, granule,
+				      inv->pgsize);
+}
+
+static inline bool arm_smmu_invs_end_batch(struct arm_smmu_inv *cur,
+					   struct arm_smmu_inv *next)
+{
+	/* Changing smmu means changing command queue */
+	if (cur->smmu != next->smmu)
+		return true;
+	/* The batch for S2 TLBI must be done before nested S1 ASIDs */
+	if (cur->type != INV_TYPE_S2_VMID_S1_CLEAR &&
+	    next->type == INV_TYPE_S2_VMID_S1_CLEAR)
+		return true;
+	/* ATS must be after a sync of the S1/S2 invalidations */
+	if (!arm_smmu_inv_is_ats(cur) && arm_smmu_inv_is_ats(next))
+		return true;
+	return false;
+}
+
+static void __arm_smmu_domain_inv_range(struct arm_smmu_invs *invs,
+					unsigned long iova, size_t size,
+					unsigned int granule, bool leaf)
+{
+	struct arm_smmu_cmdq_batch cmds = {};
+	struct arm_smmu_inv *cur;
+	struct arm_smmu_inv *end;
+
+	cur = invs->inv;
+	end = cur + READ_ONCE(invs->num_invs);
+	/* Skip any leading entry marked as a trash */
+	for (; cur != end; cur++)
+		if (READ_ONCE(cur->users))
+			break;
+	while (cur != end) {
+		struct arm_smmu_device *smmu = cur->smmu;
+		struct arm_smmu_cmdq_ent cmd = {
+			/*
+			 * Pick size_opcode to run arm_smmu_get_cmdq(). This can
+			 * be changed to nsize_opcode, which would result in the
+			 * same CMDQ pointer.
+			 */
+			.opcode = cur->size_opcode,
+		};
+		struct arm_smmu_inv *next;
+
+		if (!cmds.num)
+			arm_smmu_cmdq_batch_init(smmu, &cmds, &cmd);
+
+		switch (cur->type) {
+		case INV_TYPE_S1_ASID:
+			cmd.tlbi.asid = cur->id;
+			cmd.tlbi.leaf = leaf;
+			arm_smmu_inv_to_cmdq_batch(cur, &cmds, &cmd, iova, size,
+						   granule);
+			break;
+		case INV_TYPE_S2_VMID:
+			cmd.tlbi.vmid = cur->id;
+			cmd.tlbi.leaf = leaf;
+			arm_smmu_inv_to_cmdq_batch(cur, &cmds, &cmd, iova, size,
+						   granule);
+			break;
+		case INV_TYPE_S2_VMID_S1_CLEAR:
+			/* CMDQ_OP_TLBI_S12_VMALL already flushed S1 entries */
+			if (arm_smmu_inv_size_too_big(cur->smmu, size, granule))
+				break;
+			cmd.tlbi.vmid = cur->id;
+			arm_smmu_cmdq_batch_add(smmu, &cmds, &cmd);
+			break;
+		case INV_TYPE_ATS:
+			arm_smmu_atc_inv_to_cmd(cur->ssid, iova, size, &cmd);
+			cmd.atc.sid = cur->id;
+			arm_smmu_cmdq_batch_add(smmu, &cmds, &cmd);
+			break;
+		case INV_TYPE_ATS_FULL:
+			arm_smmu_atc_inv_to_cmd(IOMMU_NO_PASID, 0, 0, &cmd);
+			cmd.atc.sid = cur->id;
+			arm_smmu_cmdq_batch_add(smmu, &cmds, &cmd);
+			break;
+		default:
+			WARN_ON_ONCE(1);
+			break;
+		}
+
+		/* Skip any trash entry in-between */
+		for (next = cur + 1; next != end; next++)
+			if (READ_ONCE(next->users))
+				break;
+
+		if (cmds.num &&
+		    (next == end || arm_smmu_invs_end_batch(cur, next))) {
+			arm_smmu_cmdq_batch_submit(smmu, &cmds);
+			cmds.num = 0;
+		}
+		cur = next;
+	}
+}
+
+void arm_smmu_domain_inv_range(struct arm_smmu_domain *smmu_domain,
+			       unsigned long iova, size_t size,
+			       unsigned int granule, bool leaf)
+{
+	struct arm_smmu_invs *invs;
+
+	/*
+	 * An invalidation request must follow some IOPTE change and then load
+	 * an invalidation array. In the meantime, a domain attachment mutates
+	 * the array and then stores an STE/CD asking SMMU HW to acquire those
+	 * changed IOPTEs.
+	 *
+	 * When running alone, a domain attachment relies on the dma_wmb() in
+	 * arm_smmu_write_entry() used by arm_smmu_install_ste_for_dev().
+	 *
+	 * But in a race, these two can be interdependent, making it a special
+	 * case requiring an additional smp_mb() for the write->read ordering.
+	 * Pairing with the dma_wmb() in arm_smmu_install_ste_for_dev(), this
+	 * makes sure that IOPTE update prior to this point is visible to SMMU
+	 * hardware before we load the updated invalidation array.
+	 *
+	 *  [CPU0]                        | [CPU1]
+	 *  change IOPTE on new domain:   |
+	 *  arm_smmu_domain_inv_range() { | arm_smmu_install_new_domain_invs()
+	 *    smp_mb(); // ensures IOPTE  | arm_smmu_install_ste_for_dev {
+	 *              // seen by SMMU   |   dma_wmb(); // ensures invs update
+	 *    // load the updated invs    |              // before updating STE
+	 *    invs = rcu_dereference();   |   STE = TTB0;
+	 *    ...                         |   ...
+	 *  }                             | }
+	 */
+	smp_mb();
+
+	rcu_read_lock();
+	invs = rcu_dereference(smmu_domain->invs);
+
+	/*
+	 * Avoid locking unless ATS is being used. No ATC invalidation can be
+	 * going on after a domain is detached.
+	 */
+	if (invs->has_ats) {
+		unsigned long flags;
+
+		read_lock_irqsave(&invs->rwlock, flags);
+		__arm_smmu_domain_inv_range(invs, iova, size, granule, leaf);
+		read_unlock_irqrestore(&invs->rwlock, flags);
+	} else {
+		__arm_smmu_domain_inv_range(invs, iova, size, granule, leaf);
+	}
+
+	rcu_read_unlock();
 }
 
 static void arm_smmu_tlb_inv_page_nosync(struct iommu_iotlb_gather *gather,
@@ -2462,7 +2788,9 @@ static void arm_smmu_tlb_inv_page_nosync(struct iommu_iotlb_gather *gather,
 static void arm_smmu_tlb_inv_walk(unsigned long iova, size_t size,
 				  size_t granule, void *cookie)
 {
-	arm_smmu_tlb_inv_range_domain(iova, size, granule, false, cookie);
+	struct arm_smmu_domain *smmu_domain = cookie;
+
+	arm_smmu_domain_inv_range(smmu_domain, iova, size, granule, false);
 }
 
 static const struct iommu_flush_ops arm_smmu_flush_ops = {
@@ -2494,6 +2822,8 @@ static bool arm_smmu_capable(struct device *dev, enum iommu_cap cap)
 		return true;
 	case IOMMU_CAP_DIRTY_TRACKING:
 		return arm_smmu_dbm_capable(master->smmu);
+	case IOMMU_CAP_PCI_ATS_SUPPORTED:
+		return arm_smmu_ats_supported(master);
 	default:
 		return false;
 	}
@@ -2522,13 +2852,21 @@ static bool arm_smmu_enforce_cache_coherency(struct iommu_domain *domain)
 struct arm_smmu_domain *arm_smmu_domain_alloc(void)
 {
 	struct arm_smmu_domain *smmu_domain;
+	struct arm_smmu_invs *new_invs;
 
 	smmu_domain = kzalloc_obj(*smmu_domain);
 	if (!smmu_domain)
 		return ERR_PTR(-ENOMEM);
 
+	new_invs = arm_smmu_invs_alloc(0);
+	if (!new_invs) {
+		kfree(smmu_domain);
+		return ERR_PTR(-ENOMEM);
+	}
+
 	INIT_LIST_HEAD(&smmu_domain->devices);
 	spin_lock_init(&smmu_domain->devices_lock);
+	rcu_assign_pointer(smmu_domain->invs, new_invs);
 
 	return smmu_domain;
 }
@@ -2552,7 +2890,7 @@ static void arm_smmu_domain_free_paging(struct iommu_domain *domain)
 			ida_free(&smmu->vmid_map, cfg->vmid);
 	}
 
-	kfree(smmu_domain);
+	arm_smmu_domain_free(smmu_domain);
 }
 
 static int arm_smmu_domain_finalise_s1(struct arm_smmu_device *smmu,
@@ -2870,6 +3208,121 @@ static void arm_smmu_disable_iopf(struct arm_smmu_master *master,
 		iopf_queue_remove_device(master->smmu->evtq.iopf, master->dev);
 }
 
+static struct arm_smmu_inv *
+arm_smmu_master_build_inv(struct arm_smmu_master *master,
+			  enum arm_smmu_inv_type type, u32 id, ioasid_t ssid,
+			  size_t pgsize)
+{
+	struct arm_smmu_invs *build_invs = master->build_invs;
+	struct arm_smmu_inv *cur, inv = {
+		.smmu = master->smmu,
+		.type = type,
+		.id = id,
+		.pgsize = pgsize,
+	};
+
+	if (WARN_ON(build_invs->num_invs >= build_invs->max_invs))
+		return NULL;
+	cur = &build_invs->inv[build_invs->num_invs];
+	build_invs->num_invs++;
+
+	*cur = inv;
+	switch (type) {
+	case INV_TYPE_S1_ASID:
+		/*
+		 * For S1 page tables the driver always uses VMID=0, and the
+		 * invalidation logic for this type will set it as well.
+		 */
+		if (master->smmu->features & ARM_SMMU_FEAT_E2H) {
+			cur->size_opcode = CMDQ_OP_TLBI_EL2_VA;
+			cur->nsize_opcode = CMDQ_OP_TLBI_EL2_ASID;
+		} else {
+			cur->size_opcode = CMDQ_OP_TLBI_NH_VA;
+			cur->nsize_opcode = CMDQ_OP_TLBI_NH_ASID;
+		}
+		break;
+	case INV_TYPE_S2_VMID:
+		cur->size_opcode = CMDQ_OP_TLBI_S2_IPA;
+		cur->nsize_opcode = CMDQ_OP_TLBI_S12_VMALL;
+		break;
+	case INV_TYPE_S2_VMID_S1_CLEAR:
+		cur->size_opcode = cur->nsize_opcode = CMDQ_OP_TLBI_NH_ALL;
+		break;
+	case INV_TYPE_ATS:
+	case INV_TYPE_ATS_FULL:
+		cur->size_opcode = cur->nsize_opcode = CMDQ_OP_ATC_INV;
+		cur->ssid = ssid;
+		break;
+	}
+
+	return cur;
+}
+
+/*
+ * Use the preallocated scratch array at master->build_invs, to build a to_merge
+ * or to_unref array, to pass into a following arm_smmu_invs_merge/unref() call.
+ *
+ * Do not free the returned invs array. It is reused, and will be overwritten by
+ * the next arm_smmu_master_build_invs() call.
+ */
+static struct arm_smmu_invs *
+arm_smmu_master_build_invs(struct arm_smmu_master *master, bool ats_enabled,
+			   ioasid_t ssid, struct arm_smmu_domain *smmu_domain)
+{
+	const bool nesting = smmu_domain->nest_parent;
+	size_t pgsize = 0, i;
+
+	iommu_group_mutex_assert(master->dev);
+
+	master->build_invs->num_invs = 0;
+
+	/* Range-based invalidation requires the leaf pgsize for calculation */
+	if (master->smmu->features & ARM_SMMU_FEAT_RANGE_INV)
+		pgsize = __ffs(smmu_domain->domain.pgsize_bitmap);
+
+	switch (smmu_domain->stage) {
+	case ARM_SMMU_DOMAIN_SVA:
+	case ARM_SMMU_DOMAIN_S1:
+		if (!arm_smmu_master_build_inv(master, INV_TYPE_S1_ASID,
+					       smmu_domain->cd.asid,
+					       IOMMU_NO_PASID, pgsize))
+			return NULL;
+		break;
+	case ARM_SMMU_DOMAIN_S2:
+		if (!arm_smmu_master_build_inv(master, INV_TYPE_S2_VMID,
+					       smmu_domain->s2_cfg.vmid,
+					       IOMMU_NO_PASID, pgsize))
+			return NULL;
+		break;
+	default:
+		WARN_ON(true);
+		return NULL;
+	}
+
+	/* All the nested S1 ASIDs have to be flushed when S2 parent changes */
+	if (nesting) {
+		if (!arm_smmu_master_build_inv(
+			    master, INV_TYPE_S2_VMID_S1_CLEAR,
+			    smmu_domain->s2_cfg.vmid, IOMMU_NO_PASID, 0))
+			return NULL;
+	}
+
+	for (i = 0; ats_enabled && i < master->num_streams; i++) {
+		/*
+		 * If an S2 used as a nesting parent is changed we have no
+		 * option but to completely flush the ATC.
+		 */
+		if (!arm_smmu_master_build_inv(
+			    master, nesting ? INV_TYPE_ATS_FULL : INV_TYPE_ATS,
+			    master->streams[i].id, ssid, 0))
+			return NULL;
+	}
+
+	/* Note this build_invs must have been sorted */
+
+	return master->build_invs;
+}
+
 static void arm_smmu_remove_master_domain(struct arm_smmu_master *master,
 					  struct iommu_domain *domain,
 					  ioasid_t ssid)
@@ -2900,6 +3353,135 @@ static void arm_smmu_remove_master_domain(struct arm_smmu_master *master,
 }
 
 /*
+ * During attachment, the updates of the two domain->invs arrays are sequenced:
+ *  1. new domain updates its invs array, merging master->build_invs
+ *  2. new domain starts to include the master during its invalidation
+ *  3. master updates its STE switching from the old domain to the new domain
+ *  4. old domain still includes the master during its invalidation
+ *  5. old domain updates its invs array, unreferencing master->build_invs
+ *
+ * For 1 and 5, prepare the two updated arrays in advance, handling any changes
+ * that can possibly failure. So the actual update of either 1 or 5 won't fail.
+ * arm_smmu_asid_lock ensures that the old invs in the domains are intact while
+ * we are sequencing to update them.
+ */
+static int arm_smmu_attach_prepare_invs(struct arm_smmu_attach_state *state,
+					struct iommu_domain *new_domain)
+{
+	struct arm_smmu_domain *old_smmu_domain =
+		to_smmu_domain_devices(state->old_domain);
+	struct arm_smmu_domain *new_smmu_domain =
+		to_smmu_domain_devices(new_domain);
+	struct arm_smmu_master *master = state->master;
+	ioasid_t ssid = state->ssid;
+
+	/*
+	 * At this point a NULL domain indicates the domain doesn't use the
+	 * IOTLB, see to_smmu_domain_devices().
+	 */
+	if (new_smmu_domain) {
+		struct arm_smmu_inv_state *invst = &state->new_domain_invst;
+		struct arm_smmu_invs *build_invs;
+
+		invst->invs_ptr = &new_smmu_domain->invs;
+		invst->old_invs = rcu_dereference_protected(
+			new_smmu_domain->invs,
+			lockdep_is_held(&arm_smmu_asid_lock));
+		build_invs = arm_smmu_master_build_invs(
+			master, state->ats_enabled, ssid, new_smmu_domain);
+		if (!build_invs)
+			return -EINVAL;
+
+		invst->new_invs =
+			arm_smmu_invs_merge(invst->old_invs, build_invs);
+		if (IS_ERR(invst->new_invs))
+			return PTR_ERR(invst->new_invs);
+	}
+
+	if (old_smmu_domain) {
+		struct arm_smmu_inv_state *invst = &state->old_domain_invst;
+
+		invst->invs_ptr = &old_smmu_domain->invs;
+		/* A re-attach case might have a different ats_enabled state */
+		if (new_smmu_domain == old_smmu_domain)
+			invst->old_invs = state->new_domain_invst.new_invs;
+		else
+			invst->old_invs = rcu_dereference_protected(
+				old_smmu_domain->invs,
+				lockdep_is_held(&arm_smmu_asid_lock));
+		/* For old_smmu_domain, new_invs points to master->build_invs */
+		invst->new_invs = arm_smmu_master_build_invs(
+			master, master->ats_enabled, ssid, old_smmu_domain);
+	}
+
+	return 0;
+}
+
+/* Must be installed before arm_smmu_install_ste_for_dev() */
+static void
+arm_smmu_install_new_domain_invs(struct arm_smmu_attach_state *state)
+{
+	struct arm_smmu_inv_state *invst = &state->new_domain_invst;
+
+	if (!invst->invs_ptr)
+		return;
+
+	rcu_assign_pointer(*invst->invs_ptr, invst->new_invs);
+	kfree_rcu(invst->old_invs, rcu);
+}
+
+static void arm_smmu_inv_flush_iotlb_tag(struct arm_smmu_inv *inv)
+{
+	struct arm_smmu_cmdq_ent cmd = {};
+
+	switch (inv->type) {
+	case INV_TYPE_S1_ASID:
+		cmd.tlbi.asid = inv->id;
+		break;
+	case INV_TYPE_S2_VMID:
+		/* S2_VMID using nsize_opcode covers S2_VMID_S1_CLEAR */
+		cmd.tlbi.vmid = inv->id;
+		break;
+	default:
+		return;
+	}
+
+	cmd.opcode = inv->nsize_opcode;
+	arm_smmu_cmdq_issue_cmd_with_sync(inv->smmu, &cmd);
+}
+
+/* Should be installed after arm_smmu_install_ste_for_dev() */
+static void
+arm_smmu_install_old_domain_invs(struct arm_smmu_attach_state *state)
+{
+	struct arm_smmu_inv_state *invst = &state->old_domain_invst;
+	struct arm_smmu_invs *old_invs = invst->old_invs;
+	struct arm_smmu_invs *new_invs;
+
+	lockdep_assert_held(&arm_smmu_asid_lock);
+
+	if (!invst->invs_ptr)
+		return;
+
+	arm_smmu_invs_unref(old_invs, invst->new_invs);
+	/*
+	 * When an IOTLB tag (the first entry in invs->new_invs) is no longer used,
+	 * it means the ASID or VMID will no longer be invalidated by map/unmap and
+	 * must be cleaned right now. The rule is that any ASID/VMID not in an invs
+	 * array must be left cleared in the IOTLB.
+	 */
+	if (!READ_ONCE(invst->new_invs->inv[0].users))
+		arm_smmu_inv_flush_iotlb_tag(&invst->new_invs->inv[0]);
+
+	new_invs = arm_smmu_invs_purge(old_invs);
+	if (!new_invs)
+		return;
+
+	rcu_assign_pointer(*invst->invs_ptr, new_invs);
+	kfree_rcu(old_invs, rcu);
+}
+
+/*
  * Start the sequence to attach a domain to a master. The sequence contains three
  * steps:
  *  arm_smmu_attach_prepare()
@@ -2956,12 +3538,16 @@ int arm_smmu_attach_prepare(struct arm_smmu_attach_state *state,
 				     arm_smmu_ats_supported(master);
 	}
 
+	ret = arm_smmu_attach_prepare_invs(state, new_domain);
+	if (ret)
+		return ret;
+
 	if (smmu_domain) {
 		if (new_domain->type == IOMMU_DOMAIN_NESTED) {
 			ret = arm_smmu_attach_prepare_vmaster(
 				state, to_smmu_nested_domain(new_domain));
 			if (ret)
-				return ret;
+				goto err_unprepare_invs;
 		}
 
 		master_domain = kzalloc_obj(*master_domain);
@@ -3009,6 +3595,8 @@ int arm_smmu_attach_prepare(struct arm_smmu_attach_state *state,
 			atomic_inc(&smmu_domain->nr_ats_masters);
 		list_add(&master_domain->devices_elm, &smmu_domain->devices);
 		spin_unlock_irqrestore(&smmu_domain->devices_lock, flags);
+
+		arm_smmu_install_new_domain_invs(state);
 	}
 
 	if (!state->ats_enabled && master->ats_enabled) {
@@ -3028,6 +3616,8 @@ int arm_smmu_attach_prepare(struct arm_smmu_attach_state *state,
 	kfree(master_domain);
 err_free_vmaster:
 	kfree(state->vmaster);
+err_unprepare_invs:
+	kfree(state->new_domain_invst.new_invs);
 	return ret;
 }
 
@@ -3059,6 +3649,7 @@ void arm_smmu_attach_commit(struct arm_smmu_attach_state *state)
 	}
 
 	arm_smmu_remove_master_domain(master, state->old_domain, state->ssid);
+	arm_smmu_install_old_domain_invs(state);
 	master->ats_enabled = state->ats_enabled;
 }
 
@@ -3125,6 +3716,9 @@ static int arm_smmu_attach_dev(struct iommu_domain *domain, struct device *dev,
 		arm_smmu_install_ste_for_dev(master, &target);
 		arm_smmu_clear_cd(master, IOMMU_NO_PASID);
 		break;
+	default:
+		WARN_ON(true);
+		break;
 	}
 
 	arm_smmu_attach_commit(&state);
@@ -3238,12 +3832,19 @@ static int arm_smmu_blocking_set_dev_pasid(struct iommu_domain *new_domain,
 {
 	struct arm_smmu_domain *smmu_domain = to_smmu_domain(old_domain);
 	struct arm_smmu_master *master = dev_iommu_priv_get(dev);
+	struct arm_smmu_attach_state state = {
+		.master = master,
+		.old_domain = old_domain,
+		.ssid = pasid,
+	};
 
 	mutex_lock(&arm_smmu_asid_lock);
+	arm_smmu_attach_prepare_invs(&state, NULL);
 	arm_smmu_clear_cd(master, pasid);
 	if (master->ats_enabled)
 		arm_smmu_atc_inv_master(master, pasid);
 	arm_smmu_remove_master_domain(master, &smmu_domain->domain, pasid);
+	arm_smmu_install_old_domain_invs(&state);
 	mutex_unlock(&arm_smmu_asid_lock);
 
 	/*
@@ -3417,7 +4018,7 @@ arm_smmu_domain_alloc_paging_flags(struct device *dev, u32 flags,
 	return &smmu_domain->domain;
 
 err_free:
-	kfree(smmu_domain);
+	arm_smmu_domain_free(smmu_domain);
 	return ERR_PTR(ret);
 }
 
@@ -3462,9 +4063,9 @@ static void arm_smmu_iotlb_sync(struct iommu_domain *domain,
 	if (!gather->pgsize)
 		return;
 
-	arm_smmu_tlb_inv_range_domain(gather->start,
-				      gather->end - gather->start + 1,
-				      gather->pgsize, true, smmu_domain);
+	arm_smmu_domain_inv_range(smmu_domain, gather->start,
+				  gather->end - gather->start + 1,
+				  gather->pgsize, true);
 }
 
 static phys_addr_t
@@ -3509,26 +4110,57 @@ static int arm_smmu_init_sid_strtab(struct arm_smmu_device *smmu, u32 sid)
 	return 0;
 }
 
+static int arm_smmu_stream_id_cmp(const void *_l, const void *_r)
+{
+	const typeof_member(struct arm_smmu_stream, id) *l = _l;
+	const typeof_member(struct arm_smmu_stream, id) *r = _r;
+
+	return cmp_int(*l, *r);
+}
+
 static int arm_smmu_insert_master(struct arm_smmu_device *smmu,
 				  struct arm_smmu_master *master)
 {
 	int i;
 	int ret = 0;
 	struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(master->dev);
+	bool ats_supported = dev_is_pci(master->dev) &&
+			     pci_ats_supported(to_pci_dev(master->dev));
 
 	master->streams = kzalloc_objs(*master->streams, fwspec->num_ids);
 	if (!master->streams)
 		return -ENOMEM;
 	master->num_streams = fwspec->num_ids;
 
+	if (!ats_supported) {
+		/* Base case has 1 ASID entry or maximum 2 VMID entries */
+		master->build_invs = arm_smmu_invs_alloc(2);
+	} else {
+		/* ATS case adds num_ids of entries, on top of the base case */
+		master->build_invs = arm_smmu_invs_alloc(2 + fwspec->num_ids);
+	}
+	if (!master->build_invs) {
+		kfree(master->streams);
+		return -ENOMEM;
+	}
+
+	for (i = 0; i < fwspec->num_ids; i++) {
+		struct arm_smmu_stream *new_stream = &master->streams[i];
+
+		new_stream->id = fwspec->ids[i];
+		new_stream->master = master;
+	}
+
+	/* Put the ids into order for sorted to_merge/to_unref arrays */
+	sort_nonatomic(master->streams, master->num_streams,
+		       sizeof(master->streams[0]), arm_smmu_stream_id_cmp,
+		       NULL);
+
 	mutex_lock(&smmu->streams_mutex);
 	for (i = 0; i < fwspec->num_ids; i++) {
 		struct arm_smmu_stream *new_stream = &master->streams[i];
 		struct rb_node *existing;
-		u32 sid = fwspec->ids[i];
-
-		new_stream->id = sid;
-		new_stream->master = master;
+		u32 sid = new_stream->id;
 
 		ret = arm_smmu_init_sid_strtab(smmu, sid);
 		if (ret)
@@ -3558,6 +4190,7 @@ static int arm_smmu_insert_master(struct arm_smmu_device *smmu,
 		for (i--; i >= 0; i--)
 			rb_erase(&master->streams[i].node, &smmu->streams);
 		kfree(master->streams);
+		kfree(master->build_invs);
 	}
 	mutex_unlock(&smmu->streams_mutex);
 
@@ -3579,6 +4212,7 @@ static void arm_smmu_remove_master(struct arm_smmu_master *master)
 	mutex_unlock(&smmu->streams_mutex);
 
 	kfree(master->streams);
+	kfree(master->build_invs);
 }
 
 static struct iommu_device *arm_smmu_probe_device(struct device *dev)
@@ -4308,6 +4942,8 @@ static int arm_smmu_device_reset(struct arm_smmu_device *smmu)
 #define IIDR_IMPLEMENTER_ARM		0x43b
 #define IIDR_PRODUCTID_ARM_MMU_600	0x483
 #define IIDR_PRODUCTID_ARM_MMU_700	0x487
+#define IIDR_PRODUCTID_ARM_MMU_L1	0x48a
+#define IIDR_PRODUCTID_ARM_MMU_S3	0x498
 
 static void arm_smmu_device_iidr_probe(struct arm_smmu_device *smmu)
 {
@@ -4332,11 +4968,19 @@ static void arm_smmu_device_iidr_probe(struct arm_smmu_device *smmu)
 				smmu->features &= ~ARM_SMMU_FEAT_NESTING;
 			break;
 		case IIDR_PRODUCTID_ARM_MMU_700:
-			/* Arm erratum 2812531 */
+			/* Many errata... */
 			smmu->features &= ~ARM_SMMU_FEAT_BTM;
-			smmu->options |= ARM_SMMU_OPT_CMDQ_FORCE_SYNC;
-			/* Arm errata 2268618, 2812531 */
-			smmu->features &= ~ARM_SMMU_FEAT_NESTING;
+			if (variant < 1 || revision < 1) {
+				/* Arm erratum 2812531 */
+				smmu->options |= ARM_SMMU_OPT_CMDQ_FORCE_SYNC;
+				/* Arm errata 2268618, 2812531 */
+				smmu->features &= ~ARM_SMMU_FEAT_NESTING;
+			}
+			break;
+		case IIDR_PRODUCTID_ARM_MMU_L1:
+		case IIDR_PRODUCTID_ARM_MMU_S3:
+			/* Arm errata 3878312/3995052 */
+			smmu->features &= ~ARM_SMMU_FEAT_BTM;
 			break;
 		}
 		break;
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
index 3c6d65d..ef42df4 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3.h
@@ -648,6 +648,93 @@ struct arm_smmu_cmdq_batch {
 	int				num;
 };
 
+/*
+ * The order here also determines the sequence in which commands are sent to the
+ * command queue. E.g. TLBI must be done before ATC_INV.
+ */
+enum arm_smmu_inv_type {
+	INV_TYPE_S1_ASID,
+	INV_TYPE_S2_VMID,
+	INV_TYPE_S2_VMID_S1_CLEAR,
+	INV_TYPE_ATS,
+	INV_TYPE_ATS_FULL,
+};
+
+struct arm_smmu_inv {
+	struct arm_smmu_device *smmu;
+	u8 type;
+	u8 size_opcode;
+	u8 nsize_opcode;
+	u32 id; /* ASID or VMID or SID */
+	union {
+		size_t pgsize; /* ARM_SMMU_FEAT_RANGE_INV */
+		u32 ssid; /* INV_TYPE_ATS */
+	};
+
+	int users; /* users=0 to mark as a trash to be purged */
+};
+
+static inline bool arm_smmu_inv_is_ats(const struct arm_smmu_inv *inv)
+{
+	return inv->type == INV_TYPE_ATS || inv->type == INV_TYPE_ATS_FULL;
+}
+
+/**
+ * struct arm_smmu_invs - Per-domain invalidation array
+ * @max_invs: maximum capacity of the flexible array
+ * @num_invs: number of invalidations in the flexible array. May be smaller than
+ *            @max_invs after a tailing trash entry is excluded, but must not be
+ *            greater than @max_invs
+ * @num_trashes: number of trash entries in the array for arm_smmu_invs_purge().
+ *               Must not be greater than @num_invs
+ * @rwlock: optional rwlock to fence ATS operations
+ * @has_ats: flag if the array contains an INV_TYPE_ATS or INV_TYPE_ATS_FULL
+ * @rcu: rcu head for kfree_rcu()
+ * @inv: flexible invalidation array
+ *
+ * The arm_smmu_invs is an RCU data structure. During a ->attach_dev callback,
+ * arm_smmu_invs_merge(), arm_smmu_invs_unref() and arm_smmu_invs_purge() will
+ * be used to allocate a new copy of an old array for addition and deletion in
+ * the old domain's and new domain's invs arrays.
+ *
+ * The arm_smmu_invs_unref() mutates a given array, by internally reducing the
+ * users counts of some given entries. This exists to support a no-fail routine
+ * like attaching to an IOMMU_DOMAIN_BLOCKED. And it could pair with a followup
+ * arm_smmu_invs_purge() call to generate a new clean array.
+ *
+ * Concurrent invalidation thread will push every invalidation described in the
+ * array into the command queue for each invalidation event. It is designed like
+ * this to optimize the invalidation fast path by avoiding locks.
+ *
+ * A domain can be shared across SMMU instances. When an instance gets removed,
+ * it would delete all the entries that belong to that SMMU instance. Then, a
+ * synchronize_rcu() would have to be called to sync the array, to prevent any
+ * concurrent invalidation thread accessing the old array from issuing commands
+ * to the command queue of a removed SMMU instance.
+ */
+struct arm_smmu_invs {
+	size_t max_invs;
+	size_t num_invs;
+	size_t num_trashes;
+	rwlock_t rwlock;
+	bool has_ats;
+	struct rcu_head rcu;
+	struct arm_smmu_inv inv[] __counted_by(max_invs);
+};
+
+static inline struct arm_smmu_invs *arm_smmu_invs_alloc(size_t num_invs)
+{
+	struct arm_smmu_invs *new_invs;
+
+	new_invs = kzalloc(struct_size(new_invs, inv, num_invs), GFP_KERNEL);
+	if (!new_invs)
+		return NULL;
+	new_invs->max_invs = num_invs;
+	new_invs->num_invs = num_invs;
+	rwlock_init(&new_invs->rwlock);
+	return new_invs;
+}
+
 struct arm_smmu_evtq {
 	struct arm_smmu_queue		q;
 	struct iopf_queue		*iopf;
@@ -841,6 +928,14 @@ struct arm_smmu_master {
 	struct arm_smmu_device		*smmu;
 	struct device			*dev;
 	struct arm_smmu_stream		*streams;
+	/*
+	 * Scratch memory for a to_merge or to_unref array to build a per-domain
+	 * invalidation array. It'll be pre-allocated with enough enries for all
+	 * possible build scenarios. It can be used by only one caller at a time
+	 * until the arm_smmu_invs_merge/unref() finishes. Must be locked by the
+	 * iommu_group mutex.
+	 */
+	struct arm_smmu_invs		*build_invs;
 	struct arm_smmu_vmaster		*vmaster; /* use smmu->streams_mutex */
 	/* Locked by the iommu core using the group mutex */
 	struct arm_smmu_ctx_desc_cfg	cd_table;
@@ -856,6 +951,7 @@ struct arm_smmu_master {
 enum arm_smmu_domain_stage {
 	ARM_SMMU_DOMAIN_S1 = 0,
 	ARM_SMMU_DOMAIN_S2,
+	ARM_SMMU_DOMAIN_SVA,
 };
 
 struct arm_smmu_domain {
@@ -872,6 +968,8 @@ struct arm_smmu_domain {
 
 	struct iommu_domain		domain;
 
+	struct arm_smmu_invs __rcu	*invs;
+
 	/* List of struct arm_smmu_master_domain */
 	struct list_head		devices;
 	spinlock_t			devices_lock;
@@ -924,6 +1022,12 @@ void arm_smmu_make_cdtable_ste(struct arm_smmu_ste *target,
 void arm_smmu_make_sva_cd(struct arm_smmu_cd *target,
 			  struct arm_smmu_master *master, struct mm_struct *mm,
 			  u16 asid);
+
+struct arm_smmu_invs *arm_smmu_invs_merge(struct arm_smmu_invs *invs,
+					  struct arm_smmu_invs *to_merge);
+void arm_smmu_invs_unref(struct arm_smmu_invs *invs,
+			 struct arm_smmu_invs *to_unref);
+struct arm_smmu_invs *arm_smmu_invs_purge(struct arm_smmu_invs *invs);
 #endif
 
 struct arm_smmu_master_domain {
@@ -955,6 +1059,13 @@ extern struct mutex arm_smmu_asid_lock;
 
 struct arm_smmu_domain *arm_smmu_domain_alloc(void);
 
+static inline void arm_smmu_domain_free(struct arm_smmu_domain *smmu_domain)
+{
+	/* No concurrency with invalidation is possible at this point */
+	kfree(rcu_dereference_protected(smmu_domain->invs, true));
+	kfree(smmu_domain);
+}
+
 void arm_smmu_clear_cd(struct arm_smmu_master *master, ioasid_t ssid);
 struct arm_smmu_cd *arm_smmu_get_cd_ptr(struct arm_smmu_master *master,
 					u32 ssid);
@@ -969,12 +1080,14 @@ int arm_smmu_set_pasid(struct arm_smmu_master *master,
 		       struct arm_smmu_domain *smmu_domain, ioasid_t pasid,
 		       struct arm_smmu_cd *cd, struct iommu_domain *old);
 
-void arm_smmu_tlb_inv_asid(struct arm_smmu_device *smmu, u16 asid);
-void arm_smmu_tlb_inv_range_asid(unsigned long iova, size_t size, int asid,
-				 size_t granule, bool leaf,
-				 struct arm_smmu_domain *smmu_domain);
-int arm_smmu_atc_inv_domain(struct arm_smmu_domain *smmu_domain,
-			    unsigned long iova, size_t size);
+void arm_smmu_domain_inv_range(struct arm_smmu_domain *smmu_domain,
+			       unsigned long iova, size_t size,
+			       unsigned int granule, bool leaf);
+
+static inline void arm_smmu_domain_inv(struct arm_smmu_domain *smmu_domain)
+{
+	arm_smmu_domain_inv_range(smmu_domain, 0, 0, 0, false);
+}
 
 void __arm_smmu_cmdq_skip_err(struct arm_smmu_device *smmu,
 			      struct arm_smmu_cmdq *cmdq);
@@ -991,6 +1104,21 @@ static inline bool arm_smmu_master_canwbs(struct arm_smmu_master *master)
 	       IOMMU_FWSPEC_PCI_RC_CANWBS;
 }
 
+/**
+ * struct arm_smmu_inv_state - Per-domain invalidation array state
+ * @invs_ptr: points to the domain->invs (unwinding nesting/etc.) or is NULL if
+ *            no change should be made
+ * @old_invs: the original invs array
+ * @new_invs: for new domain, this is the new invs array to update domain->invs;
+ *            for old domain, this is the master->build_invs to pass in as the
+ *            to_unref argument to an arm_smmu_invs_unref() call
+ */
+struct arm_smmu_inv_state {
+	struct arm_smmu_invs __rcu **invs_ptr;
+	struct arm_smmu_invs *old_invs;
+	struct arm_smmu_invs *new_invs;
+};
+
 struct arm_smmu_attach_state {
 	/* Inputs */
 	struct iommu_domain *old_domain;
@@ -1000,6 +1128,8 @@ struct arm_smmu_attach_state {
 	ioasid_t ssid;
 	/* Resulting state */
 	struct arm_smmu_vmaster *vmaster;
+	struct arm_smmu_inv_state old_domain_invst;
+	struct arm_smmu_inv_state new_domain_invst;
 	bool ats_enabled;
 };
 
diff --git a/drivers/iommu/arm/arm-smmu-v3/tegra241-cmdqv.c b/drivers/iommu/arm/arm-smmu-v3/tegra241-cmdqv.c
index 6fe5563..83f6e9f 100644
--- a/drivers/iommu/arm/arm-smmu-v3/tegra241-cmdqv.c
+++ b/drivers/iommu/arm/arm-smmu-v3/tegra241-cmdqv.c
@@ -479,6 +479,10 @@ static int tegra241_vcmdq_hw_init(struct tegra241_vcmdq *vcmdq)
 	/* Reset VCMDQ */
 	tegra241_vcmdq_hw_deinit(vcmdq);
 
+	/* vintf->hyp_own is a HW state finalized in tegra241_vintf_hw_init() */
+	if (!vcmdq->vintf->hyp_own)
+		vcmdq->cmdq.supports_cmd = tegra241_guest_vcmdq_supports_cmd;
+
 	/* Configure and enable VCMDQ */
 	writeq_relaxed(vcmdq->cmdq.q.q_base, REG_VCMDQ_PAGE1(vcmdq, BASE));
 
@@ -639,9 +643,6 @@ static int tegra241_vcmdq_alloc_smmu_cmdq(struct tegra241_vcmdq *vcmdq)
 	q->q_base = q->base_dma & VCMDQ_ADDR;
 	q->q_base |= FIELD_PREP(VCMDQ_LOG2SIZE, q->llq.max_n_shift);
 
-	if (!vcmdq->vintf->hyp_own)
-		cmdq->supports_cmd = tegra241_guest_vcmdq_supports_cmd;
-
 	return arm_smmu_cmdq_init(smmu, cmdq);
 }
 
diff --git a/drivers/iommu/dma-iommu.c b/drivers/iommu/dma-iommu.c
index 94d5141..0952353 100644
--- a/drivers/iommu/dma-iommu.c
+++ b/drivers/iommu/dma-iommu.c
@@ -14,6 +14,7 @@
 #include <linux/device.h>
 #include <linux/dma-direct.h>
 #include <linux/dma-map-ops.h>
+#include <linux/generic_pt/iommu.h>
 #include <linux/gfp.h>
 #include <linux/huge_mm.h>
 #include <linux/iommu.h>
@@ -648,6 +649,15 @@ static void iommu_dma_init_options(struct iommu_dma_options *options,
 	}
 }
 
+static bool iommu_domain_supports_fq(struct device *dev,
+				     struct iommu_domain *domain)
+{
+	/* iommupt always supports DMA-FQ */
+	if (iommupt_from_domain(domain))
+		return true;
+	return device_iommu_capable(dev, IOMMU_CAP_DEFERRED_FLUSH);
+}
+
 /**
  * iommu_dma_init_domain - Initialise a DMA mapping domain
  * @domain: IOMMU domain previously prepared by iommu_get_dma_cookie()
@@ -706,7 +716,8 @@ static int iommu_dma_init_domain(struct iommu_domain *domain, struct device *dev
 
 	/* If the FQ fails we can simply fall back to strict mode */
 	if (domain->type == IOMMU_DOMAIN_DMA_FQ &&
-	    (!device_iommu_capable(dev, IOMMU_CAP_DEFERRED_FLUSH) || iommu_dma_init_fq(domain)))
+	    (!iommu_domain_supports_fq(dev, domain) ||
+	     iommu_dma_init_fq(domain)))
 		domain->type = IOMMU_DOMAIN_DMA;
 
 	return iova_reserve_iommu_regions(dev, domain);
diff --git a/drivers/iommu/generic_pt/.kunitconfig b/drivers/iommu/generic_pt/.kunitconfig
index a78b295..0bb98fe 100644
--- a/drivers/iommu/generic_pt/.kunitconfig
+++ b/drivers/iommu/generic_pt/.kunitconfig
@@ -5,6 +5,7 @@
 CONFIG_IOMMU_PT=y
 CONFIG_IOMMU_PT_AMDV1=y
 CONFIG_IOMMU_PT_VTDSS=y
+CONFIG_IOMMU_PT_RISCV64=y
 CONFIG_IOMMU_PT_X86_64=y
 CONFIG_IOMMU_PT_KUNIT_TEST=y
 
diff --git a/drivers/iommu/generic_pt/Kconfig b/drivers/iommu/generic_pt/Kconfig
index ce4fb47..f4ed1ad 100644
--- a/drivers/iommu/generic_pt/Kconfig
+++ b/drivers/iommu/generic_pt/Kconfig
@@ -52,6 +52,16 @@
 
 	  Selected automatically by an IOMMU driver that uses this format.
 
+config IOMMU_PT_RISCV64
+       tristate "IOMMU page table for RISC-V 64 bit Sv57/Sv48/Sv39"
+	depends on !GENERIC_ATOMIC64 # for cmpxchg64
+	help
+	  iommu_domain implementation for RISC-V 64 bit 3/4/5 level page table.
+	  It supports 4K/2M/1G/512G/256T page sizes and can decode a sign
+	  extended portion of the 64 bit IOVA space.
+
+	  Selected automatically by an IOMMU driver that uses this format.
+
 config IOMMU_PT_X86_64
 	tristate "IOMMU page table for x86 64-bit, 4/5 levels"
 	depends on !GENERIC_ATOMIC64 # for cmpxchg64
@@ -66,6 +76,7 @@
 	tristate "IOMMU Page Table KUnit Test" if !KUNIT_ALL_TESTS
 	depends on KUNIT
 	depends on IOMMU_PT_AMDV1 || !IOMMU_PT_AMDV1
+	depends on IOMMU_PT_RISCV64 || !IOMMU_PT_RISCV64
 	depends on IOMMU_PT_X86_64 || !IOMMU_PT_X86_64
 	depends on IOMMU_PT_VTDSS || !IOMMU_PT_VTDSS
 	default KUNIT_ALL_TESTS
diff --git a/drivers/iommu/generic_pt/fmt/Makefile b/drivers/iommu/generic_pt/fmt/Makefile
index 976b49e..ea024d5 100644
--- a/drivers/iommu/generic_pt/fmt/Makefile
+++ b/drivers/iommu/generic_pt/fmt/Makefile
@@ -5,6 +5,8 @@
 
 iommu_pt_fmt-$(CONFIG_IOMMU_PT_VTDSS) += vtdss
 
+iommu_pt_fmt-$(CONFIG_IOMMU_PT_RISCV64) += riscv64
+
 iommu_pt_fmt-$(CONFIG_IOMMU_PT_X86_64) += x86_64
 
 IOMMU_PT_KUNIT_TEST :=
diff --git a/drivers/iommu/generic_pt/fmt/amdv1.h b/drivers/iommu/generic_pt/fmt/amdv1.h
index 3b2c41d..8d11b082 100644
--- a/drivers/iommu/generic_pt/fmt/amdv1.h
+++ b/drivers/iommu/generic_pt/fmt/amdv1.h
@@ -191,7 +191,7 @@ static inline enum pt_entry_type amdv1pt_load_entry_raw(struct pt_state *pts)
 }
 #define pt_load_entry_raw amdv1pt_load_entry_raw
 
-static inline void
+static __always_inline void
 amdv1pt_install_leaf_entry(struct pt_state *pts, pt_oaddr_t oa,
 			   unsigned int oasz_lg2,
 			   const struct pt_write_attrs *attrs)
diff --git a/drivers/iommu/generic_pt/fmt/defs_riscv.h b/drivers/iommu/generic_pt/fmt/defs_riscv.h
new file mode 100644
index 0000000..cf67474
--- /dev/null
+++ b/drivers/iommu/generic_pt/fmt/defs_riscv.h
@@ -0,0 +1,29 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+/*
+ * Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES
+ *
+ */
+#ifndef __GENERIC_PT_FMT_DEFS_RISCV_H
+#define __GENERIC_PT_FMT_DEFS_RISCV_H
+
+#include <linux/generic_pt/common.h>
+#include <linux/types.h>
+
+#ifdef PT_RISCV_32BIT
+typedef u32 pt_riscv_entry_t;
+#define riscvpt_write_attrs riscv32pt_write_attrs
+#else
+typedef u64 pt_riscv_entry_t;
+#define riscvpt_write_attrs riscv64pt_write_attrs
+#endif
+
+typedef pt_riscv_entry_t pt_vaddr_t;
+typedef u64 pt_oaddr_t;
+
+struct riscvpt_write_attrs {
+	pt_riscv_entry_t descriptor_bits;
+	gfp_t gfp;
+};
+#define pt_write_attrs riscvpt_write_attrs
+
+#endif
diff --git a/drivers/iommu/generic_pt/fmt/iommu_riscv64.c b/drivers/iommu/generic_pt/fmt/iommu_riscv64.c
new file mode 100644
index 0000000..cbf60ff
--- /dev/null
+++ b/drivers/iommu/generic_pt/fmt/iommu_riscv64.c
@@ -0,0 +1,11 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/*
+ * Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES
+ */
+#define PT_FMT riscv
+#define PT_FMT_VARIANT 64
+#define PT_SUPPORTED_FEATURES                                  \
+	(BIT(PT_FEAT_SIGN_EXTEND) | BIT(PT_FEAT_FLUSH_RANGE) | \
+	 BIT(PT_FEAT_RISCV_SVNAPOT_64K))
+
+#include "iommu_template.h"
diff --git a/drivers/iommu/generic_pt/fmt/riscv.h b/drivers/iommu/generic_pt/fmt/riscv.h
new file mode 100644
index 0000000..a7fef62
--- /dev/null
+++ b/drivers/iommu/generic_pt/fmt/riscv.h
@@ -0,0 +1,313 @@
+/* SPDX-License-Identifier: GPL-2.0-only */
+/*
+ * Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES
+ *
+ * RISC-V page table
+ *
+ * This is described in Sections:
+ *  12.3. Sv32: Page-Based 32-bit Virtual-Memory Systems
+ *  12.4. Sv39: Page-Based 39-bit Virtual-Memory System
+ *  12.5. Sv48: Page-Based 48-bit Virtual-Memory System
+ *  12.6. Sv57: Page-Based 57-bit Virtual-Memory System
+ * of the "The RISC-V Instruction Set Manual: Volume II"
+ *
+ * This includes the contiguous page extension from:
+ *  Chapter 13. "Svnapot" Extension for NAPOT Translation Contiguity,
+ *     Version 1.0
+ *
+ * The table format is sign extended and supports leafs in every level. The spec
+ * doesn't talk a lot about levels, but level here is the same as i=LEVELS-1 in
+ * the spec.
+ */
+#ifndef __GENERIC_PT_FMT_RISCV_H
+#define __GENERIC_PT_FMT_RISCV_H
+
+#include "defs_riscv.h"
+#include "../pt_defs.h"
+
+#include <linux/bitfield.h>
+#include <linux/container_of.h>
+#include <linux/log2.h>
+#include <linux/sizes.h>
+
+enum {
+	PT_ITEM_WORD_SIZE = sizeof(pt_riscv_entry_t),
+#ifdef PT_RISCV_32BIT
+	PT_MAX_VA_ADDRESS_LG2 = 32,
+	PT_MAX_OUTPUT_ADDRESS_LG2 = 34,
+	PT_MAX_TOP_LEVEL = 1,
+#else
+	PT_MAX_VA_ADDRESS_LG2 = 57,
+	PT_MAX_OUTPUT_ADDRESS_LG2 = 56,
+	PT_MAX_TOP_LEVEL = 4,
+#endif
+	PT_GRANULE_LG2SZ = 12,
+	PT_TABLEMEM_LG2SZ = 12,
+
+	/* fsc.PPN is 44 bits wide, all PPNs are 4k aligned */
+	PT_TOP_PHYS_MASK = GENMASK_ULL(55, 12),
+};
+
+/* PTE bits */
+enum {
+	RISCVPT_V = BIT(0),
+	RISCVPT_R = BIT(1),
+	RISCVPT_W = BIT(2),
+	RISCVPT_X = BIT(3),
+	RISCVPT_U = BIT(4),
+	RISCVPT_G = BIT(5),
+	RISCVPT_A = BIT(6),
+	RISCVPT_D = BIT(7),
+	RISCVPT_RSW = GENMASK(9, 8),
+	RISCVPT_PPN32 = GENMASK(31, 10),
+
+	RISCVPT_PPN64 = GENMASK_ULL(53, 10),
+	RISCVPT_PPN64_64K = GENMASK_ULL(53, 14),
+	RISCVPT_PBMT = GENMASK_ULL(62, 61),
+	RISCVPT_N = BIT_ULL(63),
+
+	/* Svnapot encodings for ppn[0] */
+	RISCVPT_PPN64_64K_SZ = BIT(13),
+};
+
+#ifdef PT_RISCV_32BIT
+#define RISCVPT_PPN RISCVPT_PPN32
+#define pt_riscv pt_riscv_32
+#else
+#define RISCVPT_PPN RISCVPT_PPN64
+#define pt_riscv pt_riscv_64
+#endif
+
+#define common_to_riscvpt(common_ptr) \
+	container_of_const(common_ptr, struct pt_riscv, common)
+#define to_riscvpt(pts) common_to_riscvpt((pts)->range->common)
+
+static inline pt_oaddr_t riscvpt_table_pa(const struct pt_state *pts)
+{
+	return oalog2_mul(FIELD_GET(RISCVPT_PPN, pts->entry), PT_GRANULE_LG2SZ);
+}
+#define pt_table_pa riscvpt_table_pa
+
+static inline pt_oaddr_t riscvpt_entry_oa(const struct pt_state *pts)
+{
+	if (pts_feature(pts, PT_FEAT_RISCV_SVNAPOT_64K) &&
+	    pts->entry & RISCVPT_N) {
+		PT_WARN_ON(pts->level != 0);
+		return oalog2_mul(FIELD_GET(RISCVPT_PPN64_64K, pts->entry),
+				  ilog2(SZ_64K));
+	}
+	return oalog2_mul(FIELD_GET(RISCVPT_PPN, pts->entry), PT_GRANULE_LG2SZ);
+}
+#define pt_entry_oa riscvpt_entry_oa
+
+static inline bool riscvpt_can_have_leaf(const struct pt_state *pts)
+{
+	return true;
+}
+#define pt_can_have_leaf riscvpt_can_have_leaf
+
+/* Body in pt_fmt_defaults.h */
+static inline unsigned int pt_table_item_lg2sz(const struct pt_state *pts);
+
+static inline unsigned int
+riscvpt_entry_num_contig_lg2(const struct pt_state *pts)
+{
+	if (PT_SUPPORTED_FEATURE(PT_FEAT_RISCV_SVNAPOT_64K) &&
+	    pts->entry & RISCVPT_N) {
+		PT_WARN_ON(!pts_feature(pts, PT_FEAT_RISCV_SVNAPOT_64K));
+		PT_WARN_ON(pts->level);
+		return ilog2(16);
+	}
+	return ilog2(1);
+}
+#define pt_entry_num_contig_lg2 riscvpt_entry_num_contig_lg2
+
+static inline unsigned int riscvpt_num_items_lg2(const struct pt_state *pts)
+{
+	return PT_TABLEMEM_LG2SZ - ilog2(sizeof(u64));
+}
+#define pt_num_items_lg2 riscvpt_num_items_lg2
+
+static inline unsigned short
+riscvpt_contig_count_lg2(const struct pt_state *pts)
+{
+	if (pts->level == 0 && pts_feature(pts, PT_FEAT_RISCV_SVNAPOT_64K))
+		return ilog2(16);
+	return ilog2(1);
+}
+#define pt_contig_count_lg2 riscvpt_contig_count_lg2
+
+static inline enum pt_entry_type riscvpt_load_entry_raw(struct pt_state *pts)
+{
+	const pt_riscv_entry_t *tablep = pt_cur_table(pts, pt_riscv_entry_t);
+	pt_riscv_entry_t entry;
+
+	pts->entry = entry = READ_ONCE(tablep[pts->index]);
+	if (!(entry & RISCVPT_V))
+		return PT_ENTRY_EMPTY;
+	if (pts->level == 0 ||
+	    ((entry & (RISCVPT_X | RISCVPT_W | RISCVPT_R)) != 0))
+		return PT_ENTRY_OA;
+	return PT_ENTRY_TABLE;
+}
+#define pt_load_entry_raw riscvpt_load_entry_raw
+
+static inline void
+riscvpt_install_leaf_entry(struct pt_state *pts, pt_oaddr_t oa,
+			   unsigned int oasz_lg2,
+			   const struct pt_write_attrs *attrs)
+{
+	pt_riscv_entry_t *tablep = pt_cur_table(pts, pt_riscv_entry_t);
+	pt_riscv_entry_t entry;
+
+	if (!pt_check_install_leaf_args(pts, oa, oasz_lg2))
+		return;
+
+	entry = RISCVPT_V |
+		FIELD_PREP(RISCVPT_PPN, log2_div(oa, PT_GRANULE_LG2SZ)) |
+		attrs->descriptor_bits;
+
+	if (pts_feature(pts, PT_FEAT_RISCV_SVNAPOT_64K) && pts->level == 0 &&
+	    oasz_lg2 != PT_GRANULE_LG2SZ) {
+		u64 *end;
+
+		entry |= RISCVPT_N | RISCVPT_PPN64_64K_SZ;
+		tablep += pts->index;
+		end = tablep + log2_div(SZ_64K, PT_GRANULE_LG2SZ);
+		for (; tablep != end; tablep++)
+			WRITE_ONCE(*tablep, entry);
+	} else {
+		/* FIXME does riscv need this to be cmpxchg? */
+		WRITE_ONCE(tablep[pts->index], entry);
+	}
+	pts->entry = entry;
+}
+#define pt_install_leaf_entry riscvpt_install_leaf_entry
+
+static inline bool riscvpt_install_table(struct pt_state *pts,
+					 pt_oaddr_t table_pa,
+					 const struct pt_write_attrs *attrs)
+{
+	pt_riscv_entry_t entry;
+
+	entry = RISCVPT_V |
+		FIELD_PREP(RISCVPT_PPN, log2_div(table_pa, PT_GRANULE_LG2SZ));
+	return pt_table_install64(pts, entry);
+}
+#define pt_install_table riscvpt_install_table
+
+static inline void riscvpt_attr_from_entry(const struct pt_state *pts,
+					   struct pt_write_attrs *attrs)
+{
+	attrs->descriptor_bits =
+		pts->entry & (RISCVPT_R | RISCVPT_W | RISCVPT_X | RISCVPT_U |
+			      RISCVPT_G | RISCVPT_A | RISCVPT_D);
+}
+#define pt_attr_from_entry riscvpt_attr_from_entry
+
+/* --- iommu */
+#include <linux/generic_pt/iommu.h>
+#include <linux/iommu.h>
+
+#define pt_iommu_table pt_iommu_riscv_64
+
+/* The common struct is in the per-format common struct */
+static inline struct pt_common *common_from_iommu(struct pt_iommu *iommu_table)
+{
+	return &container_of(iommu_table, struct pt_iommu_table, iommu)
+			->riscv_64pt.common;
+}
+
+static inline struct pt_iommu *iommu_from_common(struct pt_common *common)
+{
+	return &container_of(common, struct pt_iommu_table, riscv_64pt.common)
+			->iommu;
+}
+
+static inline int riscvpt_iommu_set_prot(struct pt_common *common,
+					 struct pt_write_attrs *attrs,
+					 unsigned int iommu_prot)
+{
+	u64 pte;
+
+	pte = RISCVPT_A | RISCVPT_U;
+	if (iommu_prot & IOMMU_WRITE)
+		pte |= RISCVPT_W | RISCVPT_R | RISCVPT_D;
+	if (iommu_prot & IOMMU_READ)
+		pte |= RISCVPT_R;
+	if (!(iommu_prot & IOMMU_NOEXEC))
+		pte |= RISCVPT_X;
+
+	/* Caller must specify a supported combination of flags */
+	if (unlikely((pte & (RISCVPT_X | RISCVPT_W | RISCVPT_R)) == 0))
+		return -EOPNOTSUPP;
+
+	attrs->descriptor_bits = pte;
+	return 0;
+}
+#define pt_iommu_set_prot riscvpt_iommu_set_prot
+
+static inline int
+riscvpt_iommu_fmt_init(struct pt_iommu_riscv_64 *iommu_table,
+		       const struct pt_iommu_riscv_64_cfg *cfg)
+{
+	struct pt_riscv *table = &iommu_table->riscv_64pt;
+
+	switch (cfg->common.hw_max_vasz_lg2) {
+	case 39:
+		pt_top_set_level(&table->common, 2);
+		break;
+	case 48:
+		pt_top_set_level(&table->common, 3);
+		break;
+	case 57:
+		pt_top_set_level(&table->common, 4);
+		break;
+	default:
+		return -EINVAL;
+	}
+	table->common.max_oasz_lg2 =
+		min(PT_MAX_OUTPUT_ADDRESS_LG2, cfg->common.hw_max_oasz_lg2);
+	return 0;
+}
+#define pt_iommu_fmt_init riscvpt_iommu_fmt_init
+
+static inline void
+riscvpt_iommu_fmt_hw_info(struct pt_iommu_riscv_64 *table,
+			  const struct pt_range *top_range,
+			  struct pt_iommu_riscv_64_hw_info *info)
+{
+	phys_addr_t top_phys = virt_to_phys(top_range->top_table);
+
+	info->ppn = oalog2_div(top_phys, PT_GRANULE_LG2SZ);
+	PT_WARN_ON(top_phys & ~PT_TOP_PHYS_MASK);
+
+	/*
+	 * See Table 3. Encodings of iosatp.MODE field" for DC.tx.SXL = 0:
+	 *  8 = Sv39 = top level 2
+	 *  9 = Sv38 = top level 3
+	 *  10 = Sv57 = top level 4
+	 */
+	info->fsc_iosatp_mode = top_range->top_level + 6;
+}
+#define pt_iommu_fmt_hw_info riscvpt_iommu_fmt_hw_info
+
+#if defined(GENERIC_PT_KUNIT)
+static const struct pt_iommu_riscv_64_cfg riscv_64_kunit_fmt_cfgs[] = {
+	[0] = { .common.features = BIT(PT_FEAT_RISCV_SVNAPOT_64K),
+		.common.hw_max_oasz_lg2 = 56,
+		.common.hw_max_vasz_lg2 = 39 },
+	[1] = { .common.features = 0,
+		.common.hw_max_oasz_lg2 = 56,
+		.common.hw_max_vasz_lg2 = 48 },
+	[2] = { .common.features = BIT(PT_FEAT_RISCV_SVNAPOT_64K),
+		.common.hw_max_oasz_lg2 = 56,
+		.common.hw_max_vasz_lg2 = 57 },
+};
+#define kunit_fmt_cfgs riscv_64_kunit_fmt_cfgs
+enum {
+	KUNIT_FMT_FEATURES = BIT(PT_FEAT_RISCV_SVNAPOT_64K),
+};
+#endif
+
+#endif
diff --git a/drivers/iommu/generic_pt/iommu_pt.h b/drivers/iommu/generic_pt/iommu_pt.h
index 3e33fe6..19b6daf 100644
--- a/drivers/iommu/generic_pt/iommu_pt.h
+++ b/drivers/iommu/generic_pt/iommu_pt.h
@@ -51,16 +51,27 @@ static void gather_range_pages(struct iommu_iotlb_gather *iotlb_gather,
 		iommu_pages_stop_incoherent_list(free_list,
 						 iommu_table->iommu_device);
 
-	if (pt_feature(common, PT_FEAT_FLUSH_RANGE_NO_GAPS) &&
-	    iommu_iotlb_gather_is_disjoint(iotlb_gather, iova, len)) {
-		iommu_iotlb_sync(&iommu_table->domain, iotlb_gather);
-		/*
-		 * Note that the sync frees the gather's free list, so we must
-		 * not have any pages on that list that are covered by iova/len
-		 */
+	/*
+	 * If running in DMA-FQ mode then the unmap will be followed by an IOTLB
+	 * flush all so we need to optimize by never flushing the IOTLB here.
+	 *
+	 * For NO_GAPS the user gets to pick if flushing all or doing micro
+	 * flushes is better for their work load by choosing DMA vs DMA-FQ
+	 * operation. Drivers should also see shadow_on_flush.
+	 */
+	if (!iommu_iotlb_gather_queued(iotlb_gather)) {
+		if (pt_feature(common, PT_FEAT_FLUSH_RANGE_NO_GAPS) &&
+		    iommu_iotlb_gather_is_disjoint(iotlb_gather, iova, len)) {
+			iommu_iotlb_sync(&iommu_table->domain, iotlb_gather);
+			/*
+			 * Note that the sync frees the gather's free list, so
+			 * we must not have any pages on that list that are
+			 * covered by iova/len
+			 */
+		}
+		iommu_iotlb_gather_add_range(iotlb_gather, iova, len);
 	}
 
-	iommu_iotlb_gather_add_range(iotlb_gather, iova, len);
 	iommu_pages_list_splice(free_list, &iotlb_gather->freelist);
 }
 
@@ -466,6 +477,7 @@ struct pt_iommu_map_args {
 	pt_oaddr_t oa;
 	unsigned int leaf_pgsize_lg2;
 	unsigned int leaf_level;
+	pt_vaddr_t num_leaves;
 };
 
 /*
@@ -518,11 +530,15 @@ static int clear_contig(const struct pt_state *start_pts,
 static int __map_range_leaf(struct pt_range *range, void *arg,
 			    unsigned int level, struct pt_table_p *table)
 {
+	struct pt_iommu *iommu_table = iommu_from_common(range->common);
 	struct pt_state pts = pt_init(range, level, table);
 	struct pt_iommu_map_args *map = arg;
 	unsigned int leaf_pgsize_lg2 = map->leaf_pgsize_lg2;
 	unsigned int start_index;
 	pt_oaddr_t oa = map->oa;
+	unsigned int num_leaves;
+	unsigned int orig_end;
+	pt_vaddr_t last_va;
 	unsigned int step;
 	bool need_contig;
 	int ret = 0;
@@ -536,6 +552,15 @@ static int __map_range_leaf(struct pt_range *range, void *arg,
 
 	_pt_iter_first(&pts);
 	start_index = pts.index;
+	orig_end = pts.end_index;
+	if (pts.index + map->num_leaves < pts.end_index) {
+		/* Need to stop in the middle of the table to change sizes */
+		pts.end_index = pts.index + map->num_leaves;
+		num_leaves = 0;
+	} else {
+		num_leaves = map->num_leaves - (pts.end_index - pts.index);
+	}
+
 	do {
 		pts.type = pt_load_entry_raw(&pts);
 		if (pts.type != PT_ENTRY_EMPTY || need_contig) {
@@ -561,7 +586,40 @@ static int __map_range_leaf(struct pt_range *range, void *arg,
 	flush_writes_range(&pts, start_index, pts.index);
 
 	map->oa = oa;
-	return ret;
+	map->num_leaves = num_leaves;
+	if (ret || num_leaves)
+		return ret;
+
+	/* range->va is not valid if we reached the end of the table */
+	pts.index -= step;
+	pt_index_to_va(&pts);
+	pts.index += step;
+	last_va = range->va + log2_to_int(leaf_pgsize_lg2);
+
+	if (last_va - 1 == range->last_va) {
+		PT_WARN_ON(pts.index != orig_end);
+		return 0;
+	}
+
+	/*
+	 * Reached a point where the page size changed, compute the new
+	 * parameters.
+	 */
+	map->leaf_pgsize_lg2 = pt_compute_best_pgsize(
+		iommu_table->domain.pgsize_bitmap, last_va, range->last_va, oa);
+	map->leaf_level =
+		pt_pgsz_lg2_to_level(range->common, map->leaf_pgsize_lg2);
+	map->num_leaves = pt_pgsz_count(iommu_table->domain.pgsize_bitmap,
+					last_va, range->last_va, oa,
+					map->leaf_pgsize_lg2);
+
+	/* Didn't finish this table level, caller will repeat it */
+	if (pts.index != orig_end) {
+		if (pts.index != start_index)
+			pt_index_to_va(&pts);
+		return -EAGAIN;
+	}
+	return 0;
 }
 
 static int __map_range(struct pt_range *range, void *arg, unsigned int level,
@@ -584,14 +642,9 @@ static int __map_range(struct pt_range *range, void *arg, unsigned int level,
 			if (pts.type != PT_ENTRY_EMPTY)
 				return -EADDRINUSE;
 			ret = pt_iommu_new_table(&pts, &map->attrs);
-			if (ret) {
-				/*
-				 * Racing with another thread installing a table
-				 */
-				if (ret == -EAGAIN)
-					continue;
+			/* EAGAIN on a race will loop again */
+			if (ret)
 				return ret;
-			}
 		} else {
 			pts.table_lower = pt_table_ptr(&pts);
 			/*
@@ -615,10 +668,12 @@ static int __map_range(struct pt_range *range, void *arg, unsigned int level,
 		 * The already present table can possibly be shared with another
 		 * concurrent map.
 		 */
-		if (map->leaf_level == level - 1)
-			ret = pt_descend(&pts, arg, __map_range_leaf);
-		else
-			ret = pt_descend(&pts, arg, __map_range);
+		do {
+			if (map->leaf_level == level - 1)
+				ret = pt_descend(&pts, arg, __map_range_leaf);
+			else
+				ret = pt_descend(&pts, arg, __map_range);
+		} while (ret == -EAGAIN);
 		if (ret)
 			return ret;
 
@@ -626,6 +681,14 @@ static int __map_range(struct pt_range *range, void *arg, unsigned int level,
 		pt_index_to_va(&pts);
 		if (pts.index >= pts.end_index)
 			break;
+
+		/*
+		 * This level is currently running __map_range_leaf() which is
+		 * not correct if the target level has been updated to this
+		 * level. Have the caller invoke __map_range_leaf.
+		 */
+		if (map->leaf_level == level)
+			return -EAGAIN;
 	} while (true);
 	return 0;
 }
@@ -797,12 +860,13 @@ static int check_map_range(struct pt_iommu *iommu_table, struct pt_range *range,
 static int do_map(struct pt_range *range, struct pt_common *common,
 		  bool single_page, struct pt_iommu_map_args *map)
 {
+	int ret;
+
 	/*
 	 * The __map_single_page() fast path does not support DMA_INCOHERENT
 	 * flushing to keep its .text small.
 	 */
 	if (single_page && !pt_feature(common, PT_FEAT_DMA_INCOHERENT)) {
-		int ret;
 
 		ret = pt_walk_range(range, __map_single_page, map);
 		if (ret != -EAGAIN)
@@ -810,50 +874,25 @@ static int do_map(struct pt_range *range, struct pt_common *common,
 		/* EAGAIN falls through to the full path */
 	}
 
-	if (map->leaf_level == range->top_level)
-		return pt_walk_range(range, __map_range_leaf, map);
-	return pt_walk_range(range, __map_range, map);
+	do {
+		if (map->leaf_level == range->top_level)
+			ret = pt_walk_range(range, __map_range_leaf, map);
+		else
+			ret = pt_walk_range(range, __map_range, map);
+	} while (ret == -EAGAIN);
+	return ret;
 }
 
-/**
- * map_pages() - Install translation for an IOVA range
- * @domain: Domain to manipulate
- * @iova: IO virtual address to start
- * @paddr: Physical/Output address to start
- * @pgsize: Length of each page
- * @pgcount: Length of the range in pgsize units starting from @iova
- * @prot: A bitmap of IOMMU_READ/WRITE/CACHE/NOEXEC/MMIO
- * @gfp: GFP flags for any memory allocations
- * @mapped: Total bytes successfully mapped
- *
- * The range starting at IOVA will have paddr installed into it. The caller
- * must specify a valid pgsize and pgcount to segment the range into compatible
- * blocks.
- *
- * On error the caller will probably want to invoke unmap on the range from iova
- * up to the amount indicated by @mapped to return the table back to an
- * unchanged state.
- *
- * Context: The caller must hold a write range lock that includes the whole
- * range.
- *
- * Returns: -ERRNO on failure, 0 on success. The number of bytes of VA that were
- * mapped are added to @mapped, @mapped is not zerod first.
- */
-int DOMAIN_NS(map_pages)(struct iommu_domain *domain, unsigned long iova,
-			 phys_addr_t paddr, size_t pgsize, size_t pgcount,
-			 int prot, gfp_t gfp, size_t *mapped)
+static int NS(map_range)(struct pt_iommu *iommu_table, dma_addr_t iova,
+			 phys_addr_t paddr, dma_addr_t len, unsigned int prot,
+			 gfp_t gfp, size_t *mapped)
 {
-	struct pt_iommu *iommu_table =
-		container_of(domain, struct pt_iommu, domain);
 	pt_vaddr_t pgsize_bitmap = iommu_table->domain.pgsize_bitmap;
 	struct pt_common *common = common_from_iommu(iommu_table);
 	struct iommu_iotlb_gather iotlb_gather;
-	pt_vaddr_t len = pgsize * pgcount;
 	struct pt_iommu_map_args map = {
 		.iotlb_gather = &iotlb_gather,
 		.oa = paddr,
-		.leaf_pgsize_lg2 = vaffs(pgsize),
 	};
 	bool single_page = false;
 	struct pt_range range;
@@ -881,13 +920,13 @@ int DOMAIN_NS(map_pages)(struct iommu_domain *domain, unsigned long iova,
 		return ret;
 
 	/* Calculate target page size and level for the leaves */
-	if (pt_has_system_page_size(common) && pgsize == PAGE_SIZE &&
-	    pgcount == 1) {
+	if (pt_has_system_page_size(common) && len == PAGE_SIZE) {
 		PT_WARN_ON(!(pgsize_bitmap & PAGE_SIZE));
 		if (log2_mod(iova | paddr, PAGE_SHIFT))
 			return -ENXIO;
 		map.leaf_pgsize_lg2 = PAGE_SHIFT;
 		map.leaf_level = 0;
+		map.num_leaves = 1;
 		single_page = true;
 	} else {
 		map.leaf_pgsize_lg2 = pt_compute_best_pgsize(
@@ -896,6 +935,9 @@ int DOMAIN_NS(map_pages)(struct iommu_domain *domain, unsigned long iova,
 			return -ENXIO;
 		map.leaf_level =
 			pt_pgsz_lg2_to_level(common, map.leaf_pgsize_lg2);
+		map.num_leaves = pt_pgsz_count(pgsize_bitmap, range.va,
+					       range.last_va, paddr,
+					       map.leaf_pgsize_lg2);
 	}
 
 	ret = check_map_range(iommu_table, &range, &map);
@@ -918,7 +960,6 @@ int DOMAIN_NS(map_pages)(struct iommu_domain *domain, unsigned long iova,
 	*mapped += map.oa - paddr;
 	return ret;
 }
-EXPORT_SYMBOL_NS_GPL(DOMAIN_NS(map_pages), "GENERIC_PT_IOMMU");
 
 struct pt_unmap_args {
 	struct iommu_pages_list free_list;
@@ -1020,34 +1061,12 @@ static __maybe_unused int __unmap_range(struct pt_range *range, void *arg,
 	return ret;
 }
 
-/**
- * unmap_pages() - Make a range of IOVA empty/not present
- * @domain: Domain to manipulate
- * @iova: IO virtual address to start
- * @pgsize: Length of each page
- * @pgcount: Length of the range in pgsize units starting from @iova
- * @iotlb_gather: Gather struct that must be flushed on return
- *
- * unmap_pages() will remove a translation created by map_pages(). It cannot
- * subdivide a mapping created by map_pages(), so it should be called with IOVA
- * ranges that match those passed to map_pages(). The IOVA range can aggregate
- * contiguous map_pages() calls so long as no individual range is split.
- *
- * Context: The caller must hold a write range lock that includes
- * the whole range.
- *
- * Returns: Number of bytes of VA unmapped. iova + res will be the point
- * unmapping stopped.
- */
-size_t DOMAIN_NS(unmap_pages)(struct iommu_domain *domain, unsigned long iova,
-			      size_t pgsize, size_t pgcount,
+static size_t NS(unmap_range)(struct pt_iommu *iommu_table, dma_addr_t iova,
+			      dma_addr_t len,
 			      struct iommu_iotlb_gather *iotlb_gather)
 {
-	struct pt_iommu *iommu_table =
-		container_of(domain, struct pt_iommu, domain);
 	struct pt_unmap_args unmap = { .free_list = IOMMU_PAGES_LIST_INIT(
 					       unmap.free_list) };
-	pt_vaddr_t len = pgsize * pgcount;
 	struct pt_range range;
 	int ret;
 
@@ -1057,12 +1076,11 @@ size_t DOMAIN_NS(unmap_pages)(struct iommu_domain *domain, unsigned long iova,
 
 	pt_walk_range(&range, __unmap_range, &unmap);
 
-	gather_range_pages(iotlb_gather, iommu_table, iova, len,
+	gather_range_pages(iotlb_gather, iommu_table, iova, unmap.unmapped,
 			   &unmap.free_list);
 
 	return unmap.unmapped;
 }
-EXPORT_SYMBOL_NS_GPL(DOMAIN_NS(unmap_pages), "GENERIC_PT_IOMMU");
 
 static void NS(get_info)(struct pt_iommu *iommu_table,
 			 struct pt_iommu_info *info)
@@ -1110,6 +1128,8 @@ static void NS(deinit)(struct pt_iommu *iommu_table)
 }
 
 static const struct pt_iommu_ops NS(ops) = {
+	.map_range = NS(map_range),
+	.unmap_range = NS(unmap_range),
 #if IS_ENABLED(CONFIG_IOMMUFD_DRIVER) && defined(pt_entry_is_write_dirty) && \
 	IS_ENABLED(CONFIG_IOMMUFD_TEST) && defined(pt_entry_make_write_dirty)
 	.set_dirty = NS(set_dirty),
@@ -1172,6 +1192,7 @@ static int pt_iommu_init_domain(struct pt_iommu *iommu_table,
 
 	domain->type = __IOMMU_DOMAIN_PAGING;
 	domain->pgsize_bitmap = info.pgsize_bitmap;
+	domain->is_iommupt = true;
 
 	if (pt_feature(common, PT_FEAT_DYNAMIC_TOP))
 		range = _pt_top_range(common,
diff --git a/drivers/iommu/generic_pt/kunit_generic_pt.h b/drivers/iommu/generic_pt/kunit_generic_pt.h
index 68278bf..374e475 100644
--- a/drivers/iommu/generic_pt/kunit_generic_pt.h
+++ b/drivers/iommu/generic_pt/kunit_generic_pt.h
@@ -312,6 +312,17 @@ static void test_best_pgsize(struct kunit *test)
 	}
 }
 
+static void test_pgsz_count(struct kunit *test)
+{
+	KUNIT_EXPECT_EQ(test,
+			pt_pgsz_count(SZ_4K, 0, SZ_1G - 1, 0, ilog2(SZ_4K)),
+			SZ_1G / SZ_4K);
+	KUNIT_EXPECT_EQ(test,
+			pt_pgsz_count(SZ_2M | SZ_4K, SZ_4K, SZ_1G - 1, SZ_4K,
+				      ilog2(SZ_4K)),
+			(SZ_2M - SZ_4K) / SZ_4K);
+}
+
 /*
  * Check that pt_install_table() and pt_table_pa() match
  */
@@ -770,6 +781,7 @@ static struct kunit_case generic_pt_test_cases[] = {
 	KUNIT_CASE_FMT(test_init),
 	KUNIT_CASE_FMT(test_bitops),
 	KUNIT_CASE_FMT(test_best_pgsize),
+	KUNIT_CASE_FMT(test_pgsz_count),
 	KUNIT_CASE_FMT(test_table_ptr),
 	KUNIT_CASE_FMT(test_max_va),
 	KUNIT_CASE_FMT(test_table_radix),
diff --git a/drivers/iommu/generic_pt/pt_iter.h b/drivers/iommu/generic_pt/pt_iter.h
index c0d8617..3e45dbd 100644
--- a/drivers/iommu/generic_pt/pt_iter.h
+++ b/drivers/iommu/generic_pt/pt_iter.h
@@ -569,6 +569,28 @@ static inline unsigned int pt_compute_best_pgsize(pt_vaddr_t pgsz_bitmap,
 	return pgsz_lg2;
 }
 
+/*
+ * Return the number of pgsize_lg2 leaf entries that can be mapped for
+ * va to oa. This accounts for any requirement to reduce or increase the page
+ * size across the VA range.
+ */
+static inline pt_vaddr_t pt_pgsz_count(pt_vaddr_t pgsz_bitmap, pt_vaddr_t va,
+				       pt_vaddr_t last_va, pt_oaddr_t oa,
+				       unsigned int pgsize_lg2)
+{
+	pt_vaddr_t len = last_va - va + 1;
+	pt_vaddr_t next_pgsizes = log2_set_mod(pgsz_bitmap, 0, pgsize_lg2 + 1);
+
+	if (next_pgsizes) {
+		unsigned int next_pgsize_lg2 = vaffs(next_pgsizes);
+
+		if (log2_mod(va ^ oa, next_pgsize_lg2) == 0)
+			len = min(len, log2_set_mod_max(va, next_pgsize_lg2) -
+					       va + 1);
+	}
+	return log2_div(len, pgsize_lg2);
+}
+
 #define _PT_MAKE_CALL_LEVEL(fn)                                          \
 	static __always_inline int fn(struct pt_range *range, void *arg, \
 				      unsigned int level,                \
diff --git a/drivers/iommu/intel/cache.c b/drivers/iommu/intel/cache.c
index 249ab58..be8410f 100644
--- a/drivers/iommu/intel/cache.c
+++ b/drivers/iommu/intel/cache.c
@@ -255,7 +255,6 @@ void cache_tag_unassign_domain(struct dmar_domain *domain,
 
 static unsigned long calculate_psi_aligned_address(unsigned long start,
 						   unsigned long end,
-						   unsigned long *_pages,
 						   unsigned long *_mask)
 {
 	unsigned long pages = aligned_nrpages(start, end - start + 1);
@@ -281,10 +280,8 @@ static unsigned long calculate_psi_aligned_address(unsigned long start,
 		 */
 		shared_bits = ~(pfn ^ end_pfn) & ~bitmask;
 		mask = shared_bits ? __ffs(shared_bits) : MAX_AGAW_PFN_WIDTH;
-		aligned_pages = 1UL << mask;
 	}
 
-	*_pages = aligned_pages;
 	*_mask = mask;
 
 	return ALIGN_DOWN(start, VTD_PAGE_SIZE << mask);
@@ -330,19 +327,19 @@ static void qi_batch_add_dev_iotlb(struct intel_iommu *iommu, u16 sid, u16 pfsid
 	qi_batch_increment_index(iommu, batch);
 }
 
+static void qi_batch_add_piotlb_all(struct intel_iommu *iommu, u16 did,
+				    u32 pasid, struct qi_batch *batch)
+{
+	qi_desc_piotlb_all(did, pasid, &batch->descs[batch->index]);
+	qi_batch_increment_index(iommu, batch);
+}
+
 static void qi_batch_add_piotlb(struct intel_iommu *iommu, u16 did, u32 pasid,
-				u64 addr, unsigned long npages, bool ih,
+				u64 addr, unsigned int size_order, bool ih,
 				struct qi_batch *batch)
 {
-	/*
-	 * npages == -1 means a PASID-selective invalidation, otherwise,
-	 * a positive value for Page-selective-within-PASID invalidation.
-	 * 0 is not a valid input.
-	 */
-	if (!npages)
-		return;
-
-	qi_desc_piotlb(did, pasid, addr, npages, ih, &batch->descs[batch->index]);
+	qi_desc_piotlb(did, pasid, addr, size_order, ih,
+		       &batch->descs[batch->index]);
 	qi_batch_increment_index(iommu, batch);
 }
 
@@ -371,15 +368,18 @@ static bool intel_domain_use_piotlb(struct dmar_domain *domain)
 }
 
 static void cache_tag_flush_iotlb(struct dmar_domain *domain, struct cache_tag *tag,
-				  unsigned long addr, unsigned long pages,
-				  unsigned long mask, int ih)
+				  unsigned long addr, unsigned long mask, int ih)
 {
 	struct intel_iommu *iommu = tag->iommu;
 	u64 type = DMA_TLB_PSI_FLUSH;
 
 	if (intel_domain_use_piotlb(domain)) {
-		qi_batch_add_piotlb(iommu, tag->domain_id, tag->pasid, addr,
-				    pages, ih, domain->qi_batch);
+		if (mask >= MAX_AGAW_PFN_WIDTH)
+			qi_batch_add_piotlb_all(iommu, tag->domain_id,
+						tag->pasid, domain->qi_batch);
+		else
+			qi_batch_add_piotlb(iommu, tag->domain_id, tag->pasid,
+					    addr, mask, ih, domain->qi_batch);
 		return;
 	}
 
@@ -388,7 +388,7 @@ static void cache_tag_flush_iotlb(struct dmar_domain *domain, struct cache_tag *
 	 * is too big.
 	 */
 	if (!cap_pgsel_inv(iommu->cap) ||
-	    mask > cap_max_amask_val(iommu->cap) || pages == -1) {
+	    mask > cap_max_amask_val(iommu->cap)) {
 		addr = 0;
 		mask = 0;
 		ih = 0;
@@ -437,16 +437,15 @@ void cache_tag_flush_range(struct dmar_domain *domain, unsigned long start,
 			   unsigned long end, int ih)
 {
 	struct intel_iommu *iommu = NULL;
-	unsigned long pages, mask, addr;
+	unsigned long mask, addr;
 	struct cache_tag *tag;
 	unsigned long flags;
 
 	if (start == 0 && end == ULONG_MAX) {
 		addr = 0;
-		pages = -1;
 		mask = MAX_AGAW_PFN_WIDTH;
 	} else {
-		addr = calculate_psi_aligned_address(start, end, &pages, &mask);
+		addr = calculate_psi_aligned_address(start, end, &mask);
 	}
 
 	spin_lock_irqsave(&domain->cache_lock, flags);
@@ -458,7 +457,7 @@ void cache_tag_flush_range(struct dmar_domain *domain, unsigned long start,
 		switch (tag->type) {
 		case CACHE_TAG_IOTLB:
 		case CACHE_TAG_NESTING_IOTLB:
-			cache_tag_flush_iotlb(domain, tag, addr, pages, mask, ih);
+			cache_tag_flush_iotlb(domain, tag, addr, mask, ih);
 			break;
 		case CACHE_TAG_NESTING_DEVTLB:
 			/*
@@ -476,7 +475,7 @@ void cache_tag_flush_range(struct dmar_domain *domain, unsigned long start,
 			break;
 		}
 
-		trace_cache_tag_flush_range(tag, start, end, addr, pages, mask);
+		trace_cache_tag_flush_range(tag, start, end, addr, mask);
 	}
 	qi_batch_flush_descs(iommu, domain->qi_batch);
 	spin_unlock_irqrestore(&domain->cache_lock, flags);
@@ -506,11 +505,11 @@ void cache_tag_flush_range_np(struct dmar_domain *domain, unsigned long start,
 			      unsigned long end)
 {
 	struct intel_iommu *iommu = NULL;
-	unsigned long pages, mask, addr;
+	unsigned long mask, addr;
 	struct cache_tag *tag;
 	unsigned long flags;
 
-	addr = calculate_psi_aligned_address(start, end, &pages, &mask);
+	addr = calculate_psi_aligned_address(start, end, &mask);
 
 	spin_lock_irqsave(&domain->cache_lock, flags);
 	list_for_each_entry(tag, &domain->cache_tags, node) {
@@ -526,9 +525,9 @@ void cache_tag_flush_range_np(struct dmar_domain *domain, unsigned long start,
 
 		if (tag->type == CACHE_TAG_IOTLB ||
 		    tag->type == CACHE_TAG_NESTING_IOTLB)
-			cache_tag_flush_iotlb(domain, tag, addr, pages, mask, 0);
+			cache_tag_flush_iotlb(domain, tag, addr, mask, 0);
 
-		trace_cache_tag_flush_range_np(tag, start, end, addr, pages, mask);
+		trace_cache_tag_flush_range_np(tag, start, end, addr, mask);
 	}
 	qi_batch_flush_descs(iommu, domain->qi_batch);
 	spin_unlock_irqrestore(&domain->cache_lock, flags);
diff --git a/drivers/iommu/intel/debugfs.c b/drivers/iommu/intel/debugfs.c
index 617fd81..21e4e46 100644
--- a/drivers/iommu/intel/debugfs.c
+++ b/drivers/iommu/intel/debugfs.c
@@ -133,13 +133,13 @@ static int iommu_regset_show(struct seq_file *m, void *unused)
 		 */
 		raw_spin_lock_irqsave(&iommu->register_lock, flag);
 		for (i = 0 ; i < ARRAY_SIZE(iommu_regs_32); i++) {
-			value = dmar_readl(iommu->reg + iommu_regs_32[i].offset);
+			value = readl(iommu->reg + iommu_regs_32[i].offset);
 			seq_printf(m, "%-16s\t0x%02x\t\t0x%016llx\n",
 				   iommu_regs_32[i].regs, iommu_regs_32[i].offset,
 				   value);
 		}
 		for (i = 0 ; i < ARRAY_SIZE(iommu_regs_64); i++) {
-			value = dmar_readq(iommu->reg + iommu_regs_64[i].offset);
+			value = readq(iommu->reg + iommu_regs_64[i].offset);
 			seq_printf(m, "%-16s\t0x%02x\t\t0x%016llx\n",
 				   iommu_regs_64[i].regs, iommu_regs_64[i].offset,
 				   value);
@@ -247,7 +247,7 @@ static void ctx_tbl_walk(struct seq_file *m, struct intel_iommu *iommu, u16 bus)
 		tbl_wlk.ctx_entry = context;
 		m->private = &tbl_wlk;
 
-		if (dmar_readq(iommu->reg + DMAR_RTADDR_REG) & DMA_RTADDR_SMT) {
+		if (readq(iommu->reg + DMAR_RTADDR_REG) & DMA_RTADDR_SMT) {
 			pasid_dir_ptr = context->lo & VTD_PAGE_MASK;
 			pasid_dir_size = get_pasid_dir_size(context);
 			pasid_dir_walk(m, pasid_dir_ptr, pasid_dir_size);
@@ -285,7 +285,7 @@ static int dmar_translation_struct_show(struct seq_file *m, void *unused)
 
 	rcu_read_lock();
 	for_each_active_iommu(iommu, drhd) {
-		sts = dmar_readl(iommu->reg + DMAR_GSTS_REG);
+		sts = readl(iommu->reg + DMAR_GSTS_REG);
 		if (!(sts & DMA_GSTS_TES)) {
 			seq_printf(m, "DMA Remapping is not enabled on %s\n",
 				   iommu->name);
@@ -364,13 +364,13 @@ static int domain_translation_struct_show(struct seq_file *m,
 		if (seg != iommu->segment)
 			continue;
 
-		sts = dmar_readl(iommu->reg + DMAR_GSTS_REG);
+		sts = readl(iommu->reg + DMAR_GSTS_REG);
 		if (!(sts & DMA_GSTS_TES)) {
 			seq_printf(m, "DMA Remapping is not enabled on %s\n",
 				   iommu->name);
 			continue;
 		}
-		if (dmar_readq(iommu->reg + DMAR_RTADDR_REG) & DMA_RTADDR_SMT)
+		if (readq(iommu->reg + DMAR_RTADDR_REG) & DMA_RTADDR_SMT)
 			scalable = true;
 		else
 			scalable = false;
@@ -538,8 +538,8 @@ static int invalidation_queue_show(struct seq_file *m, void *unused)
 		raw_spin_lock_irqsave(&qi->q_lock, flags);
 		seq_printf(m, " Base: 0x%llx\tHead: %lld\tTail: %lld\n",
 			   (u64)virt_to_phys(qi->desc),
-			   dmar_readq(iommu->reg + DMAR_IQH_REG) >> shift,
-			   dmar_readq(iommu->reg + DMAR_IQT_REG) >> shift);
+			   readq(iommu->reg + DMAR_IQH_REG) >> shift,
+			   readq(iommu->reg + DMAR_IQT_REG) >> shift);
 		invalidation_queue_entry_show(m, iommu);
 		raw_spin_unlock_irqrestore(&qi->q_lock, flags);
 		seq_putc(m, '\n');
@@ -620,7 +620,7 @@ static int ir_translation_struct_show(struct seq_file *m, void *unused)
 		seq_printf(m, "Remapped Interrupt supported on IOMMU: %s\n",
 			   iommu->name);
 
-		sts = dmar_readl(iommu->reg + DMAR_GSTS_REG);
+		sts = readl(iommu->reg + DMAR_GSTS_REG);
 		if (iommu->ir_table && (sts & DMA_GSTS_IRES)) {
 			irta = virt_to_phys(iommu->ir_table->base);
 			seq_printf(m, " IR table address:%llx\n", irta);
diff --git a/drivers/iommu/intel/dmar.c b/drivers/iommu/intel/dmar.c
index 69222db..d33c119 100644
--- a/drivers/iommu/intel/dmar.c
+++ b/drivers/iommu/intel/dmar.c
@@ -899,8 +899,8 @@ dmar_validate_one_drhd(struct acpi_dmar_header *entry, void *arg)
 		return -EINVAL;
 	}
 
-	cap = dmar_readq(addr + DMAR_CAP_REG);
-	ecap = dmar_readq(addr + DMAR_ECAP_REG);
+	cap = readq(addr + DMAR_CAP_REG);
+	ecap = readq(addr + DMAR_ECAP_REG);
 
 	if (arg)
 		iounmap(addr);
@@ -982,8 +982,8 @@ static int map_iommu(struct intel_iommu *iommu, struct dmar_drhd_unit *drhd)
 		goto release;
 	}
 
-	iommu->cap = dmar_readq(iommu->reg + DMAR_CAP_REG);
-	iommu->ecap = dmar_readq(iommu->reg + DMAR_ECAP_REG);
+	iommu->cap = readq(iommu->reg + DMAR_CAP_REG);
+	iommu->ecap = readq(iommu->reg + DMAR_ECAP_REG);
 
 	if (iommu->cap == (uint64_t)-1 && iommu->ecap == (uint64_t)-1) {
 		err = -EINVAL;
@@ -1017,8 +1017,8 @@ static int map_iommu(struct intel_iommu *iommu, struct dmar_drhd_unit *drhd)
 		int i;
 
 		for (i = 0; i < DMA_MAX_NUM_ECMDCAP; i++) {
-			iommu->ecmdcap[i] = dmar_readq(iommu->reg + DMAR_ECCAP_REG +
-						       i * DMA_ECMD_REG_STEP);
+			iommu->ecmdcap[i] = readq(iommu->reg + DMAR_ECCAP_REG +
+						  i * DMA_ECMD_REG_STEP);
 		}
 	}
 
@@ -1239,8 +1239,8 @@ static const char *qi_type_string(u8 type)
 
 static void qi_dump_fault(struct intel_iommu *iommu, u32 fault)
 {
-	unsigned int head = dmar_readl(iommu->reg + DMAR_IQH_REG);
-	u64 iqe_err = dmar_readq(iommu->reg + DMAR_IQER_REG);
+	unsigned int head = readl(iommu->reg + DMAR_IQH_REG);
+	u64 iqe_err = readq(iommu->reg + DMAR_IQER_REG);
 	struct qi_desc *desc = iommu->qi->desc + head;
 
 	if (fault & DMA_FSTS_IQE)
@@ -1321,7 +1321,7 @@ static int qi_check_fault(struct intel_iommu *iommu, int index, int wait_index)
 		 * SID field is valid only when the ITE field is Set in FSTS_REG
 		 * see Intel VT-d spec r4.1, section 11.4.9.9
 		 */
-		iqe_err = dmar_readq(iommu->reg + DMAR_IQER_REG);
+		iqe_err = readq(iommu->reg + DMAR_IQER_REG);
 		ite_sid = DMAR_IQER_REG_ITESID(iqe_err);
 
 		writel(DMA_FSTS_ITE, iommu->reg + DMAR_FSTS_REG);
@@ -1550,23 +1550,12 @@ void qi_flush_dev_iotlb(struct intel_iommu *iommu, u16 sid, u16 pfsid,
 	qi_submit_sync(iommu, &desc, 1, 0);
 }
 
-/* PASID-based IOTLB invalidation */
-void qi_flush_piotlb(struct intel_iommu *iommu, u16 did, u32 pasid, u64 addr,
-		     unsigned long npages, bool ih)
+/* PASID-selective IOTLB invalidation */
+void qi_flush_piotlb_all(struct intel_iommu *iommu, u16 did, u32 pasid)
 {
-	struct qi_desc desc = {.qw2 = 0, .qw3 = 0};
+	struct qi_desc desc = {};
 
-	/*
-	 * npages == -1 means a PASID-selective invalidation, otherwise,
-	 * a positive value for Page-selective-within-PASID invalidation.
-	 * 0 is not a valid input.
-	 */
-	if (WARN_ON(!npages)) {
-		pr_err("Invalid input npages = %ld\n", npages);
-		return;
-	}
-
-	qi_desc_piotlb(did, pasid, addr, npages, ih, &desc);
+	qi_desc_piotlb_all(did, pasid, &desc);
 	qi_submit_sync(iommu, &desc, 1, 0);
 }
 
@@ -1661,7 +1650,7 @@ static void __dmar_enable_qi(struct intel_iommu *iommu)
 	/* write zero to the tail reg */
 	writel(0, iommu->reg + DMAR_IQT_REG);
 
-	dmar_writeq(iommu->reg + DMAR_IQA_REG, val);
+	writeq(val, iommu->reg + DMAR_IQA_REG);
 
 	iommu->gcmd |= DMA_GCMD_QIE;
 	writel(iommu->gcmd, iommu->reg + DMAR_GCMD_REG);
@@ -1980,8 +1969,8 @@ irqreturn_t dmar_fault(int irq, void *dev_id)
 			source_id = dma_frcd_source_id(data);
 
 			pasid_present = dma_frcd_pasid_present(data);
-			guest_addr = dmar_readq(iommu->reg + reg +
-					fault_index * PRIMARY_FAULT_REG_LEN);
+			guest_addr = readq(iommu->reg + reg +
+					   fault_index * PRIMARY_FAULT_REG_LEN);
 			guest_addr = dma_frcd_page_addr(guest_addr);
 		}
 
diff --git a/drivers/iommu/intel/iommu.c b/drivers/iommu/intel/iommu.c
index ef7613b..63cee27 100644
--- a/drivers/iommu/intel/iommu.c
+++ b/drivers/iommu/intel/iommu.c
@@ -697,7 +697,7 @@ static void iommu_set_root_entry(struct intel_iommu *iommu)
 		addr |= DMA_RTADDR_SMT;
 
 	raw_spin_lock_irqsave(&iommu->register_lock, flag);
-	dmar_writeq(iommu->reg + DMAR_RTADDR_REG, addr);
+	writeq(addr, iommu->reg + DMAR_RTADDR_REG);
 
 	writel(iommu->gcmd | DMA_GCMD_SRTP, iommu->reg + DMAR_GCMD_REG);
 
@@ -765,11 +765,11 @@ static void __iommu_flush_context(struct intel_iommu *iommu,
 	val |= DMA_CCMD_ICC;
 
 	raw_spin_lock_irqsave(&iommu->register_lock, flag);
-	dmar_writeq(iommu->reg + DMAR_CCMD_REG, val);
+	writeq(val, iommu->reg + DMAR_CCMD_REG);
 
 	/* Make sure hardware complete it */
 	IOMMU_WAIT_OP(iommu, DMAR_CCMD_REG,
-		dmar_readq, (!(val & DMA_CCMD_ICC)), val);
+		readq, (!(val & DMA_CCMD_ICC)), val);
 
 	raw_spin_unlock_irqrestore(&iommu->register_lock, flag);
 }
@@ -806,12 +806,12 @@ void __iommu_flush_iotlb(struct intel_iommu *iommu, u16 did, u64 addr,
 	raw_spin_lock_irqsave(&iommu->register_lock, flag);
 	/* Note: Only uses first TLB reg currently */
 	if (val_iva)
-		dmar_writeq(iommu->reg + tlb_offset, val_iva);
-	dmar_writeq(iommu->reg + tlb_offset + 8, val);
+		writeq(val_iva, iommu->reg + tlb_offset);
+	writeq(val, iommu->reg + tlb_offset + 8);
 
 	/* Make sure hardware complete it */
 	IOMMU_WAIT_OP(iommu, tlb_offset + 8,
-		dmar_readq, (!(val & DMA_TLB_IVT)), val);
+		readq, (!(val & DMA_TLB_IVT)), val);
 
 	raw_spin_unlock_irqrestore(&iommu->register_lock, flag);
 
@@ -1533,7 +1533,7 @@ static int copy_translation_tables(struct intel_iommu *iommu)
 	int bus, ret;
 	bool new_ext, ext;
 
-	rtaddr_reg = dmar_readq(iommu->reg + DMAR_RTADDR_REG);
+	rtaddr_reg = readq(iommu->reg + DMAR_RTADDR_REG);
 	ext        = !!(rtaddr_reg & DMA_RTADDR_SMT);
 	new_ext    = !!sm_supported(iommu);
 
@@ -3212,14 +3212,14 @@ static bool intel_iommu_capable(struct device *dev, enum iommu_cap cap)
 
 	switch (cap) {
 	case IOMMU_CAP_CACHE_COHERENCY:
-	case IOMMU_CAP_DEFERRED_FLUSH:
-		return true;
 	case IOMMU_CAP_PRE_BOOT_PROTECTION:
 		return dmar_platform_optin();
 	case IOMMU_CAP_ENFORCE_CACHE_COHERENCY:
 		return ecap_sc_support(info->iommu->ecap);
 	case IOMMU_CAP_DIRTY_TRACKING:
 		return ssads_supported(info->iommu);
+	case IOMMU_CAP_PCI_ATS_SUPPORTED:
+		return info->ats_supported;
 	default:
 		return false;
 	}
@@ -3618,9 +3618,6 @@ static int intel_iommu_set_dev_pasid(struct iommu_domain *domain,
 	if (!pasid_supported(iommu) || dev_is_real_dma_subdevice(dev))
 		return -EOPNOTSUPP;
 
-	if (domain->dirty_ops)
-		return -EINVAL;
-
 	if (context_copied(iommu, info->bus, info->devfn))
 		return -EBUSY;
 
@@ -3684,19 +3681,27 @@ static void *intel_iommu_hw_info(struct device *dev, u32 *length,
 	return vtd;
 }
 
-/*
- * Set dirty tracking for the device list of a domain. The caller must
- * hold the domain->lock when calling it.
- */
-static int device_set_dirty_tracking(struct list_head *devices, bool enable)
+/* Set dirty tracking for the devices that the domain has been attached. */
+static int domain_set_dirty_tracking(struct dmar_domain *domain, bool enable)
 {
 	struct device_domain_info *info;
+	struct dev_pasid_info *dev_pasid;
 	int ret = 0;
 
-	list_for_each_entry(info, devices, link) {
+	lockdep_assert_held(&domain->lock);
+
+	list_for_each_entry(info, &domain->devices, link) {
 		ret = intel_pasid_setup_dirty_tracking(info->iommu, info->dev,
 						       IOMMU_NO_PASID, enable);
 		if (ret)
+			return ret;
+	}
+
+	list_for_each_entry(dev_pasid, &domain->dev_pasids, link_domain) {
+		info = dev_iommu_priv_get(dev_pasid->dev);
+		ret = intel_pasid_setup_dirty_tracking(info->iommu, info->dev,
+						       dev_pasid->pasid, enable);
+		if (ret)
 			break;
 	}
 
@@ -3713,7 +3718,7 @@ static int parent_domain_set_dirty_tracking(struct dmar_domain *domain,
 	spin_lock(&domain->s1_lock);
 	list_for_each_entry(s1_domain, &domain->s1_domains, s2_link) {
 		spin_lock_irqsave(&s1_domain->lock, flags);
-		ret = device_set_dirty_tracking(&s1_domain->devices, enable);
+		ret = domain_set_dirty_tracking(s1_domain, enable);
 		spin_unlock_irqrestore(&s1_domain->lock, flags);
 		if (ret)
 			goto err_unwind;
@@ -3724,8 +3729,7 @@ static int parent_domain_set_dirty_tracking(struct dmar_domain *domain,
 err_unwind:
 	list_for_each_entry(s1_domain, &domain->s1_domains, s2_link) {
 		spin_lock_irqsave(&s1_domain->lock, flags);
-		device_set_dirty_tracking(&s1_domain->devices,
-					  domain->dirty_tracking);
+		domain_set_dirty_tracking(s1_domain, domain->dirty_tracking);
 		spin_unlock_irqrestore(&s1_domain->lock, flags);
 	}
 	spin_unlock(&domain->s1_lock);
@@ -3742,7 +3746,7 @@ static int intel_iommu_set_dirty_tracking(struct iommu_domain *domain,
 	if (dmar_domain->dirty_tracking == enable)
 		goto out_unlock;
 
-	ret = device_set_dirty_tracking(&dmar_domain->devices, enable);
+	ret = domain_set_dirty_tracking(dmar_domain, enable);
 	if (ret)
 		goto err_unwind;
 
@@ -3759,8 +3763,7 @@ static int intel_iommu_set_dirty_tracking(struct iommu_domain *domain,
 	return 0;
 
 err_unwind:
-	device_set_dirty_tracking(&dmar_domain->devices,
-				  dmar_domain->dirty_tracking);
+	domain_set_dirty_tracking(dmar_domain, dmar_domain->dirty_tracking);
 	spin_unlock(&dmar_domain->lock);
 	return ret;
 }
@@ -4185,7 +4188,7 @@ int ecmd_submit_sync(struct intel_iommu *iommu, u8 ecmd, u64 oa, u64 ob)
 
 	raw_spin_lock_irqsave(&iommu->register_lock, flags);
 
-	res = dmar_readq(iommu->reg + DMAR_ECRSP_REG);
+	res = readq(iommu->reg + DMAR_ECRSP_REG);
 	if (res & DMA_ECMD_ECRSP_IP) {
 		ret = -EBUSY;
 		goto err;
@@ -4198,10 +4201,10 @@ int ecmd_submit_sync(struct intel_iommu *iommu, u8 ecmd, u64 oa, u64 ob)
 	 * - It's not invoked in any critical path. The extra MMIO
 	 *   write doesn't bring any performance concerns.
 	 */
-	dmar_writeq(iommu->reg + DMAR_ECEO_REG, ob);
-	dmar_writeq(iommu->reg + DMAR_ECMD_REG, ecmd | (oa << DMA_ECMD_OA_SHIFT));
+	writeq(ob, iommu->reg + DMAR_ECEO_REG);
+	writeq(ecmd | (oa << DMA_ECMD_OA_SHIFT), iommu->reg + DMAR_ECMD_REG);
 
-	IOMMU_WAIT_OP(iommu, DMAR_ECRSP_REG, dmar_readq,
+	IOMMU_WAIT_OP(iommu, DMAR_ECRSP_REG, readq,
 		      !(res & DMA_ECMD_ECRSP_IP), res);
 
 	if (res & DMA_ECMD_ECRSP_IP) {
diff --git a/drivers/iommu/intel/iommu.h b/drivers/iommu/intel/iommu.h
index 599913f..ef14556 100644
--- a/drivers/iommu/intel/iommu.h
+++ b/drivers/iommu/intel/iommu.h
@@ -148,11 +148,6 @@
 
 #define OFFSET_STRIDE		(9)
 
-#define dmar_readq(a) readq(a)
-#define dmar_writeq(a,v) writeq(v,a)
-#define dmar_readl(a) readl(a)
-#define dmar_writel(a, v) writel(v, a)
-
 #define DMAR_VER_MAJOR(v)		(((v) & 0xf0) >> 4)
 #define DMAR_VER_MINOR(v)		((v) & 0x0f)
 
@@ -1082,31 +1077,26 @@ static inline void qi_desc_dev_iotlb(u16 sid, u16 pfsid, u16 qdep, u64 addr,
 	desc->qw3 = 0;
 }
 
+/* PASID-selective IOTLB invalidation */
+static inline void qi_desc_piotlb_all(u16 did, u32 pasid, struct qi_desc *desc)
+{
+	desc->qw0 = QI_EIOTLB_PASID(pasid) | QI_EIOTLB_DID(did) |
+		    QI_EIOTLB_GRAN(QI_GRAN_NONG_PASID) | QI_EIOTLB_TYPE;
+	desc->qw1 = 0;
+}
+
+/* Page-selective-within-PASID IOTLB invalidation */
 static inline void qi_desc_piotlb(u16 did, u32 pasid, u64 addr,
-				  unsigned long npages, bool ih,
+				  unsigned int size_order, bool ih,
 				  struct qi_desc *desc)
 {
-	if (npages == -1) {
-		desc->qw0 = QI_EIOTLB_PASID(pasid) |
-				QI_EIOTLB_DID(did) |
-				QI_EIOTLB_GRAN(QI_GRAN_NONG_PASID) |
-				QI_EIOTLB_TYPE;
-		desc->qw1 = 0;
-	} else {
-		int mask = ilog2(__roundup_pow_of_two(npages));
-		unsigned long align = (1ULL << (VTD_PAGE_SHIFT + mask));
-
-		if (WARN_ON_ONCE(!IS_ALIGNED(addr, align)))
-			addr = ALIGN_DOWN(addr, align);
-
-		desc->qw0 = QI_EIOTLB_PASID(pasid) |
-				QI_EIOTLB_DID(did) |
-				QI_EIOTLB_GRAN(QI_GRAN_PSI_PASID) |
-				QI_EIOTLB_TYPE;
-		desc->qw1 = QI_EIOTLB_ADDR(addr) |
-				QI_EIOTLB_IH(ih) |
-				QI_EIOTLB_AM(mask);
-	}
+	/*
+	 * calculate_psi_aligned_address() must be used for addr and size_order
+	 */
+	desc->qw0 = QI_EIOTLB_PASID(pasid) | QI_EIOTLB_DID(did) |
+		    QI_EIOTLB_GRAN(QI_GRAN_PSI_PASID) | QI_EIOTLB_TYPE;
+	desc->qw1 = QI_EIOTLB_ADDR(addr) | QI_EIOTLB_IH(ih) |
+		    QI_EIOTLB_AM(size_order);
 }
 
 static inline void qi_desc_dev_iotlb_pasid(u16 sid, u16 pfsid, u32 pasid,
@@ -1168,8 +1158,7 @@ void qi_flush_iotlb(struct intel_iommu *iommu, u16 did, u64 addr,
 void qi_flush_dev_iotlb(struct intel_iommu *iommu, u16 sid, u16 pfsid,
 			u16 qdep, u64 addr, unsigned mask);
 
-void qi_flush_piotlb(struct intel_iommu *iommu, u16 did, u32 pasid, u64 addr,
-		     unsigned long npages, bool ih);
+void qi_flush_piotlb_all(struct intel_iommu *iommu, u16 did, u32 pasid);
 
 void qi_flush_dev_iotlb_pasid(struct intel_iommu *iommu, u16 sid, u16 pfsid,
 			      u32 pasid, u16 qdep, u64 addr,
diff --git a/drivers/iommu/intel/irq_remapping.c b/drivers/iommu/intel/irq_remapping.c
index 1cd2101..25c26f7 100644
--- a/drivers/iommu/intel/irq_remapping.c
+++ b/drivers/iommu/intel/irq_remapping.c
@@ -422,7 +422,7 @@ static int iommu_load_old_irte(struct intel_iommu *iommu)
 	u64 irta;
 
 	/* Check whether the old ir-table has the same size as ours */
-	irta = dmar_readq(iommu->reg + DMAR_IRTA_REG);
+	irta = readq(iommu->reg + DMAR_IRTA_REG);
 	if ((irta & INTR_REMAP_TABLE_REG_SIZE_MASK)
 	     != INTR_REMAP_TABLE_REG_SIZE)
 		return -EINVAL;
@@ -465,8 +465,8 @@ static void iommu_set_irq_remapping(struct intel_iommu *iommu, int mode)
 
 	raw_spin_lock_irqsave(&iommu->register_lock, flags);
 
-	dmar_writeq(iommu->reg + DMAR_IRTA_REG,
-		    (addr) | IR_X2APIC_MODE(mode) | INTR_REMAP_TABLE_REG_SIZE);
+	writeq((addr) | IR_X2APIC_MODE(mode) | INTR_REMAP_TABLE_REG_SIZE,
+	       iommu->reg + DMAR_IRTA_REG);
 
 	/* Set interrupt-remapping table pointer */
 	writel(iommu->gcmd | DMA_GCMD_SIRTP, iommu->reg + DMAR_GCMD_REG);
diff --git a/drivers/iommu/intel/pasid.c b/drivers/iommu/intel/pasid.c
index 9d30015b8..89541b7 100644
--- a/drivers/iommu/intel/pasid.c
+++ b/drivers/iommu/intel/pasid.c
@@ -282,7 +282,7 @@ void intel_pasid_tear_down_entry(struct intel_iommu *iommu, struct device *dev,
 	pasid_cache_invalidation_with_pasid(iommu, did, pasid);
 
 	if (pgtt == PASID_ENTRY_PGTT_PT || pgtt == PASID_ENTRY_PGTT_FL_ONLY)
-		qi_flush_piotlb(iommu, did, pasid, 0, -1, 0);
+		qi_flush_piotlb_all(iommu, did, pasid);
 	else
 		iommu->flush.flush_iotlb(iommu, did, 0, 0, DMA_TLB_DSI_FLUSH);
 
@@ -308,7 +308,7 @@ static void pasid_flush_caches(struct intel_iommu *iommu,
 
 	if (cap_caching_mode(iommu->cap)) {
 		pasid_cache_invalidation_with_pasid(iommu, did, pasid);
-		qi_flush_piotlb(iommu, did, pasid, 0, -1, 0);
+		qi_flush_piotlb_all(iommu, did, pasid);
 	} else {
 		iommu_flush_write_buffer(iommu);
 	}
@@ -342,7 +342,7 @@ static void intel_pasid_flush_present(struct intel_iommu *iommu,
 	 *      Addr[63:12]=0x7FFFFFFF_FFFFF) to affected functions
 	 */
 	pasid_cache_invalidation_with_pasid(iommu, did, pasid);
-	qi_flush_piotlb(iommu, did, pasid, 0, -1, 0);
+	qi_flush_piotlb_all(iommu, did, pasid);
 
 	devtlb_invalidation_with_pasid(iommu, dev, pasid);
 }
diff --git a/drivers/iommu/intel/perfmon.c b/drivers/iommu/intel/perfmon.c
index fec51b6..eb1df7a 100644
--- a/drivers/iommu/intel/perfmon.c
+++ b/drivers/iommu/intel/perfmon.c
@@ -99,20 +99,20 @@ IOMMU_PMU_ATTR(filter_page_table,	"config2:32-36",	IOMMU_PMU_FILTER_PAGE_TABLE);
 #define iommu_pmu_set_filter(_name, _config, _filter, _idx, _econfig)		\
 {										\
 	if ((iommu_pmu->filter & _filter) && iommu_pmu_en_##_name(_econfig)) {	\
-		dmar_writel(iommu_pmu->cfg_reg + _idx * IOMMU_PMU_CFG_OFFSET +	\
-			    IOMMU_PMU_CFG_SIZE +				\
-			    (ffs(_filter) - 1) * IOMMU_PMU_CFG_FILTERS_OFFSET,	\
-			    iommu_pmu_get_##_name(_config) | IOMMU_PMU_FILTER_EN);\
+		writel(iommu_pmu_get_##_name(_config) | IOMMU_PMU_FILTER_EN,	\
+		       iommu_pmu->cfg_reg + _idx * IOMMU_PMU_CFG_OFFSET +	\
+		       IOMMU_PMU_CFG_SIZE +					\
+		       (ffs(_filter) - 1) * IOMMU_PMU_CFG_FILTERS_OFFSET);	\
 	}									\
 }
 
 #define iommu_pmu_clear_filter(_filter, _idx)					\
 {										\
 	if (iommu_pmu->filter & _filter) {					\
-		dmar_writel(iommu_pmu->cfg_reg + _idx * IOMMU_PMU_CFG_OFFSET +	\
-			    IOMMU_PMU_CFG_SIZE +				\
-			    (ffs(_filter) - 1) * IOMMU_PMU_CFG_FILTERS_OFFSET,	\
-			    0);							\
+		writel(0,							\
+		       iommu_pmu->cfg_reg + _idx * IOMMU_PMU_CFG_OFFSET +	\
+		       IOMMU_PMU_CFG_SIZE +					\
+		       (ffs(_filter) - 1) * IOMMU_PMU_CFG_FILTERS_OFFSET);	\
 	}									\
 }
 
@@ -307,7 +307,7 @@ static void iommu_pmu_event_update(struct perf_event *event)
 
 again:
 	prev_count = local64_read(&hwc->prev_count);
-	new_count = dmar_readq(iommu_event_base(iommu_pmu, hwc->idx));
+	new_count = readq(iommu_event_base(iommu_pmu, hwc->idx));
 	if (local64_xchg(&hwc->prev_count, new_count) != prev_count)
 		goto again;
 
@@ -340,7 +340,7 @@ static void iommu_pmu_start(struct perf_event *event, int flags)
 	hwc->state = 0;
 
 	/* Always reprogram the period */
-	count = dmar_readq(iommu_event_base(iommu_pmu, hwc->idx));
+	count = readq(iommu_event_base(iommu_pmu, hwc->idx));
 	local64_set((&hwc->prev_count), count);
 
 	/*
@@ -411,7 +411,7 @@ static int iommu_pmu_assign_event(struct iommu_pmu *iommu_pmu,
 	hwc->idx = idx;
 
 	/* config events */
-	dmar_writeq(iommu_config_base(iommu_pmu, idx), hwc->config);
+	writeq(hwc->config, iommu_config_base(iommu_pmu, idx));
 
 	iommu_pmu_set_filter(requester_id, event->attr.config1,
 			     IOMMU_PMU_FILTER_REQUESTER_ID, idx,
@@ -496,7 +496,7 @@ static void iommu_pmu_counter_overflow(struct iommu_pmu *iommu_pmu)
 	 * Two counters may be overflowed very close. Always check
 	 * whether there are more to handle.
 	 */
-	while ((status = dmar_readq(iommu_pmu->overflow))) {
+	while ((status = readq(iommu_pmu->overflow))) {
 		for_each_set_bit(i, (unsigned long *)&status, iommu_pmu->num_cntr) {
 			/*
 			 * Find the assigned event of the counter.
@@ -510,7 +510,7 @@ static void iommu_pmu_counter_overflow(struct iommu_pmu *iommu_pmu)
 			iommu_pmu_event_update(event);
 		}
 
-		dmar_writeq(iommu_pmu->overflow, status);
+		writeq(status, iommu_pmu->overflow);
 	}
 }
 
@@ -518,13 +518,13 @@ static irqreturn_t iommu_pmu_irq_handler(int irq, void *dev_id)
 {
 	struct intel_iommu *iommu = dev_id;
 
-	if (!dmar_readl(iommu->reg + DMAR_PERFINTRSTS_REG))
+	if (!readl(iommu->reg + DMAR_PERFINTRSTS_REG))
 		return IRQ_NONE;
 
 	iommu_pmu_counter_overflow(iommu->pmu);
 
 	/* Clear the status bit */
-	dmar_writel(iommu->reg + DMAR_PERFINTRSTS_REG, DMA_PERFINTRSTS_PIS);
+	writel(DMA_PERFINTRSTS_PIS, iommu->reg + DMAR_PERFINTRSTS_REG);
 
 	return IRQ_HANDLED;
 }
@@ -555,7 +555,7 @@ static int __iommu_pmu_register(struct intel_iommu *iommu)
 static inline void __iomem *
 get_perf_reg_address(struct intel_iommu *iommu, u32 offset)
 {
-	u32 off = dmar_readl(iommu->reg + offset);
+	u32 off = readl(iommu->reg + offset);
 
 	return iommu->reg + off;
 }
@@ -574,7 +574,7 @@ int alloc_iommu_pmu(struct intel_iommu *iommu)
 	if (!cap_ecmds(iommu->cap))
 		return -ENODEV;
 
-	perfcap = dmar_readq(iommu->reg + DMAR_PERFCAP_REG);
+	perfcap = readq(iommu->reg + DMAR_PERFCAP_REG);
 	/* The performance monitoring is not supported. */
 	if (!perfcap)
 		return -ENODEV;
@@ -617,8 +617,8 @@ int alloc_iommu_pmu(struct intel_iommu *iommu)
 	for (i = 0; i < iommu_pmu->num_eg; i++) {
 		u64 pcap;
 
-		pcap = dmar_readq(iommu->reg + DMAR_PERFEVNTCAP_REG +
-				  i * IOMMU_PMU_CAP_REGS_STEP);
+		pcap = readq(iommu->reg + DMAR_PERFEVNTCAP_REG +
+			     i * IOMMU_PMU_CAP_REGS_STEP);
 		iommu_pmu->evcap[i] = pecap_es(pcap);
 	}
 
@@ -651,9 +651,9 @@ int alloc_iommu_pmu(struct intel_iommu *iommu)
 	 * Width.
 	 */
 	for (i = 0; i < iommu_pmu->num_cntr; i++) {
-		cap = dmar_readl(iommu_pmu->cfg_reg +
-				 i * IOMMU_PMU_CFG_OFFSET +
-				 IOMMU_PMU_CFG_CNTRCAP_OFFSET);
+		cap = readl(iommu_pmu->cfg_reg +
+			    i * IOMMU_PMU_CFG_OFFSET +
+			    IOMMU_PMU_CFG_CNTRCAP_OFFSET);
 		if (!iommu_cntrcap_pcc(cap))
 			continue;
 
@@ -675,9 +675,9 @@ int alloc_iommu_pmu(struct intel_iommu *iommu)
 
 		/* Override with per-counter event capabilities */
 		for (j = 0; j < iommu_cntrcap_egcnt(cap); j++) {
-			cap = dmar_readl(iommu_pmu->cfg_reg + i * IOMMU_PMU_CFG_OFFSET +
-					 IOMMU_PMU_CFG_CNTREVCAP_OFFSET +
-					 (j * IOMMU_PMU_OFF_REGS_STEP));
+			cap = readl(iommu_pmu->cfg_reg + i * IOMMU_PMU_CFG_OFFSET +
+				    IOMMU_PMU_CFG_CNTREVCAP_OFFSET +
+				    (j * IOMMU_PMU_OFF_REGS_STEP));
 			iommu_pmu->cntr_evcap[i][iommu_event_group(cap)] = iommu_event_select(cap);
 			/*
 			 * Some events may only be supported by a specific counter.
diff --git a/drivers/iommu/intel/prq.c b/drivers/iommu/intel/prq.c
index ff63c22..586055e 100644
--- a/drivers/iommu/intel/prq.c
+++ b/drivers/iommu/intel/prq.c
@@ -81,8 +81,8 @@ void intel_iommu_drain_pasid_prq(struct device *dev, u32 pasid)
 	 */
 prq_retry:
 	reinit_completion(&iommu->prq_complete);
-	tail = dmar_readq(iommu->reg + DMAR_PQT_REG) & PRQ_RING_MASK;
-	head = dmar_readq(iommu->reg + DMAR_PQH_REG) & PRQ_RING_MASK;
+	tail = readq(iommu->reg + DMAR_PQT_REG) & PRQ_RING_MASK;
+	head = readq(iommu->reg + DMAR_PQH_REG) & PRQ_RING_MASK;
 	while (head != tail) {
 		struct page_req_dsc *req;
 
@@ -113,7 +113,7 @@ void intel_iommu_drain_pasid_prq(struct device *dev, u32 pasid)
 		qi_desc_dev_iotlb(sid, info->pfsid, info->ats_qdep, 0,
 				  MAX_AGAW_PFN_WIDTH, &desc[2]);
 	} else {
-		qi_desc_piotlb(did, pasid, 0, -1, 0, &desc[1]);
+		qi_desc_piotlb_all(did, pasid, &desc[1]);
 		qi_desc_dev_iotlb_pasid(sid, info->pfsid, pasid, info->ats_qdep,
 					0, MAX_AGAW_PFN_WIDTH, &desc[2]);
 	}
@@ -208,8 +208,8 @@ static irqreturn_t prq_event_thread(int irq, void *d)
 	 */
 	writel(DMA_PRS_PPR, iommu->reg + DMAR_PRS_REG);
 
-	tail = dmar_readq(iommu->reg + DMAR_PQT_REG) & PRQ_RING_MASK;
-	head = dmar_readq(iommu->reg + DMAR_PQH_REG) & PRQ_RING_MASK;
+	tail = readq(iommu->reg + DMAR_PQT_REG) & PRQ_RING_MASK;
+	head = readq(iommu->reg + DMAR_PQH_REG) & PRQ_RING_MASK;
 	handled = (head != tail);
 	while (head != tail) {
 		req = &iommu->prq[head / sizeof(*req)];
@@ -259,7 +259,7 @@ static irqreturn_t prq_event_thread(int irq, void *d)
 		head = (head + sizeof(*req)) & PRQ_RING_MASK;
 	}
 
-	dmar_writeq(iommu->reg + DMAR_PQH_REG, tail);
+	writeq(tail, iommu->reg + DMAR_PQH_REG);
 
 	/*
 	 * Clear the page request overflow bit and wake up all threads that
@@ -268,8 +268,8 @@ static irqreturn_t prq_event_thread(int irq, void *d)
 	if (readl(iommu->reg + DMAR_PRS_REG) & DMA_PRS_PRO) {
 		pr_info_ratelimited("IOMMU: %s: PRQ overflow detected\n",
 				    iommu->name);
-		head = dmar_readq(iommu->reg + DMAR_PQH_REG) & PRQ_RING_MASK;
-		tail = dmar_readq(iommu->reg + DMAR_PQT_REG) & PRQ_RING_MASK;
+		head = readq(iommu->reg + DMAR_PQH_REG) & PRQ_RING_MASK;
+		tail = readq(iommu->reg + DMAR_PQT_REG) & PRQ_RING_MASK;
 		if (head == tail) {
 			iopf_queue_discard_partial(iommu->iopf_queue);
 			writel(DMA_PRS_PRO, iommu->reg + DMAR_PRS_REG);
@@ -325,9 +325,9 @@ int intel_iommu_enable_prq(struct intel_iommu *iommu)
 		       iommu->name);
 		goto free_iopfq;
 	}
-	dmar_writeq(iommu->reg + DMAR_PQH_REG, 0ULL);
-	dmar_writeq(iommu->reg + DMAR_PQT_REG, 0ULL);
-	dmar_writeq(iommu->reg + DMAR_PQA_REG, virt_to_phys(iommu->prq) | PRQ_ORDER);
+	writeq(0ULL, iommu->reg + DMAR_PQH_REG);
+	writeq(0ULL, iommu->reg + DMAR_PQT_REG);
+	writeq(virt_to_phys(iommu->prq) | PRQ_ORDER, iommu->reg + DMAR_PQA_REG);
 
 	init_completion(&iommu->prq_complete);
 
@@ -348,9 +348,9 @@ int intel_iommu_enable_prq(struct intel_iommu *iommu)
 
 int intel_iommu_finish_prq(struct intel_iommu *iommu)
 {
-	dmar_writeq(iommu->reg + DMAR_PQH_REG, 0ULL);
-	dmar_writeq(iommu->reg + DMAR_PQT_REG, 0ULL);
-	dmar_writeq(iommu->reg + DMAR_PQA_REG, 0ULL);
+	writeq(0ULL, iommu->reg + DMAR_PQH_REG);
+	writeq(0ULL, iommu->reg + DMAR_PQT_REG);
+	writeq(0ULL, iommu->reg + DMAR_PQA_REG);
 
 	if (iommu->pr_irq) {
 		free_irq(iommu->pr_irq, iommu);
diff --git a/drivers/iommu/intel/trace.h b/drivers/iommu/intel/trace.h
index 6311ba3..9f0ab43 100644
--- a/drivers/iommu/intel/trace.h
+++ b/drivers/iommu/intel/trace.h
@@ -132,8 +132,8 @@ DEFINE_EVENT(cache_tag_log, cache_tag_unassign,
 
 DECLARE_EVENT_CLASS(cache_tag_flush,
 	TP_PROTO(struct cache_tag *tag, unsigned long start, unsigned long end,
-		 unsigned long addr, unsigned long pages, unsigned long mask),
-	TP_ARGS(tag, start, end, addr, pages, mask),
+		 unsigned long addr, unsigned long mask),
+	TP_ARGS(tag, start, end, addr, mask),
 	TP_STRUCT__entry(
 		__string(iommu, tag->iommu->name)
 		__string(dev, dev_name(tag->dev))
@@ -143,7 +143,6 @@ DECLARE_EVENT_CLASS(cache_tag_flush,
 		__field(unsigned long, start)
 		__field(unsigned long, end)
 		__field(unsigned long, addr)
-		__field(unsigned long, pages)
 		__field(unsigned long, mask)
 	),
 	TP_fast_assign(
@@ -155,10 +154,9 @@ DECLARE_EVENT_CLASS(cache_tag_flush,
 		__entry->start = start;
 		__entry->end = end;
 		__entry->addr = addr;
-		__entry->pages = pages;
 		__entry->mask = mask;
 	),
-	TP_printk("%s %s[%d] type %s did %d [0x%lx-0x%lx] addr 0x%lx pages 0x%lx mask 0x%lx",
+	TP_printk("%s %s[%d] type %s did %d [0x%lx-0x%lx] addr 0x%lx mask 0x%lx",
 		  __get_str(iommu), __get_str(dev), __entry->pasid,
 		  __print_symbolic(__entry->type,
 			{ CACHE_TAG_IOTLB,		"iotlb" },
@@ -166,20 +164,20 @@ DECLARE_EVENT_CLASS(cache_tag_flush,
 			{ CACHE_TAG_NESTING_IOTLB,	"nesting_iotlb" },
 			{ CACHE_TAG_NESTING_DEVTLB,	"nesting_devtlb" }),
 		__entry->domain_id, __entry->start, __entry->end,
-		__entry->addr, __entry->pages, __entry->mask
+		__entry->addr, __entry->mask
 	)
 );
 
 DEFINE_EVENT(cache_tag_flush, cache_tag_flush_range,
 	TP_PROTO(struct cache_tag *tag, unsigned long start, unsigned long end,
-		 unsigned long addr, unsigned long pages, unsigned long mask),
-	TP_ARGS(tag, start, end, addr, pages, mask)
+		 unsigned long addr, unsigned long mask),
+	TP_ARGS(tag, start, end, addr, mask)
 );
 
 DEFINE_EVENT(cache_tag_flush, cache_tag_flush_range_np,
 	TP_PROTO(struct cache_tag *tag, unsigned long start, unsigned long end,
-		 unsigned long addr, unsigned long pages, unsigned long mask),
-	TP_ARGS(tag, start, end, addr, pages, mask)
+		 unsigned long addr, unsigned long mask),
+	TP_ARGS(tag, start, end, addr, mask)
 );
 #endif /* _TRACE_INTEL_IOMMU_H */
 
diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index 50718ab..78756c3 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -34,6 +34,7 @@
 #include <linux/sched/mm.h>
 #include <linux/msi.h>
 #include <uapi/linux/iommufd.h>
+#include <linux/generic_pt/iommu.h>
 
 #include "dma-iommu.h"
 #include "iommu-priv.h"
@@ -2572,14 +2573,14 @@ static size_t iommu_pgsize(struct iommu_domain *domain, unsigned long iova,
 	return pgsize;
 }
 
-int iommu_map_nosync(struct iommu_domain *domain, unsigned long iova,
-		phys_addr_t paddr, size_t size, int prot, gfp_t gfp)
+static int __iommu_map_domain_pgtbl(struct iommu_domain *domain,
+				    unsigned long iova, phys_addr_t paddr,
+				    size_t size, int prot, gfp_t gfp)
 {
 	const struct iommu_domain_ops *ops = domain->ops;
 	unsigned long orig_iova = iova;
 	unsigned int min_pagesz;
 	size_t orig_size = size;
-	phys_addr_t orig_paddr = paddr;
 	int ret = 0;
 
 	might_sleep_if(gfpflags_allow_blocking(gfp));
@@ -2636,12 +2637,9 @@ int iommu_map_nosync(struct iommu_domain *domain, unsigned long iova,
 	/* unroll mapping in case something went wrong */
 	if (ret) {
 		iommu_unmap(domain, orig_iova, orig_size - size);
-	} else {
-		trace_map(orig_iova, orig_paddr, orig_size);
-		iommu_debug_map(domain, orig_paddr, orig_size);
+		return ret;
 	}
-
-	return ret;
+	return 0;
 }
 
 int iommu_sync_map(struct iommu_domain *domain, unsigned long iova, size_t size)
@@ -2653,6 +2651,32 @@ int iommu_sync_map(struct iommu_domain *domain, unsigned long iova, size_t size)
 	return ops->iotlb_sync_map(domain, iova, size);
 }
 
+int iommu_map_nosync(struct iommu_domain *domain, unsigned long iova,
+		phys_addr_t paddr, size_t size, int prot, gfp_t gfp)
+{
+	struct pt_iommu *pt = iommupt_from_domain(domain);
+	int ret;
+
+	if (pt) {
+		size_t mapped = 0;
+
+		ret = pt->ops->map_range(pt, iova, paddr, size, prot, gfp,
+					 &mapped);
+		if (ret) {
+			iommu_unmap(domain, iova, mapped);
+			return ret;
+		}
+		return 0;
+	}
+	ret = __iommu_map_domain_pgtbl(domain, iova, paddr, size, prot, gfp);
+	if (!ret)
+		return ret;
+
+	trace_map(iova, paddr, size);
+	iommu_debug_map(domain, paddr, size);
+	return 0;
+}
+
 int iommu_map(struct iommu_domain *domain, unsigned long iova,
 	      phys_addr_t paddr, size_t size, int prot, gfp_t gfp)
 {
@@ -2670,13 +2694,12 @@ int iommu_map(struct iommu_domain *domain, unsigned long iova,
 }
 EXPORT_SYMBOL_GPL(iommu_map);
 
-static size_t __iommu_unmap(struct iommu_domain *domain,
-			    unsigned long iova, size_t size,
-			    struct iommu_iotlb_gather *iotlb_gather)
+static size_t
+__iommu_unmap_domain_pgtbl(struct iommu_domain *domain, unsigned long iova,
+			   size_t size, struct iommu_iotlb_gather *iotlb_gather)
 {
 	const struct iommu_domain_ops *ops = domain->ops;
 	size_t unmapped_page, unmapped = 0;
-	unsigned long orig_iova = iova;
 	unsigned int min_pagesz;
 
 	if (unlikely(!(domain->type & __IOMMU_DOMAIN_PAGING)))
@@ -2722,8 +2745,23 @@ static size_t __iommu_unmap(struct iommu_domain *domain,
 		unmapped += unmapped_page;
 	}
 
-	trace_unmap(orig_iova, size, unmapped);
-	iommu_debug_unmap_end(domain, orig_iova, size, unmapped);
+	return unmapped;
+}
+
+static size_t __iommu_unmap(struct iommu_domain *domain, unsigned long iova,
+			    size_t size,
+			    struct iommu_iotlb_gather *iotlb_gather)
+{
+	struct pt_iommu *pt = iommupt_from_domain(domain);
+	size_t unmapped;
+
+	if (pt)
+		unmapped = pt->ops->unmap_range(pt, iova, size, iotlb_gather);
+	else
+		unmapped = __iommu_unmap_domain_pgtbl(domain, iova, size,
+						      iotlb_gather);
+	trace_unmap(iova, size, unmapped);
+	iommu_debug_unmap_end(domain, iova, size, unmapped);
 	return unmapped;
 }
 
diff --git a/drivers/iommu/iommufd/device.c b/drivers/iommu/iommufd/device.c
index 344d620..92c5d5e 100644
--- a/drivers/iommu/iommufd/device.c
+++ b/drivers/iommu/iommufd/device.c
@@ -1624,6 +1624,10 @@ int iommufd_get_hw_info(struct iommufd_ucmd *ucmd)
 	if (device_iommu_capable(idev->dev, IOMMU_CAP_DIRTY_TRACKING))
 		cmd->out_capabilities |= IOMMU_HW_CAP_DIRTY_TRACKING;
 
+	/* Report when ATS cannot be used for this device */
+	if (!device_iommu_capable(idev->dev, IOMMU_CAP_PCI_ATS_SUPPORTED))
+		cmd->out_capabilities |= IOMMU_HW_CAP_PCI_ATS_NOT_SUPPORTED;
+
 	cmd->out_max_pasid_log2 = 0;
 	/*
 	 * Currently, all iommu drivers enable PASID in the probe_device()
diff --git a/drivers/iommu/iova.c b/drivers/iommu/iova.c
index f9cd183..021daf6 100644
--- a/drivers/iommu/iova.c
+++ b/drivers/iommu/iova.c
@@ -611,7 +611,8 @@ static struct iova_magazine *iova_magazine_alloc(gfp_t flags)
 
 static void iova_magazine_free(struct iova_magazine *mag)
 {
-	kmem_cache_free(iova_magazine_cache, mag);
+	if (mag)
+		kmem_cache_free(iova_magazine_cache, mag);
 }
 
 static void
diff --git a/drivers/iommu/riscv/Kconfig b/drivers/iommu/riscv/Kconfig
index c071816..b86e5ab 100644
--- a/drivers/iommu/riscv/Kconfig
+++ b/drivers/iommu/riscv/Kconfig
@@ -3,9 +3,13 @@
 
 config RISCV_IOMMU
 	bool "RISC-V IOMMU Support"
-	depends on RISCV && 64BIT
-	default y
+	default RISCV
+	depends on GENERIC_MSI_IRQ
+	depends on (RISCV || COMPILE_TEST) && 64BIT
 	select IOMMU_API
+	select GENERIC_PT
+	select IOMMU_PT
+	select IOMMU_PT_RISCV64
 	help
 	  Support for implementations of the RISC-V IOMMU architecture that
 	  complements the RISC-V MMU capabilities, providing similar address
diff --git a/drivers/iommu/riscv/iommu-bits.h b/drivers/iommu/riscv/iommu-bits.h
index 98daf0e..29a0040 100644
--- a/drivers/iommu/riscv/iommu-bits.h
+++ b/drivers/iommu/riscv/iommu-bits.h
@@ -17,6 +17,7 @@
 #include <linux/types.h>
 #include <linux/bitfield.h>
 #include <linux/bits.h>
+#include <asm/page.h>
 
 /*
  * Chapter 5: Memory Mapped register interface
@@ -718,7 +719,8 @@ static inline void riscv_iommu_cmd_inval_vma(struct riscv_iommu_command *cmd)
 static inline void riscv_iommu_cmd_inval_set_addr(struct riscv_iommu_command *cmd,
 						  u64 addr)
 {
-	cmd->dword1 = FIELD_PREP(RISCV_IOMMU_CMD_IOTINVAL_ADDR, phys_to_pfn(addr));
+	cmd->dword1 =
+		FIELD_PREP(RISCV_IOMMU_CMD_IOTINVAL_ADDR, PHYS_PFN(addr));
 	cmd->dword0 |= RISCV_IOMMU_CMD_IOTINVAL_AV;
 }
 
diff --git a/drivers/iommu/riscv/iommu-platform.c b/drivers/iommu/riscv/iommu-platform.c
index 83a28c8..399ba8f 100644
--- a/drivers/iommu/riscv/iommu-platform.c
+++ b/drivers/iommu/riscv/iommu-platform.c
@@ -68,12 +68,7 @@ static int riscv_iommu_platform_probe(struct platform_device *pdev)
 	iommu->caps = riscv_iommu_readq(iommu, RISCV_IOMMU_REG_CAPABILITIES);
 	iommu->fctl = riscv_iommu_readl(iommu, RISCV_IOMMU_REG_FCTL);
 
-	iommu->irqs_count = platform_irq_count(pdev);
-	if (iommu->irqs_count <= 0)
-		return dev_err_probe(dev, -ENODEV,
-				     "no IRQ resources provided\n");
-	if (iommu->irqs_count > RISCV_IOMMU_INTR_COUNT)
-		iommu->irqs_count = RISCV_IOMMU_INTR_COUNT;
+	iommu->irqs_count = RISCV_IOMMU_INTR_COUNT;
 
 	igs = FIELD_GET(RISCV_IOMMU_CAPABILITIES_IGS, iommu->caps);
 	switch (igs) {
@@ -120,6 +115,16 @@ static int riscv_iommu_platform_probe(struct platform_device *pdev)
 		fallthrough;
 
 	case RISCV_IOMMU_CAPABILITIES_IGS_WSI:
+		ret = platform_irq_count(pdev);
+		if (ret <= 0)
+			return dev_err_probe(dev, -ENODEV,
+					     "no IRQ resources provided\n");
+
+		iommu->irqs_count = ret;
+
+		if (iommu->irqs_count > RISCV_IOMMU_INTR_COUNT)
+			iommu->irqs_count = RISCV_IOMMU_INTR_COUNT;
+
 		for (vec = 0; vec < iommu->irqs_count; vec++)
 			iommu->irqs[vec] = platform_get_irq(pdev, vec);
 
diff --git a/drivers/iommu/riscv/iommu.c b/drivers/iommu/riscv/iommu.c
index fa2ebfd..a31f50b 100644
--- a/drivers/iommu/riscv/iommu.c
+++ b/drivers/iommu/riscv/iommu.c
@@ -21,6 +21,7 @@
 #include <linux/iopoll.h>
 #include <linux/kernel.h>
 #include <linux/pci.h>
+#include <linux/generic_pt/iommu.h>
 
 #include "../iommu-pages.h"
 #include "iommu-bits.h"
@@ -159,7 +160,7 @@ static int riscv_iommu_queue_alloc(struct riscv_iommu_device *iommu,
 	if (FIELD_GET(RISCV_IOMMU_PPN_FIELD, qb)) {
 		const size_t queue_size = entry_size << (logsz + 1);
 
-		queue->phys = pfn_to_phys(FIELD_GET(RISCV_IOMMU_PPN_FIELD, qb));
+		queue->phys = PFN_PHYS(FIELD_GET(RISCV_IOMMU_PPN_FIELD, qb));
 		queue->base = devm_ioremap(iommu->dev, queue->phys, queue_size);
 	} else {
 		do {
@@ -368,6 +369,8 @@ static int riscv_iommu_queue_wait(struct riscv_iommu_queue *queue,
 				  unsigned int timeout_us)
 {
 	unsigned int cons = atomic_read(&queue->head);
+	unsigned int flags = RISCV_IOMMU_CQCSR_CQMF | RISCV_IOMMU_CQCSR_CMD_TO |
+			     RISCV_IOMMU_CQCSR_CMD_ILL;
 
 	/* Already processed by the consumer */
 	if ((int)(cons - index) > 0)
@@ -375,6 +378,7 @@ static int riscv_iommu_queue_wait(struct riscv_iommu_queue *queue,
 
 	/* Monitor consumer index */
 	return readx_poll_timeout(riscv_iommu_queue_cons, queue, cons,
+				 (riscv_iommu_readl(queue->iommu, queue->qcr) & flags) ||
 				 (int)(cons - index) > 0, 0, timeout_us);
 }
 
@@ -435,7 +439,9 @@ static unsigned int riscv_iommu_queue_send(struct riscv_iommu_queue *queue,
 	 * 6. Make sure the doorbell write to the device has finished before updating
 	 *    the shadow tail index in normal memory. 'fence o, w'
 	 */
+#ifdef CONFIG_MMIOWB
 	mmiowb();
+#endif
 	atomic_inc(&queue->tail);
 
 	/* 7. Complete submission and restore local interrupts */
@@ -806,15 +812,15 @@ static int riscv_iommu_iodir_set_mode(struct riscv_iommu_device *iommu,
 
 /* This struct contains protection domain specific IOMMU driver data. */
 struct riscv_iommu_domain {
-	struct iommu_domain domain;
+	union {
+		struct iommu_domain domain;
+		struct pt_iommu_riscv_64 riscvpt;
+	};
 	struct list_head bonds;
 	spinlock_t lock;		/* protect bonds list updates. */
 	int pscid;
-	bool amo_enabled;
-	int numa_node;
-	unsigned int pgd_mode;
-	unsigned long *pgd_root;
 };
+PT_IOMMU_CHECK_DOMAIN(struct riscv_iommu_domain, riscvpt.iommu, domain);
 
 #define iommu_domain_to_riscv(iommu_domain) \
 	container_of(iommu_domain, struct riscv_iommu_domain, domain)
@@ -928,8 +934,6 @@ static void riscv_iommu_iotlb_inval(struct riscv_iommu_domain *domain,
 	struct riscv_iommu_bond *bond;
 	struct riscv_iommu_device *iommu, *prev;
 	struct riscv_iommu_command cmd;
-	unsigned long len = end - start + 1;
-	unsigned long iova;
 
 	/*
 	 * For each IOMMU linked with this protection domain (via bonds->dev),
@@ -972,11 +976,14 @@ static void riscv_iommu_iotlb_inval(struct riscv_iommu_domain *domain,
 
 		riscv_iommu_cmd_inval_vma(&cmd);
 		riscv_iommu_cmd_inval_set_pscid(&cmd, domain->pscid);
-		if (len && len < RISCV_IOMMU_IOTLB_INVAL_LIMIT) {
-			for (iova = start; iova < end; iova += PAGE_SIZE) {
+		if (end - start < RISCV_IOMMU_IOTLB_INVAL_LIMIT - 1) {
+			unsigned long iova = start;
+
+			do {
 				riscv_iommu_cmd_inval_set_addr(&cmd, iova);
 				riscv_iommu_cmd_send(iommu, &cmd);
-			}
+			} while (!check_add_overflow(iova, PAGE_SIZE, &iova) &&
+				 iova < end);
 		} else {
 			riscv_iommu_cmd_send(iommu, &cmd);
 		}
@@ -996,7 +1003,67 @@ static void riscv_iommu_iotlb_inval(struct riscv_iommu_domain *domain,
 }
 
 #define RISCV_IOMMU_FSC_BARE 0
+/*
+ * This function sends IOTINVAL commands as required by the RISC-V
+ * IOMMU specification (Section 6.3.1 and 6.3.2 in 1.0 spec version)
+ * after modifying DDT or PDT entries
+ */
+static void riscv_iommu_iodir_iotinval(struct riscv_iommu_device *iommu,
+				       bool inval_pdt, unsigned long iohgatp,
+				       struct riscv_iommu_dc *dc,
+				       struct riscv_iommu_pc *pc)
+{
+	struct riscv_iommu_command cmd;
 
+	riscv_iommu_cmd_inval_vma(&cmd);
+
+	if (FIELD_GET(RISCV_IOMMU_DC_IOHGATP_MODE, iohgatp) ==
+	    RISCV_IOMMU_DC_IOHGATP_MODE_BARE) {
+		if (inval_pdt) {
+			/*
+			 * IOTINVAL.VMA with GV=AV=0, and PSCV=1, and
+			 * PSCID=PC.PSCID
+			 */
+			riscv_iommu_cmd_inval_set_pscid(&cmd,
+				FIELD_GET(RISCV_IOMMU_PC_TA_PSCID, pc->ta));
+		} else {
+			if (!FIELD_GET(RISCV_IOMMU_DC_TC_PDTV, dc->tc) &&
+			    FIELD_GET(RISCV_IOMMU_DC_FSC_MODE, dc->fsc) !=
+			    RISCV_IOMMU_DC_FSC_MODE_BARE) {
+				/*
+				 * DC.tc.PDTV == 0 && DC.fsc.MODE != Bare
+				 * IOTINVAL.VMA with GV=AV=0, and PSCV=1, and
+				 * PSCID=DC.ta.PSCID
+				 */
+				riscv_iommu_cmd_inval_set_pscid(&cmd,
+					FIELD_GET(RISCV_IOMMU_DC_TA_PSCID, dc->ta));
+			}
+			/* else: IOTINVAL.VMA with GV=AV=PSCV=0 */
+		}
+	} else {
+		riscv_iommu_cmd_inval_set_gscid(&cmd,
+			FIELD_GET(RISCV_IOMMU_DC_IOHGATP_GSCID, iohgatp));
+
+		if (inval_pdt) {
+			/*
+			 * IOTINVAL.VMA with GV=1, AV=0, and PSCV=1, and
+			 * GSCID=DC.iohgatp.GSCID, PSCID=PC.PSCID
+			 */
+			riscv_iommu_cmd_inval_set_pscid(&cmd,
+				FIELD_GET(RISCV_IOMMU_PC_TA_PSCID, pc->ta));
+		}
+		/*
+		 * else: IOTINVAL.VMA with GV=1,AV=PSCV=0,and
+		 * GSCID=DC.iohgatp.GSCID
+		 *
+		 * IOTINVAL.GVMA with GV=1,AV=0,and
+		 * GSCID=DC.iohgatp.GSCID
+		 * TODO: For now, the Second-Stage feature have not yet been merged,
+		 * also issue IOTINVAL.GVMA once second-stage support is merged.
+		 */
+	}
+	riscv_iommu_cmd_send(iommu, &cmd);
+}
 /*
  * Update IODIR for the device.
  *
@@ -1031,6 +1098,11 @@ static void riscv_iommu_iodir_update(struct riscv_iommu_device *iommu,
 		riscv_iommu_cmd_iodir_inval_ddt(&cmd);
 		riscv_iommu_cmd_iodir_set_did(&cmd, fwspec->ids[i]);
 		riscv_iommu_cmd_send(iommu, &cmd);
+		/*
+		 * For now, the SVA and PASID features have not yet been merged, the
+		 * default configuration is inval_pdt=false and pc=NULL.
+		 */
+		riscv_iommu_iodir_iotinval(iommu, false, dc->iohgatp, dc, NULL);
 		sync_required = true;
 	}
 
@@ -1056,6 +1128,11 @@ static void riscv_iommu_iodir_update(struct riscv_iommu_device *iommu,
 		riscv_iommu_cmd_iodir_inval_ddt(&cmd);
 		riscv_iommu_cmd_iodir_set_did(&cmd, fwspec->ids[i]);
 		riscv_iommu_cmd_send(iommu, &cmd);
+		/*
+		 * For now, the SVA and PASID features have not yet been merged, the
+		 * default configuration is inval_pdt=false and pc=NULL.
+		 */
+		riscv_iommu_iodir_iotinval(iommu, false, dc->iohgatp, dc, NULL);
 	}
 
 	riscv_iommu_cmd_sync(iommu, RISCV_IOMMU_IOTINVAL_TIMEOUT);
@@ -1077,158 +1154,9 @@ static void riscv_iommu_iotlb_sync(struct iommu_domain *iommu_domain,
 {
 	struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
 
-	riscv_iommu_iotlb_inval(domain, gather->start, gather->end);
-}
-
-#define PT_SHIFT (PAGE_SHIFT - ilog2(sizeof(pte_t)))
-
-#define _io_pte_present(pte)	((pte) & (_PAGE_PRESENT | _PAGE_PROT_NONE))
-#define _io_pte_leaf(pte)	((pte) & _PAGE_LEAF)
-#define _io_pte_none(pte)	((pte) == 0)
-#define _io_pte_entry(pn, prot)	((_PAGE_PFN_MASK & ((pn) << _PAGE_PFN_SHIFT)) | (prot))
-
-static void riscv_iommu_pte_free(struct riscv_iommu_domain *domain,
-				 unsigned long pte,
-				 struct iommu_pages_list *freelist)
-{
-	unsigned long *ptr;
-	int i;
-
-	if (!_io_pte_present(pte) || _io_pte_leaf(pte))
-		return;
-
-	ptr = (unsigned long *)pfn_to_virt(__page_val_to_pfn(pte));
-
-	/* Recursively free all sub page table pages */
-	for (i = 0; i < PTRS_PER_PTE; i++) {
-		pte = READ_ONCE(ptr[i]);
-		if (!_io_pte_none(pte) && cmpxchg_relaxed(ptr + i, pte, 0) == pte)
-			riscv_iommu_pte_free(domain, pte, freelist);
-	}
-
-	if (freelist)
-		iommu_pages_list_add(freelist, ptr);
-	else
-		iommu_free_pages(ptr);
-}
-
-static unsigned long *riscv_iommu_pte_alloc(struct riscv_iommu_domain *domain,
-					    unsigned long iova, size_t pgsize,
-					    gfp_t gfp)
-{
-	unsigned long *ptr = domain->pgd_root;
-	unsigned long pte, old;
-	int level = domain->pgd_mode - RISCV_IOMMU_DC_FSC_IOSATP_MODE_SV39 + 2;
-	void *addr;
-
-	do {
-		const int shift = PAGE_SHIFT + PT_SHIFT * level;
-
-		ptr += ((iova >> shift) & (PTRS_PER_PTE - 1));
-		/*
-		 * Note: returned entry might be a non-leaf if there was
-		 * existing mapping with smaller granularity. Up to the caller
-		 * to replace and invalidate.
-		 */
-		if (((size_t)1 << shift) == pgsize)
-			return ptr;
-pte_retry:
-		pte = READ_ONCE(*ptr);
-		/*
-		 * This is very likely incorrect as we should not be adding
-		 * new mapping with smaller granularity on top
-		 * of existing 2M/1G mapping. Fail.
-		 */
-		if (_io_pte_present(pte) && _io_pte_leaf(pte))
-			return NULL;
-		/*
-		 * Non-leaf entry is missing, allocate and try to add to the
-		 * page table. This might race with other mappings, retry.
-		 */
-		if (_io_pte_none(pte)) {
-			addr = iommu_alloc_pages_node_sz(domain->numa_node, gfp,
-							 SZ_4K);
-			if (!addr)
-				return NULL;
-			old = pte;
-			pte = _io_pte_entry(virt_to_pfn(addr), _PAGE_TABLE);
-			if (cmpxchg_relaxed(ptr, old, pte) != old) {
-				iommu_free_pages(addr);
-				goto pte_retry;
-			}
-		}
-		ptr = (unsigned long *)pfn_to_virt(__page_val_to_pfn(pte));
-	} while (level-- > 0);
-
-	return NULL;
-}
-
-static unsigned long *riscv_iommu_pte_fetch(struct riscv_iommu_domain *domain,
-					    unsigned long iova, size_t *pte_pgsize)
-{
-	unsigned long *ptr = domain->pgd_root;
-	unsigned long pte;
-	int level = domain->pgd_mode - RISCV_IOMMU_DC_FSC_IOSATP_MODE_SV39 + 2;
-
-	do {
-		const int shift = PAGE_SHIFT + PT_SHIFT * level;
-
-		ptr += ((iova >> shift) & (PTRS_PER_PTE - 1));
-		pte = READ_ONCE(*ptr);
-		if (_io_pte_present(pte) && _io_pte_leaf(pte)) {
-			*pte_pgsize = (size_t)1 << shift;
-			return ptr;
-		}
-		if (_io_pte_none(pte))
-			return NULL;
-		ptr = (unsigned long *)pfn_to_virt(__page_val_to_pfn(pte));
-	} while (level-- > 0);
-
-	return NULL;
-}
-
-static int riscv_iommu_map_pages(struct iommu_domain *iommu_domain,
-				 unsigned long iova, phys_addr_t phys,
-				 size_t pgsize, size_t pgcount, int prot,
-				 gfp_t gfp, size_t *mapped)
-{
-	struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
-	size_t size = 0;
-	unsigned long *ptr;
-	unsigned long pte, old, pte_prot;
-	int rc = 0;
-	struct iommu_pages_list freelist = IOMMU_PAGES_LIST_INIT(freelist);
-
-	if (!(prot & IOMMU_WRITE))
-		pte_prot = _PAGE_BASE | _PAGE_READ;
-	else if (domain->amo_enabled)
-		pte_prot = _PAGE_BASE | _PAGE_READ | _PAGE_WRITE;
-	else
-		pte_prot = _PAGE_BASE | _PAGE_READ | _PAGE_WRITE | _PAGE_DIRTY;
-
-	while (pgcount) {
-		ptr = riscv_iommu_pte_alloc(domain, iova, pgsize, gfp);
-		if (!ptr) {
-			rc = -ENOMEM;
-			break;
-		}
-
-		old = READ_ONCE(*ptr);
-		pte = _io_pte_entry(phys_to_pfn(phys), pte_prot);
-		if (cmpxchg_relaxed(ptr, old, pte) != old)
-			continue;
-
-		riscv_iommu_pte_free(domain, old, &freelist);
-
-		size += pgsize;
-		iova += pgsize;
-		phys += pgsize;
-		--pgcount;
-	}
-
-	*mapped = size;
-
-	if (!iommu_pages_list_empty(&freelist)) {
+	if (iommu_pages_list_empty(&gather->freelist)) {
+		riscv_iommu_iotlb_inval(domain, gather->start, gather->end);
+	} else {
 		/*
 		 * In 1.0 spec version, the smallest scope we can use to
 		 * invalidate all levels of page table (i.e. leaf and non-leaf)
@@ -1237,71 +1165,20 @@ static int riscv_iommu_map_pages(struct iommu_domain *iommu_domain,
 		 * capability.NL (non-leaf) IOTINVAL command.
 		 */
 		riscv_iommu_iotlb_inval(domain, 0, ULONG_MAX);
-		iommu_put_pages_list(&freelist);
+		iommu_put_pages_list(&gather->freelist);
 	}
-
-	return rc;
-}
-
-static size_t riscv_iommu_unmap_pages(struct iommu_domain *iommu_domain,
-				      unsigned long iova, size_t pgsize,
-				      size_t pgcount,
-				      struct iommu_iotlb_gather *gather)
-{
-	struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
-	size_t size = pgcount << __ffs(pgsize);
-	unsigned long *ptr, old;
-	size_t unmapped = 0;
-	size_t pte_size;
-
-	while (unmapped < size) {
-		ptr = riscv_iommu_pte_fetch(domain, iova, &pte_size);
-		if (!ptr)
-			return unmapped;
-
-		/* partial unmap is not allowed, fail. */
-		if (iova & (pte_size - 1))
-			return unmapped;
-
-		old = READ_ONCE(*ptr);
-		if (cmpxchg_relaxed(ptr, old, 0) != old)
-			continue;
-
-		iommu_iotlb_gather_add_page(&domain->domain, gather, iova,
-					    pte_size);
-
-		iova += pte_size;
-		unmapped += pte_size;
-	}
-
-	return unmapped;
-}
-
-static phys_addr_t riscv_iommu_iova_to_phys(struct iommu_domain *iommu_domain,
-					    dma_addr_t iova)
-{
-	struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
-	size_t pte_size;
-	unsigned long *ptr;
-
-	ptr = riscv_iommu_pte_fetch(domain, iova, &pte_size);
-	if (!ptr)
-		return 0;
-
-	return pfn_to_phys(__page_val_to_pfn(*ptr)) | (iova & (pte_size - 1));
 }
 
 static void riscv_iommu_free_paging_domain(struct iommu_domain *iommu_domain)
 {
 	struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
-	const unsigned long pfn = virt_to_pfn(domain->pgd_root);
 
 	WARN_ON(!list_empty(&domain->bonds));
 
 	if ((int)domain->pscid > 0)
 		ida_free(&riscv_iommu_pscids, domain->pscid);
 
-	riscv_iommu_pte_free(domain, _io_pte_entry(pfn, _PAGE_TABLE), NULL);
+	pt_iommu_deinit(&domain->riscvpt.iommu);
 	kfree(domain);
 }
 
@@ -1327,13 +1204,16 @@ static int riscv_iommu_attach_paging_domain(struct iommu_domain *iommu_domain,
 	struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
 	struct riscv_iommu_device *iommu = dev_to_iommu(dev);
 	struct riscv_iommu_info *info = dev_iommu_priv_get(dev);
+	struct pt_iommu_riscv_64_hw_info pt_info;
 	u64 fsc, ta;
 
-	if (!riscv_iommu_pt_supported(iommu, domain->pgd_mode))
+	pt_iommu_riscv_64_hw_info(&domain->riscvpt, &pt_info);
+
+	if (!riscv_iommu_pt_supported(iommu, pt_info.fsc_iosatp_mode))
 		return -ENODEV;
 
-	fsc = FIELD_PREP(RISCV_IOMMU_PC_FSC_MODE, domain->pgd_mode) |
-	      FIELD_PREP(RISCV_IOMMU_PC_FSC_PPN, virt_to_pfn(domain->pgd_root));
+	fsc = FIELD_PREP(RISCV_IOMMU_PC_FSC_MODE, pt_info.fsc_iosatp_mode) |
+	      FIELD_PREP(RISCV_IOMMU_PC_FSC_PPN, pt_info.ppn);
 	ta = FIELD_PREP(RISCV_IOMMU_PC_TA_PSCID, domain->pscid) |
 	     RISCV_IOMMU_PC_TA_V;
 
@@ -1348,37 +1228,32 @@ static int riscv_iommu_attach_paging_domain(struct iommu_domain *iommu_domain,
 }
 
 static const struct iommu_domain_ops riscv_iommu_paging_domain_ops = {
+	IOMMU_PT_DOMAIN_OPS(riscv_64),
 	.attach_dev = riscv_iommu_attach_paging_domain,
 	.free = riscv_iommu_free_paging_domain,
-	.map_pages = riscv_iommu_map_pages,
-	.unmap_pages = riscv_iommu_unmap_pages,
-	.iova_to_phys = riscv_iommu_iova_to_phys,
 	.iotlb_sync = riscv_iommu_iotlb_sync,
 	.flush_iotlb_all = riscv_iommu_iotlb_flush_all,
 };
 
 static struct iommu_domain *riscv_iommu_alloc_paging_domain(struct device *dev)
 {
+	struct pt_iommu_riscv_64_cfg cfg = {};
 	struct riscv_iommu_domain *domain;
 	struct riscv_iommu_device *iommu;
-	unsigned int pgd_mode;
-	dma_addr_t va_mask;
-	int va_bits;
+	int ret;
 
 	iommu = dev_to_iommu(dev);
 	if (iommu->caps & RISCV_IOMMU_CAPABILITIES_SV57) {
-		pgd_mode = RISCV_IOMMU_DC_FSC_IOSATP_MODE_SV57;
-		va_bits = 57;
+		cfg.common.hw_max_vasz_lg2 = 57;
 	} else if (iommu->caps & RISCV_IOMMU_CAPABILITIES_SV48) {
-		pgd_mode = RISCV_IOMMU_DC_FSC_IOSATP_MODE_SV48;
-		va_bits = 48;
+		cfg.common.hw_max_vasz_lg2 = 48;
 	} else if (iommu->caps & RISCV_IOMMU_CAPABILITIES_SV39) {
-		pgd_mode = RISCV_IOMMU_DC_FSC_IOSATP_MODE_SV39;
-		va_bits = 39;
+		cfg.common.hw_max_vasz_lg2 = 39;
 	} else {
 		dev_err(dev, "cannot find supported page table mode\n");
 		return ERR_PTR(-ENODEV);
 	}
+	cfg.common.hw_max_oasz_lg2 = 56;
 
 	domain = kzalloc_obj(*domain);
 	if (!domain)
@@ -1386,43 +1261,28 @@ static struct iommu_domain *riscv_iommu_alloc_paging_domain(struct device *dev)
 
 	INIT_LIST_HEAD_RCU(&domain->bonds);
 	spin_lock_init(&domain->lock);
-	domain->numa_node = dev_to_node(iommu->dev);
-	domain->amo_enabled = !!(iommu->caps & RISCV_IOMMU_CAPABILITIES_AMO_HWAD);
-	domain->pgd_mode = pgd_mode;
-	domain->pgd_root = iommu_alloc_pages_node_sz(domain->numa_node,
-						     GFP_KERNEL_ACCOUNT, SZ_4K);
-	if (!domain->pgd_root) {
-		kfree(domain);
-		return ERR_PTR(-ENOMEM);
-	}
+	/*
+	 * 6.4 IOMMU capabilities [..] IOMMU implementations must support the
+	 * Svnapot standard extension for NAPOT Translation Contiguity.
+	 */
+	cfg.common.features = BIT(PT_FEAT_SIGN_EXTEND) |
+			      BIT(PT_FEAT_FLUSH_RANGE) |
+			      BIT(PT_FEAT_RISCV_SVNAPOT_64K);
+	domain->riscvpt.iommu.nid = dev_to_node(iommu->dev);
+	domain->domain.ops = &riscv_iommu_paging_domain_ops;
 
 	domain->pscid = ida_alloc_range(&riscv_iommu_pscids, 1,
 					RISCV_IOMMU_MAX_PSCID, GFP_KERNEL);
 	if (domain->pscid < 0) {
-		iommu_free_pages(domain->pgd_root);
-		kfree(domain);
+		riscv_iommu_free_paging_domain(&domain->domain);
 		return ERR_PTR(-ENOMEM);
 	}
 
-	/*
-	 * Note: RISC-V Privilege spec mandates that virtual addresses
-	 * need to be sign-extended, so if (VA_BITS - 1) is set, all
-	 * bits >= VA_BITS need to also be set or else we'll get a
-	 * page fault. However the code that creates the mappings
-	 * above us (e.g. iommu_dma_alloc_iova()) won't do that for us
-	 * for now, so we'll end up with invalid virtual addresses
-	 * to map. As a workaround until we get this sorted out
-	 * limit the available virtual addresses to VA_BITS - 1.
-	 */
-	va_mask = DMA_BIT_MASK(va_bits - 1);
-
-	domain->domain.geometry.aperture_start = 0;
-	domain->domain.geometry.aperture_end = va_mask;
-	domain->domain.geometry.force_aperture = true;
-	domain->domain.pgsize_bitmap = va_mask & (SZ_4K | SZ_2M | SZ_1G | SZ_512G);
-
-	domain->domain.ops = &riscv_iommu_paging_domain_ops;
-
+	ret = pt_iommu_riscv_64_init(&domain->riscvpt, &cfg, GFP_KERNEL);
+	if (ret) {
+		riscv_iommu_free_paging_domain(&domain->domain);
+		return ERR_PTR(ret);
+	}
 	return &domain->domain;
 }
 
@@ -1512,8 +1372,6 @@ static struct iommu_device *riscv_iommu_probe_device(struct device *dev)
 	 * the device directory. Do not mark the context valid yet.
 	 */
 	tc = 0;
-	if (iommu->caps & RISCV_IOMMU_CAPABILITIES_AMO_HWAD)
-		tc |= RISCV_IOMMU_DC_TC_SADE;
 	for (i = 0; i < fwspec->num_ids; i++) {
 		dc = riscv_iommu_get_dc(iommu, fwspec->ids[i]);
 		if (!dc) {
@@ -1680,3 +1538,5 @@ int riscv_iommu_init(struct riscv_iommu_device *iommu)
 	riscv_iommu_queue_disable(&iommu->cmdq);
 	return rc;
 }
+
+MODULE_IMPORT_NS("GENERIC_PT_IOMMU");
diff --git a/include/linux/generic_pt/common.h b/include/linux/generic_pt/common.h
index 6a9a1ac..fc5d0b5 100644
--- a/include/linux/generic_pt/common.h
+++ b/include/linux/generic_pt/common.h
@@ -175,6 +175,22 @@ enum {
 	PT_FEAT_VTDSS_FORCE_WRITEABLE,
 };
 
+struct pt_riscv_32 {
+	struct pt_common common;
+};
+
+struct pt_riscv_64 {
+	struct pt_common common;
+};
+
+enum {
+	/*
+	 * Support the 64k contiguous page size following the Svnapot extension.
+	 */
+	PT_FEAT_RISCV_SVNAPOT_64K = PT_FEAT_FMT_START,
+
+};
+
 struct pt_x86_64 {
 	struct pt_common common;
 };
diff --git a/include/linux/generic_pt/iommu.h b/include/linux/generic_pt/iommu.h
index 9eefbb7..dd0edd0 100644
--- a/include/linux/generic_pt/iommu.h
+++ b/include/linux/generic_pt/iommu.h
@@ -66,6 +66,13 @@ struct pt_iommu {
 	struct device *iommu_device;
 };
 
+static inline struct pt_iommu *iommupt_from_domain(struct iommu_domain *domain)
+{
+	if (!IS_ENABLED(CONFIG_IOMMU_PT) || !domain->is_iommupt)
+		return NULL;
+	return container_of(domain, struct pt_iommu, domain);
+}
+
 /**
  * struct pt_iommu_info - Details about the IOMMU page table
  *
@@ -81,6 +88,56 @@ struct pt_iommu_info {
 
 struct pt_iommu_ops {
 	/**
+	 * @map_range: Install translation for an IOVA range
+	 * @iommu_table: Table to manipulate
+	 * @iova: IO virtual address to start
+	 * @paddr: Physical/Output address to start
+	 * @len: Length of the range starting from @iova
+	 * @prot: A bitmap of IOMMU_READ/WRITE/CACHE/NOEXEC/MMIO
+	 * @gfp: GFP flags for any memory allocations
+	 *
+	 * The range starting at IOVA will have paddr installed into it. The
+	 * rage is automatically segmented into optimally sized table entries,
+	 * and can have any valid alignment.
+	 *
+	 * On error the caller will probably want to invoke unmap on the range
+	 * from iova up to the amount indicated by @mapped to return the table
+	 * back to an unchanged state.
+	 *
+	 * Context: The caller must hold a write range lock that includes
+	 * the whole range.
+	 *
+	 * Returns: -ERRNO on failure, 0 on success. The number of bytes of VA
+	 * that were mapped are added to @mapped, @mapped is not zerod first.
+	 */
+	int (*map_range)(struct pt_iommu *iommu_table, dma_addr_t iova,
+			 phys_addr_t paddr, dma_addr_t len, unsigned int prot,
+			 gfp_t gfp, size_t *mapped);
+
+	/**
+	 * @unmap_range: Make a range of IOVA empty/not present
+	 * @iommu_table: Table to manipulate
+	 * @iova: IO virtual address to start
+	 * @len: Length of the range starting from @iova
+	 * @iotlb_gather: Gather struct that must be flushed on return
+	 *
+	 * unmap_range() will remove a translation created by map_range(). It
+	 * cannot subdivide a mapping created by map_range(), so it should be
+	 * called with IOVA ranges that match those passed to map_pages. The
+	 * IOVA range can aggregate contiguous map_range() calls so long as no
+	 * individual range is split.
+	 *
+	 * Context: The caller must hold a write range lock that includes
+	 * the whole range.
+	 *
+	 * Returns: Number of bytes of VA unmapped. iova + res will be the
+	 * point unmapping stopped.
+	 */
+	size_t (*unmap_range)(struct pt_iommu *iommu_table, dma_addr_t iova,
+			      dma_addr_t len,
+			      struct iommu_iotlb_gather *iotlb_gather);
+
+	/**
 	 * @set_dirty: Make the iova write dirty
 	 * @iommu_table: Table to manipulate
 	 * @iova: IO virtual address to start
@@ -194,14 +251,6 @@ struct pt_iommu_cfg {
 #define IOMMU_PROTOTYPES(fmt)                                                  \
 	phys_addr_t pt_iommu_##fmt##_iova_to_phys(struct iommu_domain *domain, \
 						  dma_addr_t iova);            \
-	int pt_iommu_##fmt##_map_pages(struct iommu_domain *domain,            \
-				       unsigned long iova, phys_addr_t paddr,  \
-				       size_t pgsize, size_t pgcount,          \
-				       int prot, gfp_t gfp, size_t *mapped);   \
-	size_t pt_iommu_##fmt##_unmap_pages(                                   \
-		struct iommu_domain *domain, unsigned long iova,               \
-		size_t pgsize, size_t pgcount,                                 \
-		struct iommu_iotlb_gather *iotlb_gather);                      \
 	int pt_iommu_##fmt##_read_and_clear_dirty(                             \
 		struct iommu_domain *domain, unsigned long iova, size_t size,  \
 		unsigned long flags, struct iommu_dirty_bitmap *dirty);        \
@@ -222,9 +271,7 @@ struct pt_iommu_cfg {
  * iommu_pt
  */
 #define IOMMU_PT_DOMAIN_OPS(fmt)                        \
-	.iova_to_phys = &pt_iommu_##fmt##_iova_to_phys, \
-	.map_pages = &pt_iommu_##fmt##_map_pages,       \
-	.unmap_pages = &pt_iommu_##fmt##_unmap_pages
+	.iova_to_phys = &pt_iommu_##fmt##_iova_to_phys
 #define IOMMU_PT_DIRTY_OPS(fmt) \
 	.read_and_clear_dirty = &pt_iommu_##fmt##_read_and_clear_dirty
 
@@ -275,6 +322,17 @@ struct pt_iommu_vtdss_hw_info {
 
 IOMMU_FORMAT(vtdss, vtdss_pt);
 
+struct pt_iommu_riscv_64_cfg {
+	struct pt_iommu_cfg common;
+};
+
+struct pt_iommu_riscv_64_hw_info {
+	u64 ppn;
+	u8 fsc_iosatp_mode;
+};
+
+IOMMU_FORMAT(riscv_64, riscv_64pt);
+
 struct pt_iommu_x86_64_cfg {
 	struct pt_iommu_cfg common;
 	/* 4 is a 57 bit 5 level table */
diff --git a/include/linux/iommu.h b/include/linux/iommu.h
index 54b8b48..e587d4a 100644
--- a/include/linux/iommu.h
+++ b/include/linux/iommu.h
@@ -223,6 +223,7 @@ enum iommu_domain_cookie_type {
 struct iommu_domain {
 	unsigned type;
 	enum iommu_domain_cookie_type cookie_type;
+	bool is_iommupt;
 	const struct iommu_domain_ops *ops;
 	const struct iommu_dirty_ops *dirty_ops;
 	const struct iommu_ops *owner; /* Whose domain_alloc we came from */
@@ -271,6 +272,8 @@ enum iommu_cap {
 	 */
 	IOMMU_CAP_DEFERRED_FLUSH,
 	IOMMU_CAP_DIRTY_TRACKING,	/* IOMMU supports dirty tracking */
+	/* ATS is supported and may be enabled for this device */
+	IOMMU_CAP_PCI_ATS_SUPPORTED,
 };
 
 /* These are the possible reserved region types */
@@ -980,7 +983,8 @@ static inline void iommu_flush_iotlb_all(struct iommu_domain *domain)
 static inline void iommu_iotlb_sync(struct iommu_domain *domain,
 				  struct iommu_iotlb_gather *iotlb_gather)
 {
-	if (domain->ops->iotlb_sync)
+	if (domain->ops->iotlb_sync &&
+	    likely(iotlb_gather->start < iotlb_gather->end))
 		domain->ops->iotlb_sync(domain, iotlb_gather);
 
 	iommu_iotlb_gather_init(iotlb_gather);
diff --git a/include/uapi/linux/iommufd.h b/include/uapi/linux/iommufd.h
index 1dafbc5..e998dfb 100644
--- a/include/uapi/linux/iommufd.h
+++ b/include/uapi/linux/iommufd.h
@@ -695,11 +695,15 @@ enum iommu_hw_info_type {
  * @IOMMU_HW_CAP_PCI_PASID_PRIV: Privileged Mode Supported, user ignores it
  *                               when the struct
  *                               iommu_hw_info::out_max_pasid_log2 is zero.
+ * @IOMMU_HW_CAP_PCI_ATS_NOT_SUPPORTED: ATS is not supported or cannot be used
+ *                                      on this device (absence implies ATS
+ *                                      may be enabled)
  */
 enum iommufd_hw_capabilities {
 	IOMMU_HW_CAP_DIRTY_TRACKING = 1 << 0,
 	IOMMU_HW_CAP_PCI_PASID_EXEC = 1 << 1,
 	IOMMU_HW_CAP_PCI_PASID_PRIV = 1 << 2,
+	IOMMU_HW_CAP_PCI_ATS_NOT_SUPPORTED = 1 << 3,
 };
 
 /**
@@ -1052,6 +1056,11 @@ struct iommu_fault_alloc {
 enum iommu_viommu_type {
 	IOMMU_VIOMMU_TYPE_DEFAULT = 0,
 	IOMMU_VIOMMU_TYPE_ARM_SMMUV3 = 1,
+	/*
+	 * TEGRA241_CMDQV requirements (otherwise, VCMDQs will not work)
+	 * - Kernel will allocate a VINTF (HYP_OWN=0) to back this VIOMMU. So,
+	 *   VMM must wire the HYP_OWN bit to 0 in guest VINTF_CONFIG register
+	 */
 	IOMMU_VIOMMU_TYPE_TEGRA241_CMDQV = 2,
 };
 
diff --git a/tools/testing/selftests/iommu/iommufd.c b/tools/testing/selftests/iommu/iommufd.c
index dadad277..d1fe5db 100644
--- a/tools/testing/selftests/iommu/iommufd.c
+++ b/tools/testing/selftests/iommu/iommufd.c
@@ -2275,6 +2275,33 @@ TEST_F(iommufd_dirty_tracking, set_dirty_tracking)
 	test_ioctl_destroy(hwpt_id);
 }
 
+TEST_F(iommufd_dirty_tracking, pasid_set_dirty_tracking)
+{
+	uint32_t stddev_id, ioas_id, hwpt_id, pasid = 100;
+	uint32_t dev_flags = MOCK_FLAGS_DEVICE_PASID;
+
+	/* Regular case */
+	test_cmd_hwpt_alloc(self->idev_id, self->ioas_id,
+			    IOMMU_HWPT_ALLOC_PASID | IOMMU_HWPT_ALLOC_DIRTY_TRACKING,
+			    &hwpt_id);
+	test_cmd_mock_domain_flags(hwpt_id, dev_flags, &stddev_id, NULL, NULL);
+	ASSERT_EQ(0, _test_cmd_pasid_attach(self->fd, stddev_id, pasid, hwpt_id));
+	test_cmd_set_dirty_tracking(hwpt_id, true);
+	test_cmd_set_dirty_tracking(hwpt_id, false);
+	ASSERT_EQ(0, _test_cmd_pasid_detach(self->fd, stddev_id, pasid));
+
+	test_ioctl_destroy(stddev_id);
+
+	/* IOMMU device does not support dirty tracking */
+	dev_flags |= MOCK_FLAGS_DEVICE_NO_DIRTY;
+	test_ioctl_ioas_alloc(&ioas_id);
+	test_cmd_mock_domain_flags(ioas_id, dev_flags, &stddev_id, NULL, NULL);
+	EXPECT_ERRNO(EINVAL, _test_cmd_pasid_attach(self->fd, stddev_id, pasid, hwpt_id));
+
+	test_ioctl_destroy(stddev_id);
+	test_ioctl_destroy(hwpt_id);
+}
+
 TEST_F(iommufd_dirty_tracking, device_dirty_capability)
 {
 	uint32_t caps = 0;