userfaultfd: WP: add UFFDIO_COPY_MODE_WP

This allows UFFDIO_COPY to map pages wrprotected.

Signed-off-by: Andrea Arcangeli <aarcange@redhat.com>
diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
index fa52170..ffbe2c7 100644
--- a/fs/userfaultfd.c
+++ b/fs/userfaultfd.c
@@ -1686,11 +1686,11 @@
 	ret = -EINVAL;
 	if (uffdio_copy.src + uffdio_copy.len <= uffdio_copy.src)
 		goto out;
-	if (uffdio_copy.mode & ~UFFDIO_COPY_MODE_DONTWAKE)
+	if (uffdio_copy.mode & ~(UFFDIO_COPY_MODE_DONTWAKE|UFFDIO_COPY_MODE_WP))
 		goto out;
 	if (mmget_not_zero(ctx->mm)) {
 		ret = mcopy_atomic(ctx->mm, uffdio_copy.dst, uffdio_copy.src,
-				   uffdio_copy.len);
+				   uffdio_copy.len, uffdio_copy.mode);
 		mmput(ctx->mm);
 	} else {
 		return -ESRCH;
diff --git a/include/linux/userfaultfd_k.h b/include/linux/userfaultfd_k.h
index 356997f..67e6361 100644
--- a/include/linux/userfaultfd_k.h
+++ b/include/linux/userfaultfd_k.h
@@ -33,7 +33,8 @@
 extern int handle_userfault(struct vm_fault *vmf, unsigned long reason);
 
 extern ssize_t mcopy_atomic(struct mm_struct *dst_mm, unsigned long dst_start,
-			    unsigned long src_start, unsigned long len);
+			    unsigned long src_start, unsigned long len,
+			    __u64 mode);
 extern ssize_t mfill_zeropage(struct mm_struct *dst_mm,
 			      unsigned long dst_start,
 			      unsigned long len);
diff --git a/include/uapi/linux/userfaultfd.h b/include/uapi/linux/userfaultfd.h
index 25c0876..e8f8a5b 100644
--- a/include/uapi/linux/userfaultfd.h
+++ b/include/uapi/linux/userfaultfd.h
@@ -212,13 +212,14 @@
 	__u64 dst;
 	__u64 src;
 	__u64 len;
-	/*
-	 * There will be a wrprotection flag later that allows to map
-	 * pages wrprotected on the fly. And such a flag will be
-	 * available if the wrprotection ioctl are implemented for the
-	 * range according to the uffdio_register.ioctls.
-	 */
 #define UFFDIO_COPY_MODE_DONTWAKE		((__u64)1<<0)
+	/*
+	 * UFFDIO_COPY_MODE_WP will map the page wrprotected on the
+	 * fly. UFFDIO_COPY_MODE_WP is available only if the
+	 * wrprotection ioctl are implemented for the range according
+	 * to the uffdio_register.ioctls.
+	 */
+#define UFFDIO_COPY_MODE_WP			((__u64)1<<1)
 	__u64 mode;
 
 	/*
diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
index fe29057..648e817 100644
--- a/mm/userfaultfd.c
+++ b/mm/userfaultfd.c
@@ -25,7 +25,8 @@
 			    struct vm_area_struct *dst_vma,
 			    unsigned long dst_addr,
 			    unsigned long src_addr,
-			    struct page **pagep)
+			    struct page **pagep,
+			    bool wp_copy)
 {
 	struct mem_cgroup *memcg;
 	pte_t _dst_pte, *dst_pte;
@@ -69,9 +70,9 @@
 	if (mem_cgroup_try_charge(page, dst_mm, GFP_KERNEL, &memcg, false))
 		goto out_release;
 
-	_dst_pte = mk_pte(page, dst_vma->vm_page_prot);
-	if (dst_vma->vm_flags & VM_WRITE)
-		_dst_pte = pte_mkwrite(pte_mkdirty(_dst_pte));
+	_dst_pte = pte_mkdirty(mk_pte(page, dst_vma->vm_page_prot));
+	if (dst_vma->vm_flags & VM_WRITE && !wp_copy)
+		_dst_pte = pte_mkwrite(_dst_pte);
 
 	ret = -EEXIST;
 	dst_pte = pte_offset_map_lock(dst_mm, dst_pmd, dst_addr, &ptl);
@@ -376,18 +377,21 @@
 						unsigned long dst_addr,
 						unsigned long src_addr,
 						struct page **page,
-						bool zeropage)
+						bool zeropage,
+						bool wp_copy)
 {
 	ssize_t err;
 
 	if (vma_is_anonymous(dst_vma)) {
 		if (!zeropage)
 			err = mcopy_atomic_pte(dst_mm, dst_pmd, dst_vma,
-					       dst_addr, src_addr, page);
+					       dst_addr, src_addr, page,
+					       wp_copy);
 		else
 			err = mfill_zeropage_pte(dst_mm, dst_pmd,
 						 dst_vma, dst_addr);
 	} else {
+		VM_WARN_ON(wp_copy); /* WP only available for anon */
 		if (!zeropage)
 			err = shmem_mcopy_atomic_pte(dst_mm, dst_pmd,
 						     dst_vma, dst_addr,
@@ -404,7 +408,7 @@
 					      unsigned long dst_start,
 					      unsigned long src_start,
 					      unsigned long len,
-					      bool zeropage)
+					      bool zeropage, __u64 mode)
 {
 	struct vm_area_struct *dst_vma;
 	ssize_t err;
@@ -412,6 +416,7 @@
 	unsigned long src_addr, dst_addr;
 	long copied;
 	struct page *page;
+	bool wp_copy;
 
 	/*
 	 * Sanitize the command parameters:
@@ -464,6 +469,14 @@
 		goto out_unlock;
 
 	/*
+	 * validate 'mode' now that we know the dst_vma: don't allow
+	 * a wrprotect copy if the userfaultfd didn't register as WP.
+	 */
+	wp_copy = mode & UFFDIO_COPY_MODE_WP;
+	if (wp_copy && !(dst_vma->vm_flags & VM_UFFD_WP))
+		goto out_unlock;
+
+	/*
 	 * If this is a HUGETLB vma, pass off to appropriate routine
 	 */
 	if (is_vm_hugetlb_page(dst_vma))
@@ -517,7 +530,7 @@
 		BUG_ON(pmd_trans_huge(*dst_pmd));
 
 		err = mfill_atomic_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
-				       src_addr, &page, zeropage);
+				       src_addr, &page, zeropage, wp_copy);
 		cond_resched();
 
 		if (unlikely(err == -EFAULT)) {
@@ -563,15 +576,15 @@
 }
 
 ssize_t mcopy_atomic(struct mm_struct *dst_mm, unsigned long dst_start,
-		     unsigned long src_start, unsigned long len)
+		     unsigned long src_start, unsigned long len, __u64 mode)
 {
-	return __mcopy_atomic(dst_mm, dst_start, src_start, len, false);
+	return __mcopy_atomic(dst_mm, dst_start, src_start, len, false, mode);
 }
 
 ssize_t mfill_zeropage(struct mm_struct *dst_mm, unsigned long start,
 		       unsigned long len)
 {
-	return __mcopy_atomic(dst_mm, start, 0, len, true);
+	return __mcopy_atomic(dst_mm, start, 0, len, true, 0);
 }
 
 int mwriteprotect_range(struct mm_struct *dst_mm, unsigned long start,