| From: "Liam R. Howlett" <Liam.Howlett@oracle.com> |
| Subject: mm/mempolicy: fix use-after-free of VMA iterator |
| Date: Mon, 10 Apr 2023 11:22:05 -0400 |
| |
| set_mempolicy_home_node() iterates over a list of VMAs and calls |
| mbind_range() on each VMA, which also iterates over the singular list of |
| the VMA passed in and potentially splits the VMA. Since the VMA iterator |
| is not passed through, set_mempolicy_home_node() may now point to a stale |
| node in the VMA tree. This can result in a UAF as reported by syzbot. |
| |
| Avoid the stale maple tree node by passing the VMA iterator through to the |
| underlying call to split_vma(). |
| |
| mbind_range() is also overly complicated, since there are two calling |
| functions and one already handles iterating over the VMAs. Simplify |
| mbind_range() to only handle merging and splitting of the VMAs. |
| |
| Align the new loop in do_mbind() and existing loop in |
| set_mempolicy_home_node() to use the reduced mbind_range() function. This |
| allows for a single location of the range calculation and avoids |
| constantly looking up the previous VMA (since this is a loop over the |
| VMAs). |
| |
| Link: https://lore.kernel.org/linux-mm/000000000000c93feb05f87e24ad@google.com/ |
| Fixes: 66850be55e8e ("mm/mempolicy: use vma iterator & maple state instead of vma linked list") |
| Signed-off-by: Liam R. Howlett <Liam.Howlett@oracle.com> |
| Reported-by: syzbot+a7c1ec5b1d71ceaa5186@syzkaller.appspotmail.com |
| Link: https://lkml.kernel.org/r/20230410152205.2294819-1-Liam.Howlett@oracle.com |
| Tested-by: syzbot+a7c1ec5b1d71ceaa5186@syzkaller.appspotmail.com |
| Cc: <stable@vger.kernel.org> |
| Signed-off-by: Andrew Morton <akpm@linux-foundation.org> |
| --- |
| |
| mm/mempolicy.c | 102 ++++++++++++++++++++++------------------------- |
| 1 file changed, 48 insertions(+), 54 deletions(-) |
| |
| --- a/mm/mempolicy.c~mm-mempolicy-fix-use-after-free-of-vma-iterator |
| +++ a/mm/mempolicy.c |
| @@ -790,61 +790,50 @@ static int vma_replace_policy(struct vm_ |
| return err; |
| } |
| |
| -/* Step 2: apply policy to a range and do splits. */ |
| -static int mbind_range(struct mm_struct *mm, unsigned long start, |
| - unsigned long end, struct mempolicy *new_pol) |
| +/* Split or merge the VMA (if required) and apply the new policy */ |
| +static int mbind_range(struct vma_iterator *vmi, struct vm_area_struct *vma, |
| + struct vm_area_struct **prev, unsigned long start, |
| + unsigned long end, struct mempolicy *new_pol) |
| { |
| - VMA_ITERATOR(vmi, mm, start); |
| - struct vm_area_struct *prev; |
| - struct vm_area_struct *vma; |
| - int err = 0; |
| + struct vm_area_struct *merged; |
| + unsigned long vmstart, vmend; |
| pgoff_t pgoff; |
| + int err; |
| |
| - prev = vma_prev(&vmi); |
| - vma = vma_find(&vmi, end); |
| - if (WARN_ON(!vma)) |
| + vmend = min(end, vma->vm_end); |
| + if (start > vma->vm_start) { |
| + *prev = vma; |
| + vmstart = start; |
| + } else { |
| + vmstart = vma->vm_start; |
| + } |
| + |
| + if (mpol_equal(vma_policy(vma), new_pol)) |
| return 0; |
| |
| - if (start > vma->vm_start) |
| - prev = vma; |
| + pgoff = vma->vm_pgoff + ((vmstart - vma->vm_start) >> PAGE_SHIFT); |
| + merged = vma_merge(vmi, vma->vm_mm, *prev, vmstart, vmend, vma->vm_flags, |
| + vma->anon_vma, vma->vm_file, pgoff, new_pol, |
| + vma->vm_userfaultfd_ctx, anon_vma_name(vma)); |
| + if (merged) { |
| + *prev = merged; |
| + return vma_replace_policy(merged, new_pol); |
| + } |
| |
| - do { |
| - unsigned long vmstart = max(start, vma->vm_start); |
| - unsigned long vmend = min(end, vma->vm_end); |
| - |
| - if (mpol_equal(vma_policy(vma), new_pol)) |
| - goto next; |
| - |
| - pgoff = vma->vm_pgoff + |
| - ((vmstart - vma->vm_start) >> PAGE_SHIFT); |
| - prev = vma_merge(&vmi, mm, prev, vmstart, vmend, vma->vm_flags, |
| - vma->anon_vma, vma->vm_file, pgoff, |
| - new_pol, vma->vm_userfaultfd_ctx, |
| - anon_vma_name(vma)); |
| - if (prev) { |
| - vma = prev; |
| - goto replace; |
| - } |
| - if (vma->vm_start != vmstart) { |
| - err = split_vma(&vmi, vma, vmstart, 1); |
| - if (err) |
| - goto out; |
| - } |
| - if (vma->vm_end != vmend) { |
| - err = split_vma(&vmi, vma, vmend, 0); |
| - if (err) |
| - goto out; |
| - } |
| -replace: |
| - err = vma_replace_policy(vma, new_pol); |
| + if (vma->vm_start != vmstart) { |
| + err = split_vma(vmi, vma, vmstart, 1); |
| if (err) |
| - goto out; |
| -next: |
| - prev = vma; |
| - } for_each_vma_range(vmi, vma, end); |
| + return err; |
| + } |
| |
| -out: |
| - return err; |
| + if (vma->vm_end != vmend) { |
| + err = split_vma(vmi, vma, vmend, 0); |
| + if (err) |
| + return err; |
| + } |
| + |
| + *prev = vma; |
| + return vma_replace_policy(vma, new_pol); |
| } |
| |
| /* Set the process memory policy */ |
| @@ -1259,6 +1248,8 @@ static long do_mbind(unsigned long start |
| nodemask_t *nmask, unsigned long flags) |
| { |
| struct mm_struct *mm = current->mm; |
| + struct vm_area_struct *vma, *prev; |
| + struct vma_iterator vmi; |
| struct mempolicy *new; |
| unsigned long end; |
| int err; |
| @@ -1328,7 +1319,13 @@ static long do_mbind(unsigned long start |
| goto up_out; |
| } |
| |
| - err = mbind_range(mm, start, end, new); |
| + vma_iter_init(&vmi, mm, start); |
| + prev = vma_prev(&vmi); |
| + for_each_vma_range(vmi, vma, end) { |
| + err = mbind_range(&vmi, vma, &prev, start, end, new); |
| + if (err) |
| + break; |
| + } |
| |
| if (!err) { |
| int nr_failed = 0; |
| @@ -1489,10 +1486,8 @@ SYSCALL_DEFINE4(set_mempolicy_home_node, |
| unsigned long, home_node, unsigned long, flags) |
| { |
| struct mm_struct *mm = current->mm; |
| - struct vm_area_struct *vma; |
| + struct vm_area_struct *vma, *prev; |
| struct mempolicy *new, *old; |
| - unsigned long vmstart; |
| - unsigned long vmend; |
| unsigned long end; |
| int err = -ENOENT; |
| VMA_ITERATOR(vmi, mm, start); |
| @@ -1521,6 +1516,7 @@ SYSCALL_DEFINE4(set_mempolicy_home_node, |
| if (end == start) |
| return 0; |
| mmap_write_lock(mm); |
| + prev = vma_prev(&vmi); |
| for_each_vma_range(vmi, vma, end) { |
| /* |
| * If any vma in the range got policy other than MPOL_BIND |
| @@ -1541,9 +1537,7 @@ SYSCALL_DEFINE4(set_mempolicy_home_node, |
| } |
| |
| new->home_node = home_node; |
| - vmstart = max(start, vma->vm_start); |
| - vmend = min(end, vma->vm_end); |
| - err = mbind_range(mm, vmstart, vmend, new); |
| + err = mbind_range(&vmi, vma, &prev, start, end, new); |
| mpol_put(new); |
| if (err) |
| break; |
| _ |