diff --git a/mm/mmap.c b/mm/mmap.c index 322677f61d30..9a9933ede542 100644 --- a/mm/mmap.c +++ b/mm/mmap.c @@ -2652,7 +2652,7 @@ int do_munmap(struct mm_struct *mm, unsigned long start, size_t len, return do_mas_munmap(&mas, mm, start, len, uf, false); } -unsigned long mmap_region(struct file *file, unsigned long addr, +static unsigned long __mmap_region(struct file *file, unsigned long addr, unsigned long len, vm_flags_t vm_flags, unsigned long pgoff, struct list_head *uf) { @@ -2750,26 +2750,28 @@ cannot_expand: vma->vm_page_prot = vm_get_page_prot(vm_flags); vma->vm_pgoff = pgoff; - if (file) { - if (vm_flags & VM_SHARED) { - error = mapping_map_writable(file->f_mapping); - if (error) - goto free_vma; - } + if (mas_preallocate(&mas, vma, GFP_KERNEL)) { + error = -ENOMEM; + goto free_vma; + } + if (file) { vma->vm_file = get_file(file); error = mmap_file(file, vma); if (error) - goto unmap_and_free_vma; + goto unmap_and_free_file_vma; + + /* Drivers cannot alter the address of the VMA. */ + WARN_ON_ONCE(addr != vma->vm_start); /* - * Expansion is handled above, merging is handled below. - * Drivers should not alter the address of the VMA. + * Drivers should not permit writability when previously it was + * disallowed. */ - if (WARN_ON((addr != vma->vm_start))) { - error = -EINVAL; - goto close_and_free_vma; - } + VM_WARN_ON_ONCE(vm_flags != vma->vm_flags && + !(vm_flags & VM_MAYWRITE) && + (vma->vm_flags & VM_MAYWRITE)); + mas_reset(&mas); /* @@ -2792,7 +2794,8 @@ cannot_expand: vma = merge; /* Update vm_flags to pick up the change. */ vm_flags = vma->vm_flags; - goto unmap_writable; + mas_destroy(&mas); + goto file_expanded; } } @@ -2800,31 +2803,15 @@ cannot_expand: } else if (vm_flags & VM_SHARED) { error = shmem_zero_setup(vma); if (error) - goto free_vma; + goto free_iter_vma; } else { vma_set_anonymous(vma); } - /* Allow architectures to sanity-check the vm_flags */ - if (!arch_validate_flags(vma->vm_flags)) { - error = -EINVAL; - if (file) - goto close_and_free_vma; - else if (vma->vm_file) - goto unmap_and_free_vma; - else - goto free_vma; - } - - if (mas_preallocate(&mas, vma, GFP_KERNEL)) { - error = -ENOMEM; - if (file) - goto close_and_free_vma; - else if (vma->vm_file) - goto unmap_and_free_vma; - else - goto free_vma; - } +#ifdef CONFIG_SPARC64 + /* TODO: Fix SPARC ADI! */ + WARN_ON_ONCE(!arch_validate_flags(vm_flags)); +#endif if (vma->vm_file) i_mmap_lock_write(vma->vm_file->f_mapping); @@ -2847,10 +2834,7 @@ cannot_expand: */ khugepaged_enter_vma(vma, vma->vm_flags); - /* Once vma denies write, undo our temporary denial count */ -unmap_writable: - if (file && vm_flags & VM_SHARED) - mapping_unmap_writable(file->f_mapping); +file_expanded: file = vma->vm_file; expanded: perf_event_mmap(vma); @@ -2879,28 +2863,54 @@ expanded: vma_set_page_prot(vma); - validate_mm(mm); return addr; -close_and_free_vma: - vma_close(vma); -unmap_and_free_vma: +unmap_and_free_file_vma: fput(vma->vm_file); vma->vm_file = NULL; /* Undo any partial mapping done by a device driver. */ unmap_region(mm, mas.tree, vma, prev, next, vma->vm_start, vma->vm_end); - if (file && (vm_flags & VM_SHARED)) - mapping_unmap_writable(file->f_mapping); +free_iter_vma: + mas_destroy(&mas); free_vma: vm_area_free(vma); unacct_error: if (charged) vm_unacct_memory(charged); - validate_mm(mm); return error; } +unsigned long mmap_region(struct file *file, unsigned long addr, + unsigned long len, vm_flags_t vm_flags, unsigned long pgoff, + struct list_head *uf) +{ + unsigned long ret; + bool writable_file_mapping = false; + + /* Allow architectures to sanity-check the vm_flags. */ + if (!arch_validate_flags(vm_flags)) + return -EINVAL; + + /* Map writable and ensure this isn't a sealed memfd. */ + if (file && (vm_flags & VM_SHARED)) { + int error = mapping_map_writable(file->f_mapping); + + if (error) + return error; + writable_file_mapping = true; + } + + ret = __mmap_region(file, addr, len, vm_flags, pgoff, uf); + + /* Clear our write mapping regardless of error. */ + if (writable_file_mapping) + mapping_unmap_writable(file->f_mapping); + + validate_mm(current->mm); + return ret; +} + static int __vm_munmap(unsigned long start, size_t len, bool downgrade) { int ret;