diff --git a/arch/arm64/include/asm/kvm_host.h b/arch/arm64/include/asm/kvm_host.h index 8888316043e5..4a9dc8bc8124 100644 --- a/arch/arm64/include/asm/kvm_host.h +++ b/arch/arm64/include/asm/kvm_host.h @@ -401,6 +401,7 @@ int pkvm_iommu_register(struct device *dev, struct pkvm_iommu_driver *drv, int pkvm_iommu_suspend(struct device *dev); int pkvm_iommu_resume(struct device *dev); +int pkvm_iommu_s2mpu_init(u32 version); int pkvm_iommu_s2mpu_register(struct device *dev, phys_addr_t pa); int pkvm_iommu_sysmmu_sync_register(struct device *dev, phys_addr_t pa, struct device *parent); diff --git a/arch/arm64/include/asm/kvm_s2mpu.h b/arch/arm64/include/asm/kvm_s2mpu.h index 2d9cd5509b16..fa3cc25f1080 100644 --- a/arch/arm64/include/asm/kvm_s2mpu.h +++ b/arch/arm64/include/asm/kvm_s2mpu.h @@ -217,6 +217,7 @@ struct fmpt { struct mpt { struct fmpt fmpt[NR_GIGABYTES]; + enum s2mpu_version version; }; #endif /* __ARM64_KVM_S2MPU_H__ */ diff --git a/arch/arm64/kvm/hyp/nvhe/iommu/s2mpu.c b/arch/arm64/kvm/hyp/nvhe/iommu/s2mpu.c index b76d2cef6919..fbaa950e5856 100644 --- a/arch/arm64/kvm/hyp/nvhe/iommu/s2mpu.c +++ b/arch/arm64/kvm/hyp/nvhe/iommu/s2mpu.c @@ -487,12 +487,7 @@ static int s2mpu_init(void *data, size_t size) /* The host can concurrently modify 'data'. Copy it to avoid TOCTOU. */ memcpy(&in_mpt, data, sizeof(in_mpt)); - /* - * Only v8/v9 are supported at this point so hardcode the version - * as there is not way to get the version required from the kernel yet, - * v8/v9 are compatible so using any of them will work. - */ - cfg.version = S2MPU_VERSION_8; + cfg.version = in_mpt.version; /* Get page table operations for this version. */ mpt_ops = s2mpu_get_mpt_ops(cfg); /* If version is wrong return. */ diff --git a/arch/arm64/kvm/iommu/s2mpu.c b/arch/arm64/kvm/iommu/s2mpu.c index cbe28dd8b660..fd63ce1fc6e4 100644 --- a/arch/arm64/kvm/iommu/s2mpu.c +++ b/arch/arm64/kvm/iommu/s2mpu.c @@ -12,7 +12,7 @@ /* For an nvhe symbol get the kernel linear address of it. */ #define ksym_ref_addr_nvhe(x) kvm_ksym_ref(&kvm_nvhe_sym(x)) -static int init_s2mpu_driver(void) +static int init_s2mpu_driver(u32 version) { static DEFINE_MUTEX(lock); static bool init_done; @@ -46,6 +46,7 @@ static int init_s2mpu_driver(void) } mpt->fmpt[gb].smpt = (u32 *)addr; } + mpt->version = version; /* Share MPT descriptor with hyp. */ pfn = __pa(mpt) >> PAGE_SHIFT; @@ -74,17 +75,19 @@ out: mutex_unlock(&lock); return ret; } - -int pkvm_iommu_s2mpu_register(struct device *dev, phys_addr_t addr) +int pkvm_iommu_s2mpu_init(u32 version) { - int ret; - if (!is_protected_kvm_enabled()) return -ENODEV; - ret = init_s2mpu_driver(); - if (ret) - return ret; + return init_s2mpu_driver(version); +} +EXPORT_SYMBOL_GPL(pkvm_iommu_s2mpu_init); + +int pkvm_iommu_s2mpu_register(struct device *dev, phys_addr_t addr) +{ + if (!is_protected_kvm_enabled()) + return -ENODEV; return pkvm_iommu_register(dev, ksym_ref_addr_nvhe(pkvm_s2mpu_driver), addr, S2MPU_MMIO_SIZE, NULL);