diff --git a/arch/arm64/mm/fault.c b/arch/arm64/mm/fault.c index e34b46785150..b6ac6caeb662 100644 --- a/arch/arm64/mm/fault.c +++ b/arch/arm64/mm/fault.c @@ -612,7 +612,8 @@ static int __kprobes do_page_fault(unsigned long far, unsigned long esr, goto lock_mmap; } fault = handle_mm_fault(vma, addr, mm_flags | FAULT_FLAG_VMA_LOCK, regs); - vma_end_read(vma); + if (!(fault & (VM_FAULT_RETRY | VM_FAULT_COMPLETED))) + vma_end_read(vma); if (!(fault & VM_FAULT_RETRY)) { count_vm_vma_lock_event(VMA_LOCK_SUCCESS); diff --git a/arch/powerpc/mm/fault.c b/arch/powerpc/mm/fault.c index 5bfdf6ecfa96..82954d0e6906 100644 --- a/arch/powerpc/mm/fault.c +++ b/arch/powerpc/mm/fault.c @@ -489,7 +489,8 @@ static int ___do_page_fault(struct pt_regs *regs, unsigned long address, } fault = handle_mm_fault(vma, address, flags | FAULT_FLAG_VMA_LOCK, regs); - vma_end_read(vma); + if (!(fault & (VM_FAULT_RETRY | VM_FAULT_COMPLETED))) + vma_end_read(vma); if (!(fault & VM_FAULT_RETRY)) { count_vm_vma_lock_event(VMA_LOCK_SUCCESS); diff --git a/arch/riscv/mm/fault.c b/arch/riscv/mm/fault.c index 209cb580a3b0..88b70983bfb7 100644 --- a/arch/riscv/mm/fault.c +++ b/arch/riscv/mm/fault.c @@ -303,7 +303,8 @@ asmlinkage void do_page_fault(struct pt_regs *regs) } fault = handle_mm_fault(vma, addr, flags | FAULT_FLAG_VMA_LOCK, regs); - vma_end_read(vma); + if (!(fault & (VM_FAULT_RETRY | VM_FAULT_COMPLETED))) + vma_end_read(vma); if (!(fault & VM_FAULT_RETRY)) { count_vm_vma_lock_event(VMA_LOCK_SUCCESS); diff --git a/arch/s390/mm/fault.c b/arch/s390/mm/fault.c index 98a0091bb097..4e1f2790e6ff 100644 --- a/arch/s390/mm/fault.c +++ b/arch/s390/mm/fault.c @@ -414,7 +414,8 @@ static inline vm_fault_t do_exception(struct pt_regs *regs, int access) goto lock_mmap; } fault = handle_mm_fault(vma, address, flags | FAULT_FLAG_VMA_LOCK, regs); - vma_end_read(vma); + if (!(fault & (VM_FAULT_RETRY | VM_FAULT_COMPLETED))) + vma_end_read(vma); if (!(fault & VM_FAULT_RETRY)) { count_vm_vma_lock_event(VMA_LOCK_SUCCESS); goto out; diff --git a/arch/x86/mm/fault.c b/arch/x86/mm/fault.c index 8a74089d9f2e..72044a9342d0 100644 --- a/arch/x86/mm/fault.c +++ b/arch/x86/mm/fault.c @@ -1362,7 +1362,8 @@ void do_user_addr_fault(struct pt_regs *regs, goto lock_mmap; } fault = handle_mm_fault(vma, address, flags | FAULT_FLAG_VMA_LOCK, regs); - vma_end_read(vma); + if (!(fault & (VM_FAULT_RETRY | VM_FAULT_COMPLETED))) + vma_end_read(vma); if (!(fault & VM_FAULT_RETRY)) { count_vm_vma_lock_event(VMA_LOCK_SUCCESS); diff --git a/mm/memory.c b/mm/memory.c index a9ffa95e2386..d1df4eac9d13 100644 --- a/mm/memory.c +++ b/mm/memory.c @@ -3771,6 +3771,7 @@ vm_fault_t do_swap_page(struct vm_fault *vmf) if (vmf->flags & FAULT_FLAG_VMA_LOCK) { ret = VM_FAULT_RETRY; + vma_end_read(vma); goto out; } @@ -5243,6 +5244,17 @@ vm_fault_t handle_mm_fault(struct vm_area_struct *vma, unsigned long address, __set_current_state(TASK_RUNNING); +#ifdef CONFIG_PER_VMA_LOCK + /* + * Per-VMA locks can't be used with FAULT_FLAG_RETRY_NOWAIT because of + * the assumption that lock is dropped on VM_FAULT_RETRY. + */ + if (WARN_ON_ONCE((flags & + (FAULT_FLAG_VMA_LOCK | FAULT_FLAG_RETRY_NOWAIT)) == + (FAULT_FLAG_VMA_LOCK | FAULT_FLAG_RETRY_NOWAIT))) + return VM_FAULT_SIGSEGV; +#endif + /* do counter updates before entering really critical section. */ check_sync_rss_stat(current);