diff --git a/arch/arm64/mm/fault.c b/arch/arm64/mm/fault.c index 92cc26d07d5c..02182305bf10 100644 --- a/arch/arm64/mm/fault.c +++ b/arch/arm64/mm/fault.c @@ -542,9 +542,7 @@ static int __kprobes do_page_fault(unsigned long far, unsigned int esr, unsigned int mm_flags = FAULT_FLAG_DEFAULT; unsigned long addr = untagged_addr(far); #ifdef CONFIG_SPECULATIVE_PAGE_FAULT - struct file *orig_file = NULL; struct vm_area_struct *vma; - struct vm_area_struct pvma; unsigned long seq; #endif @@ -618,38 +616,29 @@ static int __kprobes do_page_fault(unsigned long far, unsigned int esr, count_vm_spf_event(SPF_ABORT_ODD); goto spf_abort; } - rcu_read_lock(); - vma = __find_vma(mm, addr); - if (!vma || vma->vm_start > addr) { - rcu_read_unlock(); + vma = get_vma(mm, addr); + if (!vma) { count_vm_spf_event(SPF_ABORT_UNMAPPED); goto spf_abort; } if (!vma_can_speculate(vma, mm_flags)) { - rcu_read_unlock(); + put_vma(vma); count_vm_spf_event(SPF_ABORT_NO_SPECULATE); goto spf_abort; } - if (vma->vm_file) - orig_file = get_file(vma->vm_file); - pvma = *vma; - rcu_read_unlock(); + if (!mmap_seq_read_check(mm, seq, SPF_ABORT_VMA_COPY)) { - if (orig_file) - fput(orig_file); + put_vma(vma); goto spf_abort; } - vma = &pvma; if (!(vma->vm_flags & vm_flags)) { - if (orig_file) - fput(orig_file); + put_vma(vma); count_vm_spf_event(SPF_ABORT_ACCESS_ERROR); goto spf_abort; } fault = do_handle_mm_fault(vma, addr & PAGE_MASK, mm_flags | FAULT_FLAG_SPECULATIVE, seq, regs); - if (orig_file) - fput(orig_file); + put_vma(vma); /* Quick path to respond to signals */ if (fault_signal_pending(fault, regs)) { diff --git a/arch/powerpc/mm/fault.c b/arch/powerpc/mm/fault.c index 0799a058b6b9..888c12f405c2 100644 --- a/arch/powerpc/mm/fault.c +++ b/arch/powerpc/mm/fault.c @@ -395,8 +395,6 @@ static int ___do_page_fault(struct pt_regs *regs, unsigned long address, vm_fault_t fault, major = 0; bool kprobe_fault = kprobe_page_fault(regs, 11); #ifdef CONFIG_SPECULATIVE_PAGE_FAULT - struct file *orig_file = NULL; - struct vm_area_struct pvma; unsigned long seq; #endif @@ -469,47 +467,37 @@ static int ___do_page_fault(struct pt_regs *regs, unsigned long address, count_vm_spf_event(SPF_ABORT_ODD); goto spf_abort; } - rcu_read_lock(); - vma = __find_vma(mm, address); - if (!vma || vma->vm_start > address) { - rcu_read_unlock(); + vma = get_vma(mm, address); + if (!vma) { count_vm_spf_event(SPF_ABORT_UNMAPPED); goto spf_abort; } if (!vma_can_speculate(vma, flags)) { - rcu_read_unlock(); + put_vma(vma); count_vm_spf_event(SPF_ABORT_NO_SPECULATE); goto spf_abort; } - if (vma->vm_file) - orig_file = get_file(vma->vm_file); - pvma = *vma; - rcu_read_unlock(); + if (!mmap_seq_read_check(mm, seq, SPF_ABORT_VMA_COPY)) { - if (orig_file) - fput(orig_file); + put_vma(vma); goto spf_abort; } - vma = &pvma; #ifdef CONFIG_PPC_MEM_KEYS if (unlikely(access_pkey_error(is_write, is_exec, (error_code & DSISR_KEYFAULT), vma))) { - if (orig_file) - fput(orig_file); + put_vma(vma); count_vm_spf_event(SPF_ABORT_ACCESS_ERROR); goto spf_abort; } #endif /* CONFIG_PPC_MEM_KEYS */ if (unlikely(access_error(is_write, is_exec, vma))) { - if (orig_file) - fput(orig_file); + put_vma(vma); count_vm_spf_event(SPF_ABORT_ACCESS_ERROR); goto spf_abort; } fault = do_handle_mm_fault(vma, address, - flags | FAULT_FLAG_SPECULATIVE, seq, regs); - if (orig_file) - fput(orig_file); + flags | FAULT_FLAG_SPECULATIVE, seq, regs); + put_vma(vma); major |= fault & VM_FAULT_MAJOR; if (fault_signal_pending(fault, regs)) diff --git a/arch/x86/mm/fault.c b/arch/x86/mm/fault.c index 7b05f6da6616..83e07cbaa95a 100644 --- a/arch/x86/mm/fault.c +++ b/arch/x86/mm/fault.c @@ -1227,8 +1227,6 @@ void do_user_addr_fault(struct pt_regs *regs, vm_fault_t fault; unsigned int flags = FAULT_FLAG_DEFAULT; #ifdef CONFIG_SPECULATIVE_PAGE_FAULT - struct file *orig_file = NULL; - struct vm_area_struct pvma; unsigned long seq; #endif @@ -1342,38 +1340,30 @@ void do_user_addr_fault(struct pt_regs *regs, count_vm_spf_event(SPF_ABORT_ODD); goto spf_abort; } - rcu_read_lock(); - vma = __find_vma(mm, address); - if (!vma || vma->vm_start > address) { - rcu_read_unlock(); + vma = get_vma(mm, address); + if (!vma) { count_vm_spf_event(SPF_ABORT_UNMAPPED); goto spf_abort; } + if (!vma_can_speculate(vma, flags)) { - rcu_read_unlock(); + put_vma(vma); count_vm_spf_event(SPF_ABORT_NO_SPECULATE); goto spf_abort; } - if (vma->vm_file) - orig_file = get_file(vma->vm_file); - pvma = *vma; - rcu_read_unlock(); + if (!mmap_seq_read_check(mm, seq, SPF_ABORT_VMA_COPY)) { - if (orig_file) - fput(orig_file); + put_vma(vma); goto spf_abort; } - vma = &pvma; if (unlikely(access_error(error_code, vma))) { - if (orig_file) - fput(orig_file); + put_vma(vma); count_vm_spf_event(SPF_ABORT_ACCESS_ERROR); goto spf_abort; } fault = do_handle_mm_fault(vma, address, - flags | FAULT_FLAG_SPECULATIVE, seq, regs); - if (orig_file) - fput(orig_file); + flags | FAULT_FLAG_SPECULATIVE, seq, regs); + put_vma(vma); if (!(fault & VM_FAULT_RETRY)) goto done; diff --git a/include/linux/mm.h b/include/linux/mm.h index 7710684b00b1..21c8954d4249 100644 --- a/include/linux/mm.h +++ b/include/linux/mm.h @@ -253,6 +253,7 @@ void setup_initial_init_mm(void *start_code, void *end_code, struct vm_area_struct *vm_area_alloc(struct mm_struct *); struct vm_area_struct *vm_area_dup(struct vm_area_struct *); +void vm_area_free_no_check(struct vm_area_struct *); void vm_area_free(struct vm_area_struct *); #ifndef CONFIG_MMU @@ -685,6 +686,10 @@ static inline void vma_init(struct vm_area_struct *vma, struct mm_struct *mm) memset(vma, 0, sizeof(*vma)); vma->vm_mm = mm; vma->vm_ops = &dummy_vm_ops; +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT + /* Start from 0 to use atomic_inc_unless_negative() in get_vma() */ + atomic_set(&vma->file_ref_count, 0); +#endif INIT_LIST_HEAD(&vma->anon_vma_chain); } @@ -3383,6 +3388,9 @@ static inline bool pte_spinlock(struct vm_fault *vmf) return __pte_map_lock(vmf); } +struct vm_area_struct *get_vma(struct mm_struct *mm, unsigned long addr); +void put_vma(struct vm_area_struct *vma); + #else /* !CONFIG_SPECULATIVE_PAGE_FAULT */ #define pte_map_lock(___vmf) \ diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h index 3142ce952db6..7127dc54e5b3 100644 --- a/include/linux/mm_types.h +++ b/include/linux/mm_types.h @@ -419,6 +419,11 @@ struct vm_area_struct { #endif struct vm_userfaultfd_ctx vm_userfaultfd_ctx; #ifdef CONFIG_SPECULATIVE_PAGE_FAULT + /* + * The name does not reflect the usage and is not renamed to keep + * the ABI intact. + * This is used to refcount VMA in get_vma/put_vma. + */ atomic_t file_ref_count; #endif diff --git a/kernel/fork.c b/kernel/fork.c index ad812821a0fd..0a3c24723ff8 100644 --- a/kernel/fork.c +++ b/kernel/fork.c @@ -381,32 +381,41 @@ struct vm_area_struct *vm_area_dup(struct vm_area_struct *orig) return new; } -static inline void ____vm_area_free(struct vm_area_struct *vma) -{ - kmem_cache_free(vm_area_cachep, vma); -} - #ifdef CONFIG_SPECULATIVE_PAGE_FAULT -static void __vm_area_free(struct rcu_head *head) +static void __free_vm_area_struct(struct rcu_head *head) { struct vm_area_struct *vma = container_of(head, struct vm_area_struct, vm_rcu); - ____vm_area_free(vma); + kmem_cache_free(vm_area_cachep, vma); +} + +static inline void free_vm_area_struct(struct vm_area_struct *vma) +{ + call_rcu(&vma->vm_rcu, __free_vm_area_struct); +} +#else +static inline void free_vm_area_struct(struct vm_area_struct *vma) +{ + kmem_cache_free(vm_area_cachep, vma); } #endif -void vm_area_free(struct vm_area_struct *vma) +void vm_area_free_no_check(struct vm_area_struct *vma) { free_anon_vma_name(vma); if (vma->vm_file) fput(vma->vm_file); + free_vm_area_struct(vma); +} + +void vm_area_free(struct vm_area_struct *vma) +{ #ifdef CONFIG_SPECULATIVE_PAGE_FAULT - if (atomic_read(&vma->vm_mm->mm_users) > 1) { - call_rcu(&vma->vm_rcu, __vm_area_free); + /* Free only after refcount dropped to negative */ + if (atomic_dec_return(&vma->file_ref_count) >= 0) return; - } #endif - ____vm_area_free(vma); + vm_area_free_no_check(vma); } static void account_kernel_stack(struct task_struct *tsk, int account) diff --git a/mm/memory.c b/mm/memory.c index 17a03bf164e5..6bc34d79cfa2 100644 --- a/mm/memory.c +++ b/mm/memory.c @@ -209,6 +209,35 @@ static void check_sync_rss_stat(struct task_struct *task) #endif /* SPLIT_RSS_COUNTING */ +#ifdef CONFIG_SPECULATIVE_PAGE_FAULT + +struct vm_area_struct *get_vma(struct mm_struct *mm, unsigned long addr) +{ + struct vm_area_struct *vma; + + rcu_read_lock(); + vma = __find_vma(mm, addr); + if (vma) { + if (vma->vm_start > addr || + !atomic_inc_unless_negative(&vma->file_ref_count)) + vma = NULL; + } + rcu_read_unlock(); + + return vma; +} + +void put_vma(struct vm_area_struct *vma) +{ + int new_ref_count; + + new_ref_count = atomic_dec_return(&vma->file_ref_count); + if (new_ref_count < 0) + vm_area_free_no_check(vma); +} + +#endif /* CONFIG_SPECULATIVE_PAGE_FAULT */ + /* * Note: this doesn't free the actual pages themselves. That * has been handled earlier when unmapping all the memory regions.