mm: don't expose page to fast gup before it's ready

We don't want to expose page before it's properly setup.  During page
setup, we may call page_add_new_anon_rmap() which uses non- atomic bit op.
If page is exposed before it's done, we could overwrite page flags that
are set by get_user_pages_fast() or its callers.  Here is a non-fatal
scenario (there might be other fatal problems that I didn't look into):

	CPU 1				CPU1
set_pte_at()			get_user_pages_fast()
page_add_new_anon_rmap()		gup_pte_range()
	__SetPageSwapBacked()			SetPageReferenced()

Fix the problem by delaying set_pte_at() until page is ready.

I didn't observe the race directly.  But I did get few crashes when
trying to access mem_cgroup of pages returned by get_user_pages_fast().
Those page were charged and they showed valid mem_cgroup in kdumps.
So this led me to think the problem came from premature set_pte_at().

I think the fact that nobody complained about this problem is because
the race only happens when using ksm+swap, and it might not cause any
fatal problem even so.  Nevertheless, it's nice to have set_pte_at()
done consistently after rmap is added and page is charged.

Link: http://lkml.kernel.org/r/20180108225632.16332-1-yuzhao@google.com
Signed-off-by: Yu Zhao <yuzhao@google.com>
Cc: Jan Kara <jack@suse.cz>
Cc: Minchan Kim <minchan@kernel.org>
Cc: Johannes Weiner <hannes@cmpxchg.org>
Cc: Vladimir Davydov <vdavydov.dev@gmail.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
diff --git a/mm/memory.c b/mm/memory.c
index 7206a63..a784838 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -3044,7 +3044,6 @@ int do_swap_page(struct vm_fault *vmf)
 	flush_icache_page(vma, page);
 	if (pte_swp_soft_dirty(vmf->orig_pte))
 		pte = pte_mksoft_dirty(pte);
-	set_pte_at(vma->vm_mm, vmf->address, vmf->pte, pte);
 	arch_do_swap_page(vma->vm_mm, vma, vmf->address, pte, vmf->orig_pte);
 	vmf->orig_pte = pte;
 
@@ -3058,6 +3057,7 @@ int do_swap_page(struct vm_fault *vmf)
 		mem_cgroup_commit_charge(page, memcg, true, false);
 		activate_page(page);
 	}
+	set_pte_at(vma->vm_mm, vmf->address, vmf->pte, pte);
 
 	swap_free(entry);
 	if (mem_cgroup_swap_full(page) ||
diff --git a/mm/swapfile.c b/mm/swapfile.c
index 2cc2972..917392c 100644
--- a/mm/swapfile.c
+++ b/mm/swapfile.c
@@ -1800,8 +1800,6 @@ static int unuse_pte(struct vm_area_struct *vma, pmd_t *pmd,
 	dec_mm_counter(vma->vm_mm, MM_SWAPENTS);
 	inc_mm_counter(vma->vm_mm, MM_ANONPAGES);
 	get_page(page);
-	set_pte_at(vma->vm_mm, addr, pte,
-		   pte_mkold(mk_pte(page, vma->vm_page_prot)));
 	if (page == swapcache) {
 		page_add_anon_rmap(page, vma, addr, false);
 		mem_cgroup_commit_charge(page, memcg, true, false);
@@ -1810,6 +1808,8 @@ static int unuse_pte(struct vm_area_struct *vma, pmd_t *pmd,
 		mem_cgroup_commit_charge(page, memcg, false, false);
 		lru_cache_add_active_or_unevictable(page, vma);
 	}
+	set_pte_at(vma->vm_mm, addr, pte,
+		   pte_mkold(mk_pte(page, vma->vm_page_prot)));
 	swap_free(entry);
 	/*
 	 * Move the page to the active list so it is not