diff --git a/arch/arm64/include/asm/kvm_pkvm_module.h b/arch/arm64/include/asm/kvm_pkvm_module.h index 074077995a21..de0015f77985 100644 --- a/arch/arm64/include/asm/kvm_pkvm_module.h +++ b/arch/arm64/include/asm/kvm_pkvm_module.h @@ -20,7 +20,7 @@ struct pkvm_module_ops { enum kvm_pgtable_prot prot, unsigned long *haddr); void *(*alloc_module_va)(u64 nr_pages); - int (*map_module_page)(u64 pfn, void *va, enum kvm_pgtable_prot prot); + int (*map_module_page)(u64 pfn, void *va, enum kvm_pgtable_prot prot, bool is_protected); int (*register_serial_driver)(void (*hyp_putc_cb)(char)); void (*puts)(const char *str); void (*putx64)(u64 num); diff --git a/arch/arm64/kvm/hyp/include/nvhe/mm.h b/arch/arm64/kvm/hyp/include/nvhe/mm.h index 6afc4b1e5d04..66a6d71bd4e0 100644 --- a/arch/arm64/kvm/hyp/include/nvhe/mm.h +++ b/arch/arm64/kvm/hyp/include/nvhe/mm.h @@ -31,7 +31,7 @@ int __pkvm_create_private_mapping(phys_addr_t phys, size_t size, int pkvm_alloc_private_va_range(size_t size, unsigned long *haddr); void pkvm_remove_mappings(void *from, void *to); -int __pkvm_map_module_page(u64 pfn, void *va, enum kvm_pgtable_prot prot); +int __pkvm_map_module_page(u64 pfn, void *va, enum kvm_pgtable_prot prot, bool is_protected); void __pkvm_unmap_module_page(u64 pfn, void *va); void *__pkvm_alloc_module_va(u64 nr_pages); #endif /* __KVM_HYP_MM_H */ diff --git a/arch/arm64/kvm/hyp/nvhe/hyp-main.c b/arch/arm64/kvm/hyp/nvhe/hyp-main.c index 0f5c8dc7d3a7..d67b02fe7ca5 100644 --- a/arch/arm64/kvm/hyp/nvhe/hyp-main.c +++ b/arch/arm64/kvm/hyp/nvhe/hyp-main.c @@ -1177,7 +1177,7 @@ static void handle___pkvm_map_module_page(struct kvm_cpu_context *host_ctxt) DECLARE_REG(void *, va, host_ctxt, 2); DECLARE_REG(enum kvm_pgtable_prot, prot, host_ctxt, 3); - cpu_reg(host_ctxt, 1) = (u64)__pkvm_map_module_page(pfn, va, prot); + cpu_reg(host_ctxt, 1) = (u64)__pkvm_map_module_page(pfn, va, prot, false); } static void handle___pkvm_unmap_module_page(struct kvm_cpu_context *host_ctxt) diff --git a/arch/arm64/kvm/hyp/nvhe/mm.c b/arch/arm64/kvm/hyp/nvhe/mm.c index 8f7c983924f1..76f26dc73cc7 100644 --- a/arch/arm64/kvm/hyp/nvhe/mm.c +++ b/arch/arm64/kvm/hyp/nvhe/mm.c @@ -142,24 +142,24 @@ void *__pkvm_alloc_module_va(u64 nr_pages) return (void *)addr; } -int __pkvm_map_module_page(u64 pfn, void *va, enum kvm_pgtable_prot prot) +int __pkvm_map_module_page(u64 pfn, void *va, enum kvm_pgtable_prot prot, bool is_protected) { unsigned long addr = (unsigned long)va; int ret; assert_in_mod_range(addr); - ret = __pkvm_host_donate_hyp(pfn, 1); - if (ret) - return ret; - - ret = __pkvm_create_mappings(addr, PAGE_SIZE, hyp_pfn_to_phys(pfn), prot); - if (ret) { - WARN_ON(__pkvm_hyp_donate_host(pfn, 1)); - return ret; + if (!is_protected) { + ret = __pkvm_host_donate_hyp(pfn, 1); + if (ret) + return ret; } - return 0; + ret = __pkvm_create_mappings(addr, PAGE_SIZE, hyp_pfn_to_phys(pfn), prot); + if (ret && !is_protected) + WARN_ON(__pkvm_hyp_donate_host(pfn, 1)); + + return ret; } void __pkvm_unmap_module_page(u64 pfn, void *va)