diff --git a/arch/arm64/include/asm/io-mpt-s2mpu.h b/arch/arm64/include/asm/io-mpt-s2mpu.h index 382422b26ed6..0dfff4c08ec8 100644 --- a/arch/arm64/include/asm/io-mpt-s2mpu.h +++ b/arch/arm64/include/asm/io-mpt-s2mpu.h @@ -15,11 +15,13 @@ struct s2mpu_mpt_cfg { }; struct s2mpu_mpt_ops { + u32 (*smpt_size)(void); void (*init_with_prot)(void *dev_va, enum mpt_prot prot); void (*init_with_mpt)(void *dev_va, struct mpt *mpt); void (*apply_range)(void *dev_va, struct mpt *mpt, u32 first_gb, u32 last_gb); void (*prepare_range)(struct mpt *mpt, phys_addr_t first_byte, phys_addr_t last_byte, enum mpt_prot prot); + int (*pte_from_addr_smpt)(u32 *smpt, u64 addr); }; const struct s2mpu_mpt_ops *s2mpu_get_mpt_ops(struct s2mpu_mpt_cfg cfg); diff --git a/arch/arm64/include/asm/kvm_s2mpu.h b/arch/arm64/include/asm/kvm_s2mpu.h index 13104e2f18d7..674963c879d1 100644 --- a/arch/arm64/include/asm/kvm_s2mpu.h +++ b/arch/arm64/include/asm/kvm_s2mpu.h @@ -197,10 +197,32 @@ #define V9_MAX_PTLB_NUM 0x100 #define V9_MAX_STLB_NUM 0x100 -#define V9_L1ENTRY_ATTR_GRAN_MASK BIT(3) -#define V9_MPT_PROT_BITS 4 +#define V9_CTRL0_DIS_CHK_S1L1PTW_MASK BIT(0) +#define V9_CTRL0_DIS_CHK_S1L2PTW_MASK BIT(1) +#define V9_CTRL0_DIS_CHK_USR_MARCHED_REQ_MASK BIT(3) +#define V9_CTRL0_FAULT_MODE_MASK BIT(4) +#define V9_CTRL0_ENF_FLT_MODE_S1_NONSEC_MASK BIT(5) +#define V9_CTRL0_DESTRUCTIVE_AP_CHK_MODE_MASK BIT(6) +#define V9_CTRL0_MASK (V9_CTRL0_DIS_CHK_S1L1PTW_MASK | \ + V9_CTRL0_DESTRUCTIVE_AP_CHK_MODE_MASK | \ + V9_CTRL0_DIS_CHK_USR_MARCHED_REQ_MASK | \ + V9_CTRL0_DIS_CHK_S1L2PTW_MASK | \ + V9_CTRL0_ENF_FLT_MODE_S1_NONSEC_MASK | \ + V9_CTRL0_FAULT_MODE_MASK) + +/* + * S2MPU V9 specific values (some new and some different from old versions) + * to avoid any confusion all names are prefixed with V9. + */ +#define V9_L1ENTRY_ATTR_GRAN_MASK BIT(3) +#define V9_MPT_PROT_BITS 4 #define V9_MPT_ACCESS_SHIFT 2 +/* V1,V2 variants. */ +#define MPT_ACCESS_SHIFT 0 +#define L1ENTRY_ATTR_GRAN_MASK GENMASK(5, 4) +#define MPT_PROT_BITS 2 + #define REG_NS_CTRL0 0x0 #define REG_NS_CTRL1 0x4 #define REG_NS_CFG 0x10 @@ -316,12 +338,11 @@ #define L1ENTRY_ATTR_GRAN_4K 0x0 #define L1ENTRY_ATTR_GRAN_64K 0x1 #define L1ENTRY_ATTR_GRAN_2M 0x2 +#define L1ENTRY_ATTR_GRAN(gran, msk) FIELD_PREP(msk, gran) #define L1ENTRY_ATTR_PROT_MASK GENMASK(2, 1) -#define L1ENTRY_ATTR_GRAN_MASK GENMASK(5, 4) #define L1ENTRY_ATTR_PROT(prot) FIELD_PREP(L1ENTRY_ATTR_PROT_MASK, prot) -#define L1ENTRY_ATTR_GRAN(gran) FIELD_PREP(L1ENTRY_ATTR_GRAN_MASK, gran) #define L1ENTRY_ATTR_1G(prot) L1ENTRY_ATTR_PROT(prot) -#define L1ENTRY_ATTR_L2(gran) (L1ENTRY_ATTR_GRAN(gran) | \ +#define L1ENTRY_ATTR_L2(gran, msk) (L1ENTRY_ATTR_GRAN(gran, msk) | \ L1ENTRY_ATTR_L2TABLE_EN) #define NR_GIGABYTES 64 @@ -339,16 +360,19 @@ #endif static_assert(SMPT_GRAN <= PAGE_SIZE); -#define MPT_PROT_BITS 2 + #define SMPT_WORD_SIZE sizeof(u32) -#define SMPT_ELEMS_PER_BYTE (BITS_PER_BYTE / MPT_PROT_BITS) -#define SMPT_ELEMS_PER_WORD (SMPT_WORD_SIZE * SMPT_ELEMS_PER_BYTE) -#define SMPT_WORD_BYTE_RANGE (SMPT_GRAN * SMPT_ELEMS_PER_WORD) +#define SMPT_ELEMS_PER_BYTE(prot_bits) (BITS_PER_BYTE / (prot_bits)) +#define SMPT_ELEMS_PER_WORD(prot_bits) (SMPT_WORD_SIZE * SMPT_ELEMS_PER_BYTE(prot_bits)) +#define SMPT_WORD_BYTE_RANGE(prot_bits) (SMPT_GRAN * SMPT_ELEMS_PER_WORD(prot_bits)) #define SMPT_NUM_ELEMS (SZ_1G / SMPT_GRAN) -#define SMPT_SIZE (SMPT_NUM_ELEMS / SMPT_ELEMS_PER_BYTE) -#define SMPT_NUM_WORDS (SMPT_SIZE / SMPT_WORD_SIZE) -#define SMPT_NUM_PAGES (SMPT_SIZE / PAGE_SIZE) -#define SMPT_ORDER get_order(SMPT_SIZE) +#define SMPT_SIZE(prot_bits) (SMPT_NUM_ELEMS / SMPT_ELEMS_PER_BYTE(prot_bits)) +#define SMPT_NUM_WORDS(prot_bits) (SMPT_SIZE(prot_bits) / SMPT_WORD_SIZE) +#define SMPT_NUM_PAGES(prot_bits) (SMPT_SIZE(prot_bits) / PAGE_SIZE) +#define SMPT_ORDER(prot_bits) get_order(SMPT_SIZE(prot_bits)) + + +#define SMPT_GRAN_MASK GENMASK(1, 0) /* SysMMU_SYNC registers, relative to SYSMMU_SYNC_S2_OFFSET. */ #define REG_NS_SYNC_CMD 0x0 @@ -375,6 +399,15 @@ enum s2mpu_version { S2MPU_VERSION_9 = 0x90000000, }; +static inline int smpt_order_from_version(enum s2mpu_version version) +{ + if (version == S2MPU_VERSION_9) + return SMPT_ORDER(V9_MPT_PROT_BITS); + else if ((version == S2MPU_VERSION_1) || (version == S2MPU_VERSION_2)) + return SMPT_ORDER(MPT_PROT_BITS); + BUG(); +} + enum mpt_prot { MPT_PROT_NONE = 0, MPT_PROT_R = BIT(0), diff --git a/arch/arm64/kvm/hyp/nvhe/iommu/io-mpt-s2mpu.c b/arch/arm64/kvm/hyp/nvhe/iommu/io-mpt-s2mpu.c index 832368171e05..e101c4c4c0b4 100644 --- a/arch/arm64/kvm/hyp/nvhe/iommu/io-mpt-s2mpu.c +++ b/arch/arm64/kvm/hyp/nvhe/iommu/io-mpt-s2mpu.c @@ -5,6 +5,37 @@ #include +#define GRAN_BYTE(gran) ((gran << V9_MPT_PROT_BITS) | (gran)) +#define GRAN_HWORD(gran) ((GRAN_BYTE(gran) << 8) | (GRAN_BYTE(gran))) +#define GRAN_WORD(gran) (((u32)(GRAN_HWORD(gran) << 16) | (GRAN_HWORD(gran)))) +#define GRAN_DWORD(gran) ((u64)((u64)GRAN_WORD(gran) << 32) | (u64)(GRAN_WORD(gran))) + +#define SMPT_NUM_TO_BYTE(x) ((x) / SMPT_GRAN / SMPT_ELEMS_PER_BYTE(config_prot_bits)) +#define BYTE_TO_SMPT_INDEX(x) ((x) / SMPT_WORD_BYTE_RANGE(config_prot_bits)) + + +/* + * MPT table ops can be configured only for one version at runtime, + * these variables will hold version specific data set a run time init, to avoid + * having duplicate code or unnessery check during operations. + */ +static u32 config_prot_bits; +static u32 config_access_shift; +static const u64 *config_lut_prot; +static u32 config_gran_mask; +static u32 this_version; + +/* + * page table entries for different protection look up table + * granularity is compile time config, so we can do this also for + * this array without having duplicate arrays + */ +static const u64 v9_mpt_prot_doubleword[] = { + [MPT_PROT_NONE] = 0x0000000000000000 | GRAN_DWORD(SMPT_GRAN_ATTR), + [MPT_PROT_R] = 0x4444444444444444 | GRAN_DWORD(SMPT_GRAN_ATTR), + [MPT_PROT_W] = 0x8888888888888888 | GRAN_DWORD(SMPT_GRAN_ATTR), + [MPT_PROT_RW] = 0xcccccccccccccccc | GRAN_DWORD(SMPT_GRAN_ATTR), +}; static const u64 mpt_prot_doubleword[] = { [MPT_PROT_NONE] = 0x0000000000000000, [MPT_PROT_R] = 0x5555555555555555, @@ -12,6 +43,25 @@ static const u64 mpt_prot_doubleword[] = { [MPT_PROT_RW] = 0xffffffffffffffff, }; +static inline int pte_from_addr_smpt(u32 *smpt, u64 addr) +{ + u32 word_idx, idx, pte, val; + + word_idx = BYTE_TO_SMPT_INDEX(addr); + val = READ_ONCE(smpt[word_idx]); + idx = (addr / SMPT_GRAN) % SMPT_ELEMS_PER_WORD(config_prot_bits); + + pte = (val >> (idx * config_prot_bits)) & ((1 << config_prot_bits)-1); + return pte; +} + +static inline int prot_from_addr_smpt(u32 *smpt, u64 addr) +{ + int pte = pte_from_addr_smpt(smpt, addr); + + return (pte >> config_access_shift); +} + /* Set protection bits of SMPT in a given range without using memset. */ static void __set_smpt_range_slow(u32 *smpt, size_t start_gb_byte, size_t end_gb_byte, enum mpt_prot prot) @@ -23,20 +73,21 @@ static void __set_smpt_range_slow(u32 *smpt, size_t start_gb_byte, start_word_byte = start_gb_byte; while (start_word_byte < end_gb_byte) { /* Determine the range of bytes covered by this word. */ - word_idx = start_word_byte / SMPT_WORD_BYTE_RANGE; + word_idx = BYTE_TO_SMPT_INDEX(start_word_byte); end_word_byte = min( - ALIGN(start_word_byte + 1, SMPT_WORD_BYTE_RANGE), + ALIGN(start_word_byte + 1, SMPT_WORD_BYTE_RANGE(config_prot_bits)), end_gb_byte); /* Identify protection bit offsets within the word. */ - first_elem = (start_word_byte / SMPT_GRAN) % SMPT_ELEMS_PER_WORD; - last_elem = ((end_word_byte - 1) / SMPT_GRAN) % SMPT_ELEMS_PER_WORD; + first_elem = (start_word_byte / SMPT_GRAN) % SMPT_ELEMS_PER_WORD(config_prot_bits); + last_elem = + ((end_word_byte - 1) / SMPT_GRAN) % SMPT_ELEMS_PER_WORD(config_prot_bits); /* Modify the corresponding word. */ val = READ_ONCE(smpt[word_idx]); for (i = first_elem; i <= last_elem; i++) { - val &= ~(MPT_PROT_MASK << (i * MPT_PROT_BITS)); - val |= prot << (i * MPT_PROT_BITS); + val &= ~(MPT_PROT_MASK << (i * config_prot_bits + config_access_shift)); + val |= prot << (i * config_prot_bits + config_access_shift); } WRITE_ONCE(smpt[word_idx], val); @@ -49,25 +100,33 @@ static void __set_smpt_range(u32 *smpt, size_t start_gb_byte, size_t end_gb_byte, enum mpt_prot prot) { size_t interlude_start, interlude_end, interlude_bytes, word_idx; - char prot_byte = (char)mpt_prot_doubleword[prot]; + + char prot_byte = (char)config_lut_prot[prot]; if (start_gb_byte >= end_gb_byte) return; /* Check if range spans at least one full u32 word. */ - interlude_start = ALIGN(start_gb_byte, SMPT_WORD_BYTE_RANGE); - interlude_end = ALIGN_DOWN(end_gb_byte, SMPT_WORD_BYTE_RANGE); + interlude_start = ALIGN(start_gb_byte, SMPT_WORD_BYTE_RANGE(config_prot_bits)); + interlude_end = ALIGN_DOWN(end_gb_byte, SMPT_WORD_BYTE_RANGE(config_prot_bits)); - /* If not, fall back to editing bits in the given range. */ + /* + * If not, fall back to editing bits in the given range. + * sets bit for PTEs that are in less than 32 bits (can't be done by memset) + */ if (interlude_start >= interlude_end) { __set_smpt_range_slow(smpt, start_gb_byte, end_gb_byte, prot); return; } /* Use bit-editing for prologue/epilogue, memset for interlude. */ - word_idx = interlude_start / SMPT_WORD_BYTE_RANGE; - interlude_bytes = (interlude_end - interlude_start) / SMPT_GRAN / SMPT_ELEMS_PER_BYTE; + word_idx = BYTE_TO_SMPT_INDEX(interlude_start); + interlude_bytes = SMPT_NUM_TO_BYTE(interlude_end - interlude_start); + /* + * These are pages in the start and at then end that are + * not part of full 32 bit SMPT word. + */ __set_smpt_range_slow(smpt, start_gb_byte, interlude_start, prot); memset(&smpt[word_idx], prot_byte, interlude_bytes); __set_smpt_range_slow(smpt, interlude_end, end_gb_byte, prot); @@ -79,8 +138,8 @@ static bool __is_smpt_uniform(u32 *smpt, enum mpt_prot prot) size_t i; u64 *doublewords = (u64 *)smpt; - for (i = 0; i < SMPT_NUM_WORDS / 2; i++) { - if (doublewords[i] != mpt_prot_doubleword[prot]) + for (i = 0; i < SMPT_NUM_WORDS(config_prot_bits) / 2; i++) { + if (doublewords[i] != config_lut_prot[prot]) return false; } return true; @@ -140,6 +199,11 @@ static void __set_fmpt_range(struct fmpt *fmpt, size_t start_gb_byte, fmpt->flags = MPT_UPDATE_L1; } +static u32 smpt_size(void) +{ + return SMPT_SIZE(config_prot_bits); +} + static void __set_l1entry_attr_with_prot(void *dev_va, unsigned int gb, unsigned int vid, enum mpt_prot prot) { @@ -154,7 +218,7 @@ static void __set_l1entry_attr_with_fmpt(void *dev_va, unsigned int gb, __set_l1entry_attr_with_prot(dev_va, gb, vid, fmpt->prot); } else { /* Order against writes to the SMPT. */ - writel(L1ENTRY_ATTR_L2(SMPT_GRAN_ATTR), + writel(config_gran_mask | L1ENTRY_ATTR_L2TABLE_EN, dev_va + REG_NS_L1ENTRY_ATTR(vid, gb)); } } @@ -218,21 +282,40 @@ static void prepare_range(struct mpt *mpt, phys_addr_t first_byte, __set_fmpt_range(fmpt, start_gb_byte, end_gb_byte, prot); if (fmpt->flags & MPT_UPDATE_L2) - kvm_flush_dcache_to_poc(fmpt->smpt, SMPT_SIZE); + kvm_flush_dcache_to_poc(fmpt->smpt, smpt_size()); } } static const struct s2mpu_mpt_ops this_ops = { + .smpt_size = smpt_size, .init_with_prot = init_with_prot, .init_with_mpt = init_with_mpt, .apply_range = apply_range, .prepare_range = prepare_range, + .pte_from_addr_smpt = pte_from_addr_smpt, }; const struct s2mpu_mpt_ops *s2mpu_get_mpt_ops(struct s2mpu_mpt_cfg cfg) { - if ((cfg.version == S2MPU_VERSION_1) || (cfg.version == S2MPU_VERSION_2)) - return &this_ops; + /* If called before with different version return NULL. */ + if (WARN_ON(this_version && (this_version != cfg.version))) + return NULL; + /* 2MB granularity not supported in V9 */ + if ((cfg.version == S2MPU_VERSION_9) && (SMPT_GRAN_ATTR != L1ENTRY_ATTR_GRAN_2M)) { + config_prot_bits = V9_MPT_PROT_BITS; + config_access_shift = V9_MPT_ACCESS_SHIFT; + config_lut_prot = v9_mpt_prot_doubleword; + config_gran_mask = L1ENTRY_ATTR_GRAN(SMPT_GRAN_ATTR, V9_L1ENTRY_ATTR_GRAN_MASK); + this_version = cfg.version; + return &this_ops; + } else if ((cfg.version == S2MPU_VERSION_2) || (cfg.version == S2MPU_VERSION_1)) { + config_prot_bits = MPT_PROT_BITS; + config_access_shift = MPT_ACCESS_SHIFT; + config_lut_prot = mpt_prot_doubleword; + config_gran_mask = L1ENTRY_ATTR_GRAN(SMPT_GRAN_ATTR, L1ENTRY_ATTR_GRAN_MASK); + this_version = cfg.version; + return &this_ops; + } return NULL; } diff --git a/arch/arm64/kvm/hyp/nvhe/iommu/s2mpu.c b/arch/arm64/kvm/hyp/nvhe/iommu/s2mpu.c index f86a525344e8..d52890b1c5b2 100644 --- a/arch/arm64/kvm/hyp/nvhe/iommu/s2mpu.c +++ b/arch/arm64/kvm/hyp/nvhe/iommu/s2mpu.c @@ -174,6 +174,23 @@ static void __set_control_regs(struct pkvm_iommu *dev) writel_relaxed(0, dev->va + REG_NS_CTRL1); writel_relaxed(ctrl0, dev->va + REG_NS_CTRL0); } +static void __set_control_regs_v9(struct pkvm_iommu *dev) +{ + /* Return DECERR to device on permission fault. */ + writel_relaxed(ALL_VIDS_BITMAP, + dev->va + REG_NS_V9_CTRL_ERR_RESP_T_PER_VID_SET); + /* + * Enable interrupts on fault for all VIDs. The IRQ must also be + * specified in DT to get unmasked in the GIC. + */ + writel_relaxed(ALL_VIDS_BITMAP, + dev->va + REG_NS_INTERRUPT_ENABLE_PER_VID_SET); + writel_relaxed(0, dev->va + REG_NS_CTRL0); + /* Enable the S2MPU, otherwise all traffic would be allowed through. */ + writel_relaxed(ALL_VIDS_BITMAP, + dev->va + REG_NS_V9_CTRL_PROT_EN_PER_VID_SET); + writel_relaxed(0, dev->va + REG_NS_V9_CFG_MPTW_ATTRIBUTE); +} /* * Poll the given SFR until its value has all bits of a given mask set. @@ -248,8 +265,8 @@ static void __invalidation_barrier_complete(struct pkvm_iommu *dev) __invalidation_barrier_slow(sync); } - /* Must not access SFRs while S2MPU is busy invalidating (v2 only). */ - if (is_version(dev, S2MPU_VERSION_2)) { + /* Must not access SFRs while S2MPU is busy invalidating */ + if (is_version(dev, S2MPU_VERSION_2) || is_version(dev, S2MPU_VERSION_9)) { __wait_while(dev->va + REG_NS_STATUS, STATUS_BUSY | STATUS_ON_INVALIDATING); } @@ -401,6 +418,64 @@ static int s2mpu_suspend(struct pkvm_iommu *dev) return initialize_with_prot(dev, MPT_PROT_NONE); } +static u32 host_mmio_reg_access_mask_v9(size_t off, bool is_write) +{ + const u32 no_access = 0; + const u32 read_write = (u32)(-1); + const u32 read_only = is_write ? no_access : read_write; + const u32 write_only = is_write ? read_write : no_access; + + switch (off) { + /* Allow reading control registers for debugging. */ + case REG_NS_CTRL0: + return read_only & V9_CTRL0_MASK; + case REG_NS_V9_CTRL_ERR_RESP_T_PER_VID_SET: + return read_only & ALL_VIDS_BITMAP; + case REG_NS_V9_CTRL_PROT_EN_PER_VID_SET: + return read_only & ALL_VIDS_BITMAP; + case REG_NS_V9_READ_STLB: + return write_only & (V9_READ_STLB_MASK_TYPEA|V9_READ_STLB_MASK_TYPEB); + case REG_NS_V9_READ_STLB_TPN: + return read_only & V9_READ_STLB_TPN_MASK; + case REG_NS_V9_READ_STLB_TAG_PPN: + return read_only & V9_READ_STLB_TAG_PPN_MASK; + case REG_NS_V9_READ_STLB_TAG_OTHERS: + return read_only & V9_READ_STLB_TAG_OTHERS_MASK; + case REG_NS_V9_READ_STLB_DATA: + return read_only; + case REG_NS_V9_MPTC_INFO: + return read_only & V9_READ_MPTC_INFO_MASK; + case REG_NS_V9_READ_MPTC: + return write_only & V9_READ_MPTC_MASK; + case REG_NS_V9_READ_MPTC_TAG_PPN: + return read_only & V9_READ_MPTC_TAG_PPN_MASK; + case REG_NS_V9_READ_MPTC_TAG_OTHERS: + return read_only & V9_READ_MPTC_TAG_OTHERS_MASK; + case REG_NS_V9_READ_MPTC_DATA: + return read_only; + case REG_NS_V9_PMMU_INFO: + return read_only & V9_READ_PMMU_INFO_MASK; + case REG_NS_V9_READ_PTLB: + return write_only & V9_READ_PTLB_MASK; + case REG_NS_V9_READ_PTLB_TAG: + return read_only & V9_READ_PTLB_TAG_MASK; + case REG_NS_V9_READ_PTLB_DATA_S1_EN_PPN_AP: + return read_only & V9_READ_PTLB_DATA_S1_ENABLE_PPN_AP_MASK; + case REG_NS_V9_READ_PTLB_DATA_S1_DIS_AP_LIST: + return read_only; + case REG_NS_V9_PMMU_INDICATOR: + return read_only & V9_READ_PMMU_INDICATOR_MASK; + case REG_NS_V9_SWALKER_INFO: + return read_only&V9_SWALKER_INFO_MASK; + }; + if (off >= REG_NS_V9_PMMU_PTLB_INFO(0) && off < REG_NS_V9_PMMU_PTLB_INFO(V9_MAX_PTLB_NUM)) + return read_only&V9_READ_PMMU_PTLB_INFO_MASK; + if (off >= REG_NS_V9_STLB_INFO(0) && off < REG_NS_V9_STLB_INFO(V9_MAX_STLB_NUM)) + return read_only&V9_READ_SLTB_INFO_MASK; + + return no_access; +} + static u32 host_mmio_reg_access_mask_v1_v2(size_t off, bool is_write) { const u32 no_access = 0; @@ -491,12 +566,20 @@ static bool s2mpu_host_dabt_handler(struct pkvm_iommu *dev, cpu_reg(host_ctxt, rd) = readl_relaxed(dev->va + off) & mask; return true; } - -const struct s2mpu_reg_ops ops_v1_v2 = { +/* + * Operations that differ between versions. We need to maintain + * old behaviour were v1 and v2 can be used together. + */ +const struct s2mpu_reg_ops ops_v1_v2 = { .init = __initialize, .host_mmio_reg_access_mask = host_mmio_reg_access_mask_v1_v2, .set_control_regs = __set_control_regs, }; +const struct s2mpu_reg_ops ops_v9 = { + .init = __initialize_v2, + .host_mmio_reg_access_mask = host_mmio_reg_access_mask_v9, + .set_control_regs = __set_control_regs_v9, +}; static int s2mpu_init(void *data, size_t size) { @@ -505,6 +588,7 @@ static int s2mpu_init(void *data, size_t size) phys_addr_t pa; unsigned int gb; int ret = 0; + int smpt_nr_pages, smpt_size; struct s2mpu_mpt_cfg cfg; if (size != sizeof(in_mpt)) @@ -514,8 +598,11 @@ static int s2mpu_init(void *data, size_t size) memcpy(&in_mpt, data, sizeof(in_mpt)); cfg.version = in_mpt.version; + /* Make sure the version sent is supported by the driver. */ if ((cfg.version == S2MPU_VERSION_1) || (cfg.version == S2MPU_VERSION_2)) reg_ops = &ops_v1_v2; + else if (cfg.version == S2MPU_VERSION_9) + reg_ops = &ops_v9; else return -ENODEV; @@ -525,17 +612,20 @@ static int s2mpu_init(void *data, size_t size) if (!mpt_ops) return -EINVAL; + smpt_size = mpt_ops->smpt_size(); + smpt_nr_pages = smpt_size / PAGE_SIZE; + /* Take ownership of all SMPT buffers. This will also map them in. */ for_each_gb(gb) { smpt = kern_hyp_va(in_mpt.fmpt[gb].smpt); pa = __hyp_pa(smpt); - if (!IS_ALIGNED(pa, SMPT_SIZE)) { + if (!IS_ALIGNED(pa, smpt_size)) { ret = -EINVAL; break; } - ret = __pkvm_host_donate_hyp(pa >> PAGE_SHIFT, SMPT_NUM_PAGES); + ret = __pkvm_host_donate_hyp(pa >> PAGE_SHIFT, smpt_nr_pages); if (ret) break; @@ -554,7 +644,7 @@ static int s2mpu_init(void *data, size_t size) break; WARN_ON(__pkvm_hyp_donate_host(__hyp_pa(smpt) >> PAGE_SHIFT, - SMPT_NUM_PAGES)); + smpt_nr_pages)); } memset(&host_mpt, 0, sizeof(host_mpt)); } diff --git a/arch/arm64/kvm/iommu/s2mpu.c b/arch/arm64/kvm/iommu/s2mpu.c index fd63ce1fc6e4..b9dcc3469f06 100644 --- a/arch/arm64/kvm/iommu/s2mpu.c +++ b/arch/arm64/kvm/iommu/s2mpu.c @@ -22,6 +22,7 @@ static int init_s2mpu_driver(u32 version) unsigned long addr; u64 pfn; int ret = 0; + const int smpt_order = smpt_order_from_version(version); mutex_lock(&lock); if (init_done) @@ -39,7 +40,7 @@ static int init_s2mpu_driver(u32 version) /* Allocate SMPT buffers. */ for_each_gb(gb) { - addr = __get_free_pages(GFP_KERNEL, SMPT_ORDER); + addr = __get_free_pages(GFP_KERNEL, smpt_order); if (!addr) { ret = -ENOMEM; goto out_free; @@ -68,7 +69,7 @@ out_free: /* TODO - will driver return the memory? */ if (ret) { for_each_gb(gb) - free_pages((unsigned long)mpt->fmpt[gb].smpt, SMPT_ORDER); + free_pages((unsigned long)mpt->fmpt[gb].smpt, smpt_order); free_page((unsigned long)mpt); } out: