Merge tag 'libnvdimm-for-4.19_dax-memory-failure' of gitolite.kernel.org:pub/scm/linux/kernel/git/nvdimm/nvdimm

Pull libnvdimm memory-failure update from Dave Jiang:
 "As it stands, memory_failure() gets thoroughly confused by dev_pagemap
  backed mappings. The recovery code has specific enabling for several
  possible page states and needs new enabling to handle poison in dax
  mappings.

  In order to support reliable reverse mapping of user space addresses:

   1/ Add new locking in the memory_failure() rmap path to prevent races
      that would typically be handled by the page lock.

   2/ Since dev_pagemap pages are hidden from the page allocator and the
      "compound page" accounting machinery, add a mechanism to determine
      the size of the mapping that encompasses a given poisoned pfn.

   3/ Given pmem errors can be repaired, change the speculatively
      accessed poison protection, mce_unmap_kpfn(), to be reversible and
      otherwise allow ongoing access from the kernel.

  A side effect of this enabling is that MADV_HWPOISON becomes usable
  for dax mappings, however the primary motivation is to allow the
  system to survive userspace consumption of hardware-poison via dax.
  Specifically the current behavior is:

     mce: Uncorrected hardware memory error in user-access at af34214200
     {1}[Hardware Error]: It has been corrected by h/w and requires no further action
     mce: [Hardware Error]: Machine check events logged
     {1}[Hardware Error]: event severity: corrected
     Memory failure: 0xaf34214: reserved kernel page still referenced by 1 users
     [..]
     Memory failure: 0xaf34214: recovery action for reserved kernel page: Failed
     mce: Memory error not recovered
     <reboot>

  ...and with these changes:

     Injecting memory failure for pfn 0x20cb00 at process virtual address 0x7f763dd00000
     Memory failure: 0x20cb00: Killing dax-pmd:5421 due to hardware memory corruption
     Memory failure: 0x20cb00: recovery action for dax page: Recovered

  Given all the cross dependencies I propose taking this through
  nvdimm.git with acks from Naoya, x86/core, x86/RAS, and of course dax
  folks"

* tag 'libnvdimm-for-4.19_dax-memory-failure' of gitolite.kernel.org:pub/scm/linux/kernel/git/nvdimm/nvdimm:
  libnvdimm, pmem: Restore page attributes when clearing errors
  x86/memory_failure: Introduce {set, clear}_mce_nospec()
  x86/mm/pat: Prepare {reserve, free}_memtype() for "decoy" addresses
  mm, memory_failure: Teach memory_failure() about dev_pagemap pages
  filesystem-dax: Introduce dax_lock_mapping_entry()
  mm, memory_failure: Collect mapping size in collect_procs()
  mm, madvise_inject_error: Let memory_failure() optionally take a page reference
  mm, dev_pagemap: Do not clear ->mapping on final put
  mm, madvise_inject_error: Disable MADV_SOFT_OFFLINE for ZONE_DEVICE pages
  filesystem-dax: Set page->index
  device-dax: Set page->index
  device-dax: Enable page_mapping()
  device-dax: Convert to vmf_insert_mixed and vm_fault_t
diff --git a/arch/x86/include/asm/set_memory.h b/arch/x86/include/asm/set_memory.h
index 34cffce..07a2575 100644
--- a/arch/x86/include/asm/set_memory.h
+++ b/arch/x86/include/asm/set_memory.h
@@ -89,4 +89,46 @@
 void set_kernel_text_rw(void);
 void set_kernel_text_ro(void);
 
+#ifdef CONFIG_X86_64
+static inline int set_mce_nospec(unsigned long pfn)
+{
+	unsigned long decoy_addr;
+	int rc;
+
+	/*
+	 * Mark the linear address as UC to make sure we don't log more
+	 * errors because of speculative access to the page.
+	 * We would like to just call:
+	 *      set_memory_uc((unsigned long)pfn_to_kaddr(pfn), 1);
+	 * but doing that would radically increase the odds of a
+	 * speculative access to the poison page because we'd have
+	 * the virtual address of the kernel 1:1 mapping sitting
+	 * around in registers.
+	 * Instead we get tricky.  We create a non-canonical address
+	 * that looks just like the one we want, but has bit 63 flipped.
+	 * This relies on set_memory_uc() properly sanitizing any __pa()
+	 * results with __PHYSICAL_MASK or PTE_PFN_MASK.
+	 */
+	decoy_addr = (pfn << PAGE_SHIFT) + (PAGE_OFFSET ^ BIT(63));
+
+	rc = set_memory_uc(decoy_addr, 1);
+	if (rc)
+		pr_warn("Could not invalidate pfn=0x%lx from 1:1 map\n", pfn);
+	return rc;
+}
+#define set_mce_nospec set_mce_nospec
+
+/* Restore full speculative operation to the pfn. */
+static inline int clear_mce_nospec(unsigned long pfn)
+{
+	return set_memory_wb((unsigned long) pfn_to_kaddr(pfn), 1);
+}
+#define clear_mce_nospec clear_mce_nospec
+#else
+/*
+ * Few people would run a 32-bit kernel on a machine that supports
+ * recoverable errors because they have too much memory to boot 32-bit.
+ */
+#endif
+
 #endif /* _ASM_X86_SET_MEMORY_H */
diff --git a/arch/x86/kernel/cpu/mcheck/mce-internal.h b/arch/x86/kernel/cpu/mcheck/mce-internal.h
index 374d1aa..ceb67cd 100644
--- a/arch/x86/kernel/cpu/mcheck/mce-internal.h
+++ b/arch/x86/kernel/cpu/mcheck/mce-internal.h
@@ -113,21 +113,6 @@
 static inline void mce_unregister_injector_chain(struct notifier_block *nb)	{ }
 #endif
 
-#ifndef CONFIG_X86_64
-/*
- * On 32-bit systems it would be difficult to safely unmap a poison page
- * from the kernel 1:1 map because there are no non-canonical addresses that
- * we can use to refer to the address without risking a speculative access.
- * However, this isn't much of an issue because:
- * 1) Few unmappable pages are in the 1:1 map. Most are in HIGHMEM which
- *    are only mapped into the kernel as needed
- * 2) Few people would run a 32-bit kernel on a machine that supports
- *    recoverable errors because they have too much memory to boot 32-bit.
- */
-static inline void mce_unmap_kpfn(unsigned long pfn) {}
-#define mce_unmap_kpfn mce_unmap_kpfn
-#endif
-
 struct mca_config {
 	bool dont_log_ce;
 	bool cmci_disabled;
diff --git a/arch/x86/kernel/cpu/mcheck/mce.c b/arch/x86/kernel/cpu/mcheck/mce.c
index 4b76728..953b3ce 100644
--- a/arch/x86/kernel/cpu/mcheck/mce.c
+++ b/arch/x86/kernel/cpu/mcheck/mce.c
@@ -42,6 +42,7 @@
 #include <linux/irq_work.h>
 #include <linux/export.h>
 #include <linux/jump_label.h>
+#include <linux/set_memory.h>
 
 #include <asm/intel-family.h>
 #include <asm/processor.h>
@@ -50,7 +51,6 @@
 #include <asm/mce.h>
 #include <asm/msr.h>
 #include <asm/reboot.h>
-#include <asm/set_memory.h>
 
 #include "mce-internal.h"
 
@@ -108,10 +108,6 @@
 
 static void (*quirk_no_way_out)(int bank, struct mce *m, struct pt_regs *regs);
 
-#ifndef mce_unmap_kpfn
-static void mce_unmap_kpfn(unsigned long pfn);
-#endif
-
 /*
  * CPU/chipset specific EDAC code can register a notifier call here to print
  * MCE errors in a human-readable form.
@@ -602,7 +598,7 @@
 	if (mce_usable_address(mce) && (mce->severity == MCE_AO_SEVERITY)) {
 		pfn = mce->addr >> PAGE_SHIFT;
 		if (!memory_failure(pfn, 0))
-			mce_unmap_kpfn(pfn);
+			set_mce_nospec(pfn);
 	}
 
 	return NOTIFY_OK;
@@ -1072,38 +1068,10 @@
 	if (ret)
 		pr_err("Memory error not recovered");
 	else
-		mce_unmap_kpfn(m->addr >> PAGE_SHIFT);
+		set_mce_nospec(m->addr >> PAGE_SHIFT);
 	return ret;
 }
 
-#ifndef mce_unmap_kpfn
-static void mce_unmap_kpfn(unsigned long pfn)
-{
-	unsigned long decoy_addr;
-
-	/*
-	 * Unmap this page from the kernel 1:1 mappings to make sure
-	 * we don't log more errors because of speculative access to
-	 * the page.
-	 * We would like to just call:
-	 *	set_memory_np((unsigned long)pfn_to_kaddr(pfn), 1);
-	 * but doing that would radically increase the odds of a
-	 * speculative access to the poison page because we'd have
-	 * the virtual address of the kernel 1:1 mapping sitting
-	 * around in registers.
-	 * Instead we get tricky.  We create a non-canonical address
-	 * that looks just like the one we want, but has bit 63 flipped.
-	 * This relies on set_memory_np() not checking whether we passed
-	 * a legal address.
-	 */
-
-	decoy_addr = (pfn << PAGE_SHIFT) + (PAGE_OFFSET ^ BIT(63));
-
-	if (set_memory_np(decoy_addr, 1))
-		pr_warn("Could not invalidate pfn=0x%lx from 1:1 map\n", pfn);
-}
-#endif
-
 
 /*
  * Cases where we avoid rendezvous handler timeout:
diff --git a/arch/x86/mm/pat.c b/arch/x86/mm/pat.c
index 1555bd7..3d0c83e 100644
--- a/arch/x86/mm/pat.c
+++ b/arch/x86/mm/pat.c
@@ -512,6 +512,17 @@
 	return 0;
 }
 
+static u64 sanitize_phys(u64 address)
+{
+	/*
+	 * When changing the memtype for pages containing poison allow
+	 * for a "decoy" virtual address (bit 63 clear) passed to
+	 * set_memory_X(). __pa() on a "decoy" address results in a
+	 * physical address with bit 63 set.
+	 */
+	return address & __PHYSICAL_MASK;
+}
+
 /*
  * req_type typically has one of the:
  * - _PAGE_CACHE_MODE_WB
@@ -533,6 +544,8 @@
 	int is_range_ram;
 	int err = 0;
 
+	start = sanitize_phys(start);
+	end = sanitize_phys(end);
 	BUG_ON(start >= end); /* end is exclusive */
 
 	if (!pat_enabled()) {
@@ -609,6 +622,9 @@
 	if (!pat_enabled())
 		return 0;
 
+	start = sanitize_phys(start);
+	end = sanitize_phys(end);
+
 	/* Low ISA region is always mapped WB. No need to track */
 	if (x86_platform.is_untracked_pat_range(start, end))
 		return 0;
diff --git a/drivers/dax/device.c b/drivers/dax/device.c
index 0a2acd7..6fd4608 100644
--- a/drivers/dax/device.c
+++ b/drivers/dax/device.c
@@ -248,13 +248,12 @@
 	return -1;
 }
 
-static int __dev_dax_pte_fault(struct dev_dax *dev_dax, struct vm_fault *vmf)
+static vm_fault_t __dev_dax_pte_fault(struct dev_dax *dev_dax,
+				struct vm_fault *vmf, pfn_t *pfn)
 {
 	struct device *dev = &dev_dax->dev;
 	struct dax_region *dax_region;
-	int rc = VM_FAULT_SIGBUS;
 	phys_addr_t phys;
-	pfn_t pfn;
 	unsigned int fault_size = PAGE_SIZE;
 
 	if (check_vma(dev_dax, vmf->vma, __func__))
@@ -276,26 +275,19 @@
 		return VM_FAULT_SIGBUS;
 	}
 
-	pfn = phys_to_pfn_t(phys, dax_region->pfn_flags);
+	*pfn = phys_to_pfn_t(phys, dax_region->pfn_flags);
 
-	rc = vm_insert_mixed(vmf->vma, vmf->address, pfn);
-
-	if (rc == -ENOMEM)
-		return VM_FAULT_OOM;
-	if (rc < 0 && rc != -EBUSY)
-		return VM_FAULT_SIGBUS;
-
-	return VM_FAULT_NOPAGE;
+	return vmf_insert_mixed(vmf->vma, vmf->address, *pfn);
 }
 
-static int __dev_dax_pmd_fault(struct dev_dax *dev_dax, struct vm_fault *vmf)
+static vm_fault_t __dev_dax_pmd_fault(struct dev_dax *dev_dax,
+				struct vm_fault *vmf, pfn_t *pfn)
 {
 	unsigned long pmd_addr = vmf->address & PMD_MASK;
 	struct device *dev = &dev_dax->dev;
 	struct dax_region *dax_region;
 	phys_addr_t phys;
 	pgoff_t pgoff;
-	pfn_t pfn;
 	unsigned int fault_size = PMD_SIZE;
 
 	if (check_vma(dev_dax, vmf->vma, __func__))
@@ -331,21 +323,21 @@
 		return VM_FAULT_SIGBUS;
 	}
 
-	pfn = phys_to_pfn_t(phys, dax_region->pfn_flags);
+	*pfn = phys_to_pfn_t(phys, dax_region->pfn_flags);
 
-	return vmf_insert_pfn_pmd(vmf->vma, vmf->address, vmf->pmd, pfn,
+	return vmf_insert_pfn_pmd(vmf->vma, vmf->address, vmf->pmd, *pfn,
 			vmf->flags & FAULT_FLAG_WRITE);
 }
 
 #ifdef CONFIG_HAVE_ARCH_TRANSPARENT_HUGEPAGE_PUD
-static int __dev_dax_pud_fault(struct dev_dax *dev_dax, struct vm_fault *vmf)
+static vm_fault_t __dev_dax_pud_fault(struct dev_dax *dev_dax,
+				struct vm_fault *vmf, pfn_t *pfn)
 {
 	unsigned long pud_addr = vmf->address & PUD_MASK;
 	struct device *dev = &dev_dax->dev;
 	struct dax_region *dax_region;
 	phys_addr_t phys;
 	pgoff_t pgoff;
-	pfn_t pfn;
 	unsigned int fault_size = PUD_SIZE;
 
 
@@ -382,23 +374,26 @@
 		return VM_FAULT_SIGBUS;
 	}
 
-	pfn = phys_to_pfn_t(phys, dax_region->pfn_flags);
+	*pfn = phys_to_pfn_t(phys, dax_region->pfn_flags);
 
-	return vmf_insert_pfn_pud(vmf->vma, vmf->address, vmf->pud, pfn,
+	return vmf_insert_pfn_pud(vmf->vma, vmf->address, vmf->pud, *pfn,
 			vmf->flags & FAULT_FLAG_WRITE);
 }
 #else
-static int __dev_dax_pud_fault(struct dev_dax *dev_dax, struct vm_fault *vmf)
+static vm_fault_t __dev_dax_pud_fault(struct dev_dax *dev_dax,
+				struct vm_fault *vmf, pfn_t *pfn)
 {
 	return VM_FAULT_FALLBACK;
 }
 #endif /* !CONFIG_HAVE_ARCH_TRANSPARENT_HUGEPAGE_PUD */
 
-static int dev_dax_huge_fault(struct vm_fault *vmf,
+static vm_fault_t dev_dax_huge_fault(struct vm_fault *vmf,
 		enum page_entry_size pe_size)
 {
-	int rc, id;
 	struct file *filp = vmf->vma->vm_file;
+	unsigned long fault_size;
+	int rc, id;
+	pfn_t pfn;
 	struct dev_dax *dev_dax = filp->private_data;
 
 	dev_dbg(&dev_dax->dev, "%s: %s (%#lx - %#lx) size = %d\n", current->comm,
@@ -408,23 +403,49 @@
 	id = dax_read_lock();
 	switch (pe_size) {
 	case PE_SIZE_PTE:
-		rc = __dev_dax_pte_fault(dev_dax, vmf);
+		fault_size = PAGE_SIZE;
+		rc = __dev_dax_pte_fault(dev_dax, vmf, &pfn);
 		break;
 	case PE_SIZE_PMD:
-		rc = __dev_dax_pmd_fault(dev_dax, vmf);
+		fault_size = PMD_SIZE;
+		rc = __dev_dax_pmd_fault(dev_dax, vmf, &pfn);
 		break;
 	case PE_SIZE_PUD:
-		rc = __dev_dax_pud_fault(dev_dax, vmf);
+		fault_size = PUD_SIZE;
+		rc = __dev_dax_pud_fault(dev_dax, vmf, &pfn);
 		break;
 	default:
 		rc = VM_FAULT_SIGBUS;
 	}
+
+	if (rc == VM_FAULT_NOPAGE) {
+		unsigned long i;
+		pgoff_t pgoff;
+
+		/*
+		 * In the device-dax case the only possibility for a
+		 * VM_FAULT_NOPAGE result is when device-dax capacity is
+		 * mapped. No need to consider the zero page, or racing
+		 * conflicting mappings.
+		 */
+		pgoff = linear_page_index(vmf->vma, vmf->address
+				& ~(fault_size - 1));
+		for (i = 0; i < fault_size / PAGE_SIZE; i++) {
+			struct page *page;
+
+			page = pfn_to_page(pfn_t_to_pfn(pfn) + i);
+			if (page->mapping)
+				continue;
+			page->mapping = filp->f_mapping;
+			page->index = pgoff + i;
+		}
+	}
 	dax_read_unlock(id);
 
 	return rc;
 }
 
-static int dev_dax_fault(struct vm_fault *vmf)
+static vm_fault_t dev_dax_fault(struct vm_fault *vmf)
 {
 	return dev_dax_huge_fault(vmf, PE_SIZE_PTE);
 }
diff --git a/drivers/nvdimm/pmem.c b/drivers/nvdimm/pmem.c
index c236498..6071e29 100644
--- a/drivers/nvdimm/pmem.c
+++ b/drivers/nvdimm/pmem.c
@@ -20,6 +20,7 @@
 #include <linux/hdreg.h>
 #include <linux/init.h>
 #include <linux/platform_device.h>
+#include <linux/set_memory.h>
 #include <linux/module.h>
 #include <linux/moduleparam.h>
 #include <linux/badblocks.h>
@@ -51,6 +52,30 @@
 	return to_nd_region(to_dev(pmem)->parent);
 }
 
+static void hwpoison_clear(struct pmem_device *pmem,
+		phys_addr_t phys, unsigned int len)
+{
+	unsigned long pfn_start, pfn_end, pfn;
+
+	/* only pmem in the linear map supports HWPoison */
+	if (is_vmalloc_addr(pmem->virt_addr))
+		return;
+
+	pfn_start = PHYS_PFN(phys);
+	pfn_end = pfn_start + PHYS_PFN(len);
+	for (pfn = pfn_start; pfn < pfn_end; pfn++) {
+		struct page *page = pfn_to_page(pfn);
+
+		/*
+		 * Note, no need to hold a get_dev_pagemap() reference
+		 * here since we're in the driver I/O path and
+		 * outstanding I/O requests pin the dev_pagemap.
+		 */
+		if (test_and_clear_pmem_poison(page))
+			clear_mce_nospec(pfn);
+	}
+}
+
 static blk_status_t pmem_clear_poison(struct pmem_device *pmem,
 		phys_addr_t offset, unsigned int len)
 {
@@ -65,6 +90,7 @@
 	if (cleared < len)
 		rc = BLK_STS_IOERR;
 	if (cleared > 0 && cleared / 512) {
+		hwpoison_clear(pmem, pmem->phys_addr + offset, cleared);
 		cleared /= 512;
 		dev_dbg(dev, "%#llx clear %ld sector%s\n",
 				(unsigned long long) sector, cleared,
diff --git a/drivers/nvdimm/pmem.h b/drivers/nvdimm/pmem.h
index a64ebc7..59cfe13 100644
--- a/drivers/nvdimm/pmem.h
+++ b/drivers/nvdimm/pmem.h
@@ -1,6 +1,7 @@
 /* SPDX-License-Identifier: GPL-2.0 */
 #ifndef __NVDIMM_PMEM_H__
 #define __NVDIMM_PMEM_H__
+#include <linux/page-flags.h>
 #include <linux/badblocks.h>
 #include <linux/types.h>
 #include <linux/pfn_t.h>
@@ -27,4 +28,16 @@
 
 long __pmem_direct_access(struct pmem_device *pmem, pgoff_t pgoff,
 		long nr_pages, void **kaddr, pfn_t *pfn);
+
+#ifdef CONFIG_MEMORY_FAILURE
+static inline bool test_and_clear_pmem_poison(struct page *page)
+{
+	return TestClearPageHWPoison(page);
+}
+#else
+static inline bool test_and_clear_pmem_poison(struct page *page)
+{
+	return false;
+}
+#endif
 #endif /* __NVDIMM_PMEM_H__ */
diff --git a/fs/dax.c b/fs/dax.c
index f767241..f32d712 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -226,8 +226,8 @@
  *
  * Must be called with the i_pages lock held.
  */
-static void *get_unlocked_mapping_entry(struct address_space *mapping,
-					pgoff_t index, void ***slotp)
+static void *__get_unlocked_mapping_entry(struct address_space *mapping,
+		pgoff_t index, void ***slotp, bool (*wait_fn)(void))
 {
 	void *entry, **slot;
 	struct wait_exceptional_entry_queue ewait;
@@ -237,6 +237,8 @@
 	ewait.wait.func = wake_exceptional_entry_func;
 
 	for (;;) {
+		bool revalidate;
+
 		entry = __radix_tree_lookup(&mapping->i_pages, index, NULL,
 					  &slot);
 		if (!entry ||
@@ -251,14 +253,31 @@
 		prepare_to_wait_exclusive(wq, &ewait.wait,
 					  TASK_UNINTERRUPTIBLE);
 		xa_unlock_irq(&mapping->i_pages);
-		schedule();
+		revalidate = wait_fn();
 		finish_wait(wq, &ewait.wait);
 		xa_lock_irq(&mapping->i_pages);
+		if (revalidate)
+			return ERR_PTR(-EAGAIN);
 	}
 }
 
-static void dax_unlock_mapping_entry(struct address_space *mapping,
-				     pgoff_t index)
+static bool entry_wait(void)
+{
+	schedule();
+	/*
+	 * Never return an ERR_PTR() from
+	 * __get_unlocked_mapping_entry(), just keep looping.
+	 */
+	return false;
+}
+
+static void *get_unlocked_mapping_entry(struct address_space *mapping,
+		pgoff_t index, void ***slotp)
+{
+	return __get_unlocked_mapping_entry(mapping, index, slotp, entry_wait);
+}
+
+static void unlock_mapping_entry(struct address_space *mapping, pgoff_t index)
 {
 	void *entry, **slot;
 
@@ -277,7 +296,7 @@
 static void put_locked_mapping_entry(struct address_space *mapping,
 		pgoff_t index)
 {
-	dax_unlock_mapping_entry(mapping, index);
+	unlock_mapping_entry(mapping, index);
 }
 
 /*
@@ -319,18 +338,27 @@
 	for (pfn = dax_radix_pfn(entry); \
 			pfn < dax_radix_end_pfn(entry); pfn++)
 
-static void dax_associate_entry(void *entry, struct address_space *mapping)
+/*
+ * TODO: for reflink+dax we need a way to associate a single page with
+ * multiple address_space instances at different linear_page_index()
+ * offsets.
+ */
+static void dax_associate_entry(void *entry, struct address_space *mapping,
+		struct vm_area_struct *vma, unsigned long address)
 {
-	unsigned long pfn;
+	unsigned long size = dax_entry_size(entry), pfn, index;
+	int i = 0;
 
 	if (IS_ENABLED(CONFIG_FS_DAX_LIMITED))
 		return;
 
+	index = linear_page_index(vma, address & ~(size - 1));
 	for_each_mapped_pfn(entry, pfn) {
 		struct page *page = pfn_to_page(pfn);
 
 		WARN_ON_ONCE(page->mapping);
 		page->mapping = mapping;
+		page->index = index + i++;
 	}
 }
 
@@ -348,6 +376,7 @@
 		WARN_ON_ONCE(trunc && page_ref_count(page) > 1);
 		WARN_ON_ONCE(page->mapping && page->mapping != mapping);
 		page->mapping = NULL;
+		page->index = 0;
 	}
 }
 
@@ -364,6 +393,84 @@
 	return NULL;
 }
 
+static bool entry_wait_revalidate(void)
+{
+	rcu_read_unlock();
+	schedule();
+	rcu_read_lock();
+
+	/*
+	 * Tell __get_unlocked_mapping_entry() to take a break, we need
+	 * to revalidate page->mapping after dropping locks
+	 */
+	return true;
+}
+
+bool dax_lock_mapping_entry(struct page *page)
+{
+	pgoff_t index;
+	struct inode *inode;
+	bool did_lock = false;
+	void *entry = NULL, **slot;
+	struct address_space *mapping;
+
+	rcu_read_lock();
+	for (;;) {
+		mapping = READ_ONCE(page->mapping);
+
+		if (!dax_mapping(mapping))
+			break;
+
+		/*
+		 * In the device-dax case there's no need to lock, a
+		 * struct dev_pagemap pin is sufficient to keep the
+		 * inode alive, and we assume we have dev_pagemap pin
+		 * otherwise we would not have a valid pfn_to_page()
+		 * translation.
+		 */
+		inode = mapping->host;
+		if (S_ISCHR(inode->i_mode)) {
+			did_lock = true;
+			break;
+		}
+
+		xa_lock_irq(&mapping->i_pages);
+		if (mapping != page->mapping) {
+			xa_unlock_irq(&mapping->i_pages);
+			continue;
+		}
+		index = page->index;
+
+		entry = __get_unlocked_mapping_entry(mapping, index, &slot,
+				entry_wait_revalidate);
+		if (!entry) {
+			xa_unlock_irq(&mapping->i_pages);
+			break;
+		} else if (IS_ERR(entry)) {
+			WARN_ON_ONCE(PTR_ERR(entry) != -EAGAIN);
+			continue;
+		}
+		lock_slot(mapping, slot);
+		did_lock = true;
+		xa_unlock_irq(&mapping->i_pages);
+		break;
+	}
+	rcu_read_unlock();
+
+	return did_lock;
+}
+
+void dax_unlock_mapping_entry(struct page *page)
+{
+	struct address_space *mapping = page->mapping;
+	struct inode *inode = mapping->host;
+
+	if (S_ISCHR(inode->i_mode))
+		return;
+
+	unlock_mapping_entry(mapping, page->index);
+}
+
 /*
  * Find radix tree entry at given index. If it points to an exceptional entry,
  * return it with the radix tree entry locked. If the radix tree doesn't
@@ -708,7 +815,7 @@
 	new_entry = dax_radix_locked_entry(pfn, flags);
 	if (dax_entry_size(entry) != dax_entry_size(new_entry)) {
 		dax_disassociate_entry(entry, mapping, false);
-		dax_associate_entry(new_entry, mapping);
+		dax_associate_entry(new_entry, mapping, vmf->vma, vmf->address);
 	}
 
 	if (dax_is_zero_entry(entry) || dax_is_empty_entry(entry)) {
diff --git a/include/linux/dax.h b/include/linux/dax.h
index deb0f66..450b28d 100644
--- a/include/linux/dax.h
+++ b/include/linux/dax.h
@@ -88,6 +88,8 @@
 		struct block_device *bdev, struct writeback_control *wbc);
 
 struct page *dax_layout_busy_page(struct address_space *mapping);
+bool dax_lock_mapping_entry(struct page *page);
+void dax_unlock_mapping_entry(struct page *page);
 #else
 static inline bool bdev_dax_supported(struct block_device *bdev,
 		int blocksize)
@@ -119,6 +121,17 @@
 {
 	return -EOPNOTSUPP;
 }
+
+static inline bool dax_lock_mapping_entry(struct page *page)
+{
+	if (IS_DAX(page->mapping->host))
+		return true;
+	return false;
+}
+
+static inline void dax_unlock_mapping_entry(struct page *page)
+{
+}
 #endif
 
 int dax_read_lock(void);
diff --git a/include/linux/huge_mm.h b/include/linux/huge_mm.h
index 27e3e32..99c19b0 100644
--- a/include/linux/huge_mm.h
+++ b/include/linux/huge_mm.h
@@ -3,6 +3,7 @@
 #define _LINUX_HUGE_MM_H
 
 #include <linux/sched/coredump.h>
+#include <linux/mm_types.h>
 
 #include <linux/fs.h> /* only for vma_is_dax() */
 
@@ -46,9 +47,9 @@
 extern int change_huge_pmd(struct vm_area_struct *vma, pmd_t *pmd,
 			unsigned long addr, pgprot_t newprot,
 			int prot_numa);
-int vmf_insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
+vm_fault_t vmf_insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
 			pmd_t *pmd, pfn_t pfn, bool write);
-int vmf_insert_pfn_pud(struct vm_area_struct *vma, unsigned long addr,
+vm_fault_t vmf_insert_pfn_pud(struct vm_area_struct *vma, unsigned long addr,
 			pud_t *pud, pfn_t pfn, bool write);
 enum transparent_hugepage_flag {
 	TRANSPARENT_HUGEPAGE_FLAG,
diff --git a/include/linux/mm.h b/include/linux/mm.h
index 8fcc366..a61ebe8 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -2731,6 +2731,7 @@
 	MF_MSG_TRUNCATED_LRU,
 	MF_MSG_BUDDY,
 	MF_MSG_BUDDY_2ND,
+	MF_MSG_DAX,
 	MF_MSG_UNKNOWN,
 };
 
diff --git a/include/linux/set_memory.h b/include/linux/set_memory.h
index da51782..2a986d2 100644
--- a/include/linux/set_memory.h
+++ b/include/linux/set_memory.h
@@ -17,6 +17,20 @@
 static inline int set_memory_nx(unsigned long addr, int numpages) { return 0; }
 #endif
 
+#ifndef set_mce_nospec
+static inline int set_mce_nospec(unsigned long pfn)
+{
+	return 0;
+}
+#endif
+
+#ifndef clear_mce_nospec
+static inline int clear_mce_nospec(unsigned long pfn)
+{
+	return 0;
+}
+#endif
+
 #ifndef CONFIG_ARCH_HAS_MEM_ENCRYPT
 static inline int set_memory_encrypted(unsigned long addr, int numpages)
 {
diff --git a/kernel/memremap.c b/kernel/memremap.c
index d57d58f..5b8600d 100644
--- a/kernel/memremap.c
+++ b/kernel/memremap.c
@@ -365,7 +365,6 @@
 		__ClearPageActive(page);
 		__ClearPageWaiters(page);
 
-		page->mapping = NULL;
 		mem_cgroup_uncharge(page);
 
 		page->pgmap->page_free(page, page->pgmap->data);
diff --git a/mm/hmm.c b/mm/hmm.c
index 0b05545..c968e49 100644
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -968,6 +968,8 @@
 {
 	struct hmm_devmem *devmem = data;
 
+	page->mapping = NULL;
+
 	devmem->ops->free(devmem, page);
 }
 
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index 08b5443..c3bc7e9 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -752,7 +752,7 @@
 	spin_unlock(ptl);
 }
 
-int vmf_insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
+vm_fault_t vmf_insert_pfn_pmd(struct vm_area_struct *vma, unsigned long addr,
 			pmd_t *pmd, pfn_t pfn, bool write)
 {
 	pgprot_t pgprot = vma->vm_page_prot;
@@ -812,7 +812,7 @@
 	spin_unlock(ptl);
 }
 
-int vmf_insert_pfn_pud(struct vm_area_struct *vma, unsigned long addr,
+vm_fault_t vmf_insert_pfn_pud(struct vm_area_struct *vma, unsigned long addr,
 			pud_t *pud, pfn_t pfn, bool write)
 {
 	pgprot_t pgprot = vma->vm_page_prot;
diff --git a/mm/madvise.c b/mm/madvise.c
index 4d3c922..972a9ea 100644
--- a/mm/madvise.c
+++ b/mm/madvise.c
@@ -631,11 +631,13 @@
 
 
 	for (; start < end; start += PAGE_SIZE << order) {
+		unsigned long pfn;
 		int ret;
 
 		ret = get_user_pages_fast(start, 1, 0, &page);
 		if (ret != 1)
 			return ret;
+		pfn = page_to_pfn(page);
 
 		/*
 		 * When soft offlining hugepages, after migrating the page
@@ -651,17 +653,25 @@
 
 		if (behavior == MADV_SOFT_OFFLINE) {
 			pr_info("Soft offlining pfn %#lx at process virtual address %#lx\n",
-						page_to_pfn(page), start);
+					pfn, start);
 
 			ret = soft_offline_page(page, MF_COUNT_INCREASED);
 			if (ret)
 				return ret;
 			continue;
 		}
-		pr_info("Injecting memory failure for pfn %#lx at process virtual address %#lx\n",
-						page_to_pfn(page), start);
 
-		ret = memory_failure(page_to_pfn(page), MF_COUNT_INCREASED);
+		pr_info("Injecting memory failure for pfn %#lx at process virtual address %#lx\n",
+				pfn, start);
+
+		/*
+		 * Drop the page reference taken by get_user_pages_fast(). In
+		 * the absence of MF_COUNT_INCREASED the memory_failure()
+		 * routine is responsible for pinning the page to prevent it
+		 * from being released back to the page allocator.
+		 */
+		put_page(page);
+		ret = memory_failure(pfn, 0);
 		if (ret)
 			return ret;
 	}
diff --git a/mm/memory-failure.c b/mm/memory-failure.c
index 192d0bb..0cd3de3 100644
--- a/mm/memory-failure.c
+++ b/mm/memory-failure.c
@@ -55,6 +55,7 @@
 #include <linux/hugetlb.h>
 #include <linux/memory_hotplug.h>
 #include <linux/mm_inline.h>
+#include <linux/memremap.h>
 #include <linux/kfifo.h>
 #include <linux/ratelimit.h>
 #include <linux/page-isolation.h>
@@ -175,22 +176,51 @@
 EXPORT_SYMBOL_GPL(hwpoison_filter);
 
 /*
+ * Kill all processes that have a poisoned page mapped and then isolate
+ * the page.
+ *
+ * General strategy:
+ * Find all processes having the page mapped and kill them.
+ * But we keep a page reference around so that the page is not
+ * actually freed yet.
+ * Then stash the page away
+ *
+ * There's no convenient way to get back to mapped processes
+ * from the VMAs. So do a brute-force search over all
+ * running processes.
+ *
+ * Remember that machine checks are not common (or rather
+ * if they are common you have other problems), so this shouldn't
+ * be a performance issue.
+ *
+ * Also there are some races possible while we get from the
+ * error detection to actually handle it.
+ */
+
+struct to_kill {
+	struct list_head nd;
+	struct task_struct *tsk;
+	unsigned long addr;
+	short size_shift;
+	char addr_valid;
+};
+
+/*
  * Send all the processes who have the page mapped a signal.
  * ``action optional'' if they are not immediately affected by the error
  * ``action required'' if error happened in current execution context
  */
-static int kill_proc(struct task_struct *t, unsigned long addr,
-			unsigned long pfn, struct page *page, int flags)
+static int kill_proc(struct to_kill *tk, unsigned long pfn, int flags)
 {
-	short addr_lsb;
+	struct task_struct *t = tk->tsk;
+	short addr_lsb = tk->size_shift;
 	int ret;
 
 	pr_err("Memory failure: %#lx: Killing %s:%d due to hardware memory corruption\n",
 		pfn, t->comm, t->pid);
-	addr_lsb = compound_order(compound_head(page)) + PAGE_SHIFT;
 
 	if ((flags & MF_ACTION_REQUIRED) && t->mm == current->mm) {
-		ret = force_sig_mceerr(BUS_MCEERR_AR, (void __user *)addr,
+		ret = force_sig_mceerr(BUS_MCEERR_AR, (void __user *)tk->addr,
 				       addr_lsb, current);
 	} else {
 		/*
@@ -199,7 +229,7 @@
 		 * This could cause a loop when the user sets SIGBUS
 		 * to SIG_IGN, but hopefully no one will do that?
 		 */
-		ret = send_sig_mceerr(BUS_MCEERR_AO, (void __user *)addr,
+		ret = send_sig_mceerr(BUS_MCEERR_AO, (void __user *)tk->addr,
 				      addr_lsb, t);  /* synchronous? */
 	}
 	if (ret < 0)
@@ -235,34 +265,39 @@
 }
 EXPORT_SYMBOL_GPL(shake_page);
 
-/*
- * Kill all processes that have a poisoned page mapped and then isolate
- * the page.
- *
- * General strategy:
- * Find all processes having the page mapped and kill them.
- * But we keep a page reference around so that the page is not
- * actually freed yet.
- * Then stash the page away
- *
- * There's no convenient way to get back to mapped processes
- * from the VMAs. So do a brute-force search over all
- * running processes.
- *
- * Remember that machine checks are not common (or rather
- * if they are common you have other problems), so this shouldn't
- * be a performance issue.
- *
- * Also there are some races possible while we get from the
- * error detection to actually handle it.
- */
+static unsigned long dev_pagemap_mapping_shift(struct page *page,
+		struct vm_area_struct *vma)
+{
+	unsigned long address = vma_address(page, vma);
+	pgd_t *pgd;
+	p4d_t *p4d;
+	pud_t *pud;
+	pmd_t *pmd;
+	pte_t *pte;
 
-struct to_kill {
-	struct list_head nd;
-	struct task_struct *tsk;
-	unsigned long addr;
-	char addr_valid;
-};
+	pgd = pgd_offset(vma->vm_mm, address);
+	if (!pgd_present(*pgd))
+		return 0;
+	p4d = p4d_offset(pgd, address);
+	if (!p4d_present(*p4d))
+		return 0;
+	pud = pud_offset(p4d, address);
+	if (!pud_present(*pud))
+		return 0;
+	if (pud_devmap(*pud))
+		return PUD_SHIFT;
+	pmd = pmd_offset(pud, address);
+	if (!pmd_present(*pmd))
+		return 0;
+	if (pmd_devmap(*pmd))
+		return PMD_SHIFT;
+	pte = pte_offset_map(pmd, address);
+	if (!pte_present(*pte))
+		return 0;
+	if (pte_devmap(*pte))
+		return PAGE_SHIFT;
+	return 0;
+}
 
 /*
  * Failure handling: if we can't find or can't kill a process there's
@@ -293,6 +328,10 @@
 	}
 	tk->addr = page_address_in_vma(p, vma);
 	tk->addr_valid = 1;
+	if (is_zone_device_page(p))
+		tk->size_shift = dev_pagemap_mapping_shift(p, vma);
+	else
+		tk->size_shift = compound_order(compound_head(p)) + PAGE_SHIFT;
 
 	/*
 	 * In theory we don't have to kill when the page was
@@ -300,7 +339,7 @@
 	 * likely very rare kill anyways just out of paranoia, but use
 	 * a SIGKILL because the error is not contained anymore.
 	 */
-	if (tk->addr == -EFAULT) {
+	if (tk->addr == -EFAULT || tk->size_shift == 0) {
 		pr_info("Memory failure: Unable to find user space address %lx in %s\n",
 			page_to_pfn(p), tsk->comm);
 		tk->addr_valid = 0;
@@ -318,9 +357,8 @@
  * Also when FAIL is set do a force kill because something went
  * wrong earlier.
  */
-static void kill_procs(struct list_head *to_kill, int forcekill,
-			  bool fail, struct page *page, unsigned long pfn,
-			  int flags)
+static void kill_procs(struct list_head *to_kill, int forcekill, bool fail,
+		unsigned long pfn, int flags)
 {
 	struct to_kill *tk, *next;
 
@@ -343,8 +381,7 @@
 			 * check for that, but we need to tell the
 			 * process anyways.
 			 */
-			else if (kill_proc(tk->tsk, tk->addr,
-					      pfn, page, flags) < 0)
+			else if (kill_proc(tk, pfn, flags) < 0)
 				pr_err("Memory failure: %#lx: Cannot send advisory machine check signal to %s:%d\n",
 				       pfn, tk->tsk->comm, tk->tsk->pid);
 		}
@@ -516,6 +553,7 @@
 	[MF_MSG_TRUNCATED_LRU]		= "already truncated LRU page",
 	[MF_MSG_BUDDY]			= "free buddy page",
 	[MF_MSG_BUDDY_2ND]		= "free buddy page (2nd try)",
+	[MF_MSG_DAX]			= "dax page",
 	[MF_MSG_UNKNOWN]		= "unknown page",
 };
 
@@ -1013,7 +1051,7 @@
 	 * any accesses to the poisoned memory.
 	 */
 	forcekill = PageDirty(hpage) || (flags & MF_MUST_KILL);
-	kill_procs(&tokill, forcekill, !unmap_success, p, pfn, flags);
+	kill_procs(&tokill, forcekill, !unmap_success, pfn, flags);
 
 	return unmap_success;
 }
@@ -1113,6 +1151,83 @@
 	return res;
 }
 
+static int memory_failure_dev_pagemap(unsigned long pfn, int flags,
+		struct dev_pagemap *pgmap)
+{
+	struct page *page = pfn_to_page(pfn);
+	const bool unmap_success = true;
+	unsigned long size = 0;
+	struct to_kill *tk;
+	LIST_HEAD(tokill);
+	int rc = -EBUSY;
+	loff_t start;
+
+	/*
+	 * Prevent the inode from being freed while we are interrogating
+	 * the address_space, typically this would be handled by
+	 * lock_page(), but dax pages do not use the page lock. This
+	 * also prevents changes to the mapping of this pfn until
+	 * poison signaling is complete.
+	 */
+	if (!dax_lock_mapping_entry(page))
+		goto out;
+
+	if (hwpoison_filter(page)) {
+		rc = 0;
+		goto unlock;
+	}
+
+	switch (pgmap->type) {
+	case MEMORY_DEVICE_PRIVATE:
+	case MEMORY_DEVICE_PUBLIC:
+		/*
+		 * TODO: Handle HMM pages which may need coordination
+		 * with device-side memory.
+		 */
+		goto unlock;
+	default:
+		break;
+	}
+
+	/*
+	 * Use this flag as an indication that the dax page has been
+	 * remapped UC to prevent speculative consumption of poison.
+	 */
+	SetPageHWPoison(page);
+
+	/*
+	 * Unlike System-RAM there is no possibility to swap in a
+	 * different physical page at a given virtual address, so all
+	 * userspace consumption of ZONE_DEVICE memory necessitates
+	 * SIGBUS (i.e. MF_MUST_KILL)
+	 */
+	flags |= MF_ACTION_REQUIRED | MF_MUST_KILL;
+	collect_procs(page, &tokill, flags & MF_ACTION_REQUIRED);
+
+	list_for_each_entry(tk, &tokill, nd)
+		if (tk->size_shift)
+			size = max(size, 1UL << tk->size_shift);
+	if (size) {
+		/*
+		 * Unmap the largest mapping to avoid breaking up
+		 * device-dax mappings which are constant size. The
+		 * actual size of the mapping being torn down is
+		 * communicated in siginfo, see kill_proc()
+		 */
+		start = (page->index << PAGE_SHIFT) & ~(size - 1);
+		unmap_mapping_range(page->mapping, start, start + size, 0);
+	}
+	kill_procs(&tokill, flags & MF_MUST_KILL, !unmap_success, pfn, flags);
+	rc = 0;
+unlock:
+	dax_unlock_mapping_entry(page);
+out:
+	/* drop pgmap ref acquired in caller */
+	put_dev_pagemap(pgmap);
+	action_result(pfn, MF_MSG_DAX, rc ? MF_FAILED : MF_RECOVERED);
+	return rc;
+}
+
 /**
  * memory_failure - Handle memory failure of a page.
  * @pfn: Page Number of the corrupted page
@@ -1135,6 +1250,7 @@
 	struct page *p;
 	struct page *hpage;
 	struct page *orig_head;
+	struct dev_pagemap *pgmap;
 	int res;
 	unsigned long page_flags;
 
@@ -1147,6 +1263,10 @@
 		return -ENXIO;
 	}
 
+	pgmap = get_dev_pagemap(pfn, NULL);
+	if (pgmap)
+		return memory_failure_dev_pagemap(pfn, flags, pgmap);
+
 	p = pfn_to_page(pfn);
 	if (PageHuge(p))
 		return memory_failure_hugetlb(pfn, flags);
@@ -1777,6 +1897,14 @@
 	int ret;
 	unsigned long pfn = page_to_pfn(page);
 
+	if (is_zone_device_page(page)) {
+		pr_debug_ratelimited("soft_offline: %#lx page is device page\n",
+				pfn);
+		if (flags & MF_COUNT_INCREASED)
+			put_page(page);
+		return -EIO;
+	}
+
 	if (PageHWPoison(page)) {
 		pr_info("soft offline: %#lx page already poisoned\n", pfn);
 		if (flags & MF_COUNT_INCREASED)