From e3262b342c7dead031ac7c5792a8a358cffa7ab8 Mon Sep 17 00:00:00 2001 From: Mostafa Saleh Date: Tue, 15 Nov 2022 13:03:36 +0000 Subject: [PATCH] ANDROID: KVM: arm64: s2mpu: S2MPU V9 code Add S2MPU V9 code with current page table ops and version ops. Most SMPT_* macros are now function of protection bits To keep logic modification minimal and avoid duplicate code SMPT and FMPT function are kept the same and the values that changed between S2MPU versions are used as variables instead of macros Bug: 255731794 Change-Id: I2a1b8bab630032d8c923c23e96e1182ce5f734ff Signed-off-by: Mostafa Saleh Signed-off-by: Quentin Perret --- arch/arm64/include/asm/io-mpt-s2mpu.h | 2 + arch/arm64/include/asm/kvm_s2mpu.h | 59 +++++++-- arch/arm64/kvm/hyp/nvhe/iommu/io-mpt-s2mpu.c | 119 ++++++++++++++++--- arch/arm64/kvm/hyp/nvhe/iommu/s2mpu.c | 104 ++++++++++++++-- arch/arm64/kvm/iommu/s2mpu.c | 5 +- 5 files changed, 249 insertions(+), 40 deletions(-) diff --git a/arch/arm64/include/asm/io-mpt-s2mpu.h b/arch/arm64/include/asm/io-mpt-s2mpu.h index 382422b26ed6..0dfff4c08ec8 100644 --- a/arch/arm64/include/asm/io-mpt-s2mpu.h +++ b/arch/arm64/include/asm/io-mpt-s2mpu.h @@ -15,11 +15,13 @@ struct s2mpu_mpt_cfg { }; struct s2mpu_mpt_ops { + u32 (*smpt_size)(void); void (*init_with_prot)(void *dev_va, enum mpt_prot prot); void (*init_with_mpt)(void *dev_va, struct mpt *mpt); void (*apply_range)(void *dev_va, struct mpt *mpt, u32 first_gb, u32 last_gb); void (*prepare_range)(struct mpt *mpt, phys_addr_t first_byte, phys_addr_t last_byte, enum mpt_prot prot); + int (*pte_from_addr_smpt)(u32 *smpt, u64 addr); }; const struct s2mpu_mpt_ops *s2mpu_get_mpt_ops(struct s2mpu_mpt_cfg cfg); diff --git a/arch/arm64/include/asm/kvm_s2mpu.h b/arch/arm64/include/asm/kvm_s2mpu.h index 13104e2f18d7..674963c879d1 100644 --- a/arch/arm64/include/asm/kvm_s2mpu.h +++ b/arch/arm64/include/asm/kvm_s2mpu.h @@ -197,10 +197,32 @@ #define V9_MAX_PTLB_NUM 0x100 #define V9_MAX_STLB_NUM 0x100 -#define V9_L1ENTRY_ATTR_GRAN_MASK BIT(3) -#define V9_MPT_PROT_BITS 4 +#define V9_CTRL0_DIS_CHK_S1L1PTW_MASK BIT(0) +#define V9_CTRL0_DIS_CHK_S1L2PTW_MASK BIT(1) +#define V9_CTRL0_DIS_CHK_USR_MARCHED_REQ_MASK BIT(3) +#define V9_CTRL0_FAULT_MODE_MASK BIT(4) +#define V9_CTRL0_ENF_FLT_MODE_S1_NONSEC_MASK BIT(5) +#define V9_CTRL0_DESTRUCTIVE_AP_CHK_MODE_MASK BIT(6) +#define V9_CTRL0_MASK (V9_CTRL0_DIS_CHK_S1L1PTW_MASK | \ + V9_CTRL0_DESTRUCTIVE_AP_CHK_MODE_MASK | \ + V9_CTRL0_DIS_CHK_USR_MARCHED_REQ_MASK | \ + V9_CTRL0_DIS_CHK_S1L2PTW_MASK | \ + V9_CTRL0_ENF_FLT_MODE_S1_NONSEC_MASK | \ + V9_CTRL0_FAULT_MODE_MASK) + +/* + * S2MPU V9 specific values (some new and some different from old versions) + * to avoid any confusion all names are prefixed with V9. + */ +#define V9_L1ENTRY_ATTR_GRAN_MASK BIT(3) +#define V9_MPT_PROT_BITS 4 #define V9_MPT_ACCESS_SHIFT 2 +/* V1,V2 variants. */ +#define MPT_ACCESS_SHIFT 0 +#define L1ENTRY_ATTR_GRAN_MASK GENMASK(5, 4) +#define MPT_PROT_BITS 2 + #define REG_NS_CTRL0 0x0 #define REG_NS_CTRL1 0x4 #define REG_NS_CFG 0x10 @@ -316,12 +338,11 @@ #define L1ENTRY_ATTR_GRAN_4K 0x0 #define L1ENTRY_ATTR_GRAN_64K 0x1 #define L1ENTRY_ATTR_GRAN_2M 0x2 +#define L1ENTRY_ATTR_GRAN(gran, msk) FIELD_PREP(msk, gran) #define L1ENTRY_ATTR_PROT_MASK GENMASK(2, 1) -#define L1ENTRY_ATTR_GRAN_MASK GENMASK(5, 4) #define L1ENTRY_ATTR_PROT(prot) FIELD_PREP(L1ENTRY_ATTR_PROT_MASK, prot) -#define L1ENTRY_ATTR_GRAN(gran) FIELD_PREP(L1ENTRY_ATTR_GRAN_MASK, gran) #define L1ENTRY_ATTR_1G(prot) L1ENTRY_ATTR_PROT(prot) -#define L1ENTRY_ATTR_L2(gran) (L1ENTRY_ATTR_GRAN(gran) | \ +#define L1ENTRY_ATTR_L2(gran, msk) (L1ENTRY_ATTR_GRAN(gran, msk) | \ L1ENTRY_ATTR_L2TABLE_EN) #define NR_GIGABYTES 64 @@ -339,16 +360,19 @@ #endif static_assert(SMPT_GRAN <= PAGE_SIZE); -#define MPT_PROT_BITS 2 + #define SMPT_WORD_SIZE sizeof(u32) -#define SMPT_ELEMS_PER_BYTE (BITS_PER_BYTE / MPT_PROT_BITS) -#define SMPT_ELEMS_PER_WORD (SMPT_WORD_SIZE * SMPT_ELEMS_PER_BYTE) -#define SMPT_WORD_BYTE_RANGE (SMPT_GRAN * SMPT_ELEMS_PER_WORD) +#define SMPT_ELEMS_PER_BYTE(prot_bits) (BITS_PER_BYTE / (prot_bits)) +#define SMPT_ELEMS_PER_WORD(prot_bits) (SMPT_WORD_SIZE * SMPT_ELEMS_PER_BYTE(prot_bits)) +#define SMPT_WORD_BYTE_RANGE(prot_bits) (SMPT_GRAN * SMPT_ELEMS_PER_WORD(prot_bits)) #define SMPT_NUM_ELEMS (SZ_1G / SMPT_GRAN) -#define SMPT_SIZE (SMPT_NUM_ELEMS / SMPT_ELEMS_PER_BYTE) -#define SMPT_NUM_WORDS (SMPT_SIZE / SMPT_WORD_SIZE) -#define SMPT_NUM_PAGES (SMPT_SIZE / PAGE_SIZE) -#define SMPT_ORDER get_order(SMPT_SIZE) +#define SMPT_SIZE(prot_bits) (SMPT_NUM_ELEMS / SMPT_ELEMS_PER_BYTE(prot_bits)) +#define SMPT_NUM_WORDS(prot_bits) (SMPT_SIZE(prot_bits) / SMPT_WORD_SIZE) +#define SMPT_NUM_PAGES(prot_bits) (SMPT_SIZE(prot_bits) / PAGE_SIZE) +#define SMPT_ORDER(prot_bits) get_order(SMPT_SIZE(prot_bits)) + + +#define SMPT_GRAN_MASK GENMASK(1, 0) /* SysMMU_SYNC registers, relative to SYSMMU_SYNC_S2_OFFSET. */ #define REG_NS_SYNC_CMD 0x0 @@ -375,6 +399,15 @@ enum s2mpu_version { S2MPU_VERSION_9 = 0x90000000, }; +static inline int smpt_order_from_version(enum s2mpu_version version) +{ + if (version == S2MPU_VERSION_9) + return SMPT_ORDER(V9_MPT_PROT_BITS); + else if ((version == S2MPU_VERSION_1) || (version == S2MPU_VERSION_2)) + return SMPT_ORDER(MPT_PROT_BITS); + BUG(); +} + enum mpt_prot { MPT_PROT_NONE = 0, MPT_PROT_R = BIT(0), diff --git a/arch/arm64/kvm/hyp/nvhe/iommu/io-mpt-s2mpu.c b/arch/arm64/kvm/hyp/nvhe/iommu/io-mpt-s2mpu.c index 832368171e05..e101c4c4c0b4 100644 --- a/arch/arm64/kvm/hyp/nvhe/iommu/io-mpt-s2mpu.c +++ b/arch/arm64/kvm/hyp/nvhe/iommu/io-mpt-s2mpu.c @@ -5,6 +5,37 @@ #include +#define GRAN_BYTE(gran) ((gran << V9_MPT_PROT_BITS) | (gran)) +#define GRAN_HWORD(gran) ((GRAN_BYTE(gran) << 8) | (GRAN_BYTE(gran))) +#define GRAN_WORD(gran) (((u32)(GRAN_HWORD(gran) << 16) | (GRAN_HWORD(gran)))) +#define GRAN_DWORD(gran) ((u64)((u64)GRAN_WORD(gran) << 32) | (u64)(GRAN_WORD(gran))) + +#define SMPT_NUM_TO_BYTE(x) ((x) / SMPT_GRAN / SMPT_ELEMS_PER_BYTE(config_prot_bits)) +#define BYTE_TO_SMPT_INDEX(x) ((x) / SMPT_WORD_BYTE_RANGE(config_prot_bits)) + + +/* + * MPT table ops can be configured only for one version at runtime, + * these variables will hold version specific data set a run time init, to avoid + * having duplicate code or unnessery check during operations. + */ +static u32 config_prot_bits; +static u32 config_access_shift; +static const u64 *config_lut_prot; +static u32 config_gran_mask; +static u32 this_version; + +/* + * page table entries for different protection look up table + * granularity is compile time config, so we can do this also for + * this array without having duplicate arrays + */ +static const u64 v9_mpt_prot_doubleword[] = { + [MPT_PROT_NONE] = 0x0000000000000000 | GRAN_DWORD(SMPT_GRAN_ATTR), + [MPT_PROT_R] = 0x4444444444444444 | GRAN_DWORD(SMPT_GRAN_ATTR), + [MPT_PROT_W] = 0x8888888888888888 | GRAN_DWORD(SMPT_GRAN_ATTR), + [MPT_PROT_RW] = 0xcccccccccccccccc | GRAN_DWORD(SMPT_GRAN_ATTR), +}; static const u64 mpt_prot_doubleword[] = { [MPT_PROT_NONE] = 0x0000000000000000, [MPT_PROT_R] = 0x5555555555555555, @@ -12,6 +43,25 @@ static const u64 mpt_prot_doubleword[] = { [MPT_PROT_RW] = 0xffffffffffffffff, }; +static inline int pte_from_addr_smpt(u32 *smpt, u64 addr) +{ + u32 word_idx, idx, pte, val; + + word_idx = BYTE_TO_SMPT_INDEX(addr); + val = READ_ONCE(smpt[word_idx]); + idx = (addr / SMPT_GRAN) % SMPT_ELEMS_PER_WORD(config_prot_bits); + + pte = (val >> (idx * config_prot_bits)) & ((1 << config_prot_bits)-1); + return pte; +} + +static inline int prot_from_addr_smpt(u32 *smpt, u64 addr) +{ + int pte = pte_from_addr_smpt(smpt, addr); + + return (pte >> config_access_shift); +} + /* Set protection bits of SMPT in a given range without using memset. */ static void __set_smpt_range_slow(u32 *smpt, size_t start_gb_byte, size_t end_gb_byte, enum mpt_prot prot) @@ -23,20 +73,21 @@ static void __set_smpt_range_slow(u32 *smpt, size_t start_gb_byte, start_word_byte = start_gb_byte; while (start_word_byte < end_gb_byte) { /* Determine the range of bytes covered by this word. */ - word_idx = start_word_byte / SMPT_WORD_BYTE_RANGE; + word_idx = BYTE_TO_SMPT_INDEX(start_word_byte); end_word_byte = min( - ALIGN(start_word_byte + 1, SMPT_WORD_BYTE_RANGE), + ALIGN(start_word_byte + 1, SMPT_WORD_BYTE_RANGE(config_prot_bits)), end_gb_byte); /* Identify protection bit offsets within the word. */ - first_elem = (start_word_byte / SMPT_GRAN) % SMPT_ELEMS_PER_WORD; - last_elem = ((end_word_byte - 1) / SMPT_GRAN) % SMPT_ELEMS_PER_WORD; + first_elem = (start_word_byte / SMPT_GRAN) % SMPT_ELEMS_PER_WORD(config_prot_bits); + last_elem = + ((end_word_byte - 1) / SMPT_GRAN) % SMPT_ELEMS_PER_WORD(config_prot_bits); /* Modify the corresponding word. */ val = READ_ONCE(smpt[word_idx]); for (i = first_elem; i <= last_elem; i++) { - val &= ~(MPT_PROT_MASK << (i * MPT_PROT_BITS)); - val |= prot << (i * MPT_PROT_BITS); + val &= ~(MPT_PROT_MASK << (i * config_prot_bits + config_access_shift)); + val |= prot << (i * config_prot_bits + config_access_shift); } WRITE_ONCE(smpt[word_idx], val); @@ -49,25 +100,33 @@ static void __set_smpt_range(u32 *smpt, size_t start_gb_byte, size_t end_gb_byte, enum mpt_prot prot) { size_t interlude_start, interlude_end, interlude_bytes, word_idx; - char prot_byte = (char)mpt_prot_doubleword[prot]; + + char prot_byte = (char)config_lut_prot[prot]; if (start_gb_byte >= end_gb_byte) return; /* Check if range spans at least one full u32 word. */ - interlude_start = ALIGN(start_gb_byte, SMPT_WORD_BYTE_RANGE); - interlude_end = ALIGN_DOWN(end_gb_byte, SMPT_WORD_BYTE_RANGE); + interlude_start = ALIGN(start_gb_byte, SMPT_WORD_BYTE_RANGE(config_prot_bits)); + interlude_end = ALIGN_DOWN(end_gb_byte, SMPT_WORD_BYTE_RANGE(config_prot_bits)); - /* If not, fall back to editing bits in the given range. */ + /* + * If not, fall back to editing bits in the given range. + * sets bit for PTEs that are in less than 32 bits (can't be done by memset) + */ if (interlude_start >= interlude_end) { __set_smpt_range_slow(smpt, start_gb_byte, end_gb_byte, prot); return; } /* Use bit-editing for prologue/epilogue, memset for interlude. */ - word_idx = interlude_start / SMPT_WORD_BYTE_RANGE; - interlude_bytes = (interlude_end - interlude_start) / SMPT_GRAN / SMPT_ELEMS_PER_BYTE; + word_idx = BYTE_TO_SMPT_INDEX(interlude_start); + interlude_bytes = SMPT_NUM_TO_BYTE(interlude_end - interlude_start); + /* + * These are pages in the start and at then end that are + * not part of full 32 bit SMPT word. + */ __set_smpt_range_slow(smpt, start_gb_byte, interlude_start, prot); memset(&smpt[word_idx], prot_byte, interlude_bytes); __set_smpt_range_slow(smpt, interlude_end, end_gb_byte, prot); @@ -79,8 +138,8 @@ static bool __is_smpt_uniform(u32 *smpt, enum mpt_prot prot) size_t i; u64 *doublewords = (u64 *)smpt; - for (i = 0; i < SMPT_NUM_WORDS / 2; i++) { - if (doublewords[i] != mpt_prot_doubleword[prot]) + for (i = 0; i < SMPT_NUM_WORDS(config_prot_bits) / 2; i++) { + if (doublewords[i] != config_lut_prot[prot]) return false; } return true; @@ -140,6 +199,11 @@ static void __set_fmpt_range(struct fmpt *fmpt, size_t start_gb_byte, fmpt->flags = MPT_UPDATE_L1; } +static u32 smpt_size(void) +{ + return SMPT_SIZE(config_prot_bits); +} + static void __set_l1entry_attr_with_prot(void *dev_va, unsigned int gb, unsigned int vid, enum mpt_prot prot) { @@ -154,7 +218,7 @@ static void __set_l1entry_attr_with_fmpt(void *dev_va, unsigned int gb, __set_l1entry_attr_with_prot(dev_va, gb, vid, fmpt->prot); } else { /* Order against writes to the SMPT. */ - writel(L1ENTRY_ATTR_L2(SMPT_GRAN_ATTR), + writel(config_gran_mask | L1ENTRY_ATTR_L2TABLE_EN, dev_va + REG_NS_L1ENTRY_ATTR(vid, gb)); } } @@ -218,21 +282,40 @@ static void prepare_range(struct mpt *mpt, phys_addr_t first_byte, __set_fmpt_range(fmpt, start_gb_byte, end_gb_byte, prot); if (fmpt->flags & MPT_UPDATE_L2) - kvm_flush_dcache_to_poc(fmpt->smpt, SMPT_SIZE); + kvm_flush_dcache_to_poc(fmpt->smpt, smpt_size()); } } static const struct s2mpu_mpt_ops this_ops = { + .smpt_size = smpt_size, .init_with_prot = init_with_prot, .init_with_mpt = init_with_mpt, .apply_range = apply_range, .prepare_range = prepare_range, + .pte_from_addr_smpt = pte_from_addr_smpt, }; const struct s2mpu_mpt_ops *s2mpu_get_mpt_ops(struct s2mpu_mpt_cfg cfg) { - if ((cfg.version == S2MPU_VERSION_1) || (cfg.version == S2MPU_VERSION_2)) - return &this_ops; + /* If called before with different version return NULL. */ + if (WARN_ON(this_version && (this_version != cfg.version))) + return NULL; + /* 2MB granularity not supported in V9 */ + if ((cfg.version == S2MPU_VERSION_9) && (SMPT_GRAN_ATTR != L1ENTRY_ATTR_GRAN_2M)) { + config_prot_bits = V9_MPT_PROT_BITS; + config_access_shift = V9_MPT_ACCESS_SHIFT; + config_lut_prot = v9_mpt_prot_doubleword; + config_gran_mask = L1ENTRY_ATTR_GRAN(SMPT_GRAN_ATTR, V9_L1ENTRY_ATTR_GRAN_MASK); + this_version = cfg.version; + return &this_ops; + } else if ((cfg.version == S2MPU_VERSION_2) || (cfg.version == S2MPU_VERSION_1)) { + config_prot_bits = MPT_PROT_BITS; + config_access_shift = MPT_ACCESS_SHIFT; + config_lut_prot = mpt_prot_doubleword; + config_gran_mask = L1ENTRY_ATTR_GRAN(SMPT_GRAN_ATTR, L1ENTRY_ATTR_GRAN_MASK); + this_version = cfg.version; + return &this_ops; + } return NULL; } diff --git a/arch/arm64/kvm/hyp/nvhe/iommu/s2mpu.c b/arch/arm64/kvm/hyp/nvhe/iommu/s2mpu.c index f86a525344e8..d52890b1c5b2 100644 --- a/arch/arm64/kvm/hyp/nvhe/iommu/s2mpu.c +++ b/arch/arm64/kvm/hyp/nvhe/iommu/s2mpu.c @@ -174,6 +174,23 @@ static void __set_control_regs(struct pkvm_iommu *dev) writel_relaxed(0, dev->va + REG_NS_CTRL1); writel_relaxed(ctrl0, dev->va + REG_NS_CTRL0); } +static void __set_control_regs_v9(struct pkvm_iommu *dev) +{ + /* Return DECERR to device on permission fault. */ + writel_relaxed(ALL_VIDS_BITMAP, + dev->va + REG_NS_V9_CTRL_ERR_RESP_T_PER_VID_SET); + /* + * Enable interrupts on fault for all VIDs. The IRQ must also be + * specified in DT to get unmasked in the GIC. + */ + writel_relaxed(ALL_VIDS_BITMAP, + dev->va + REG_NS_INTERRUPT_ENABLE_PER_VID_SET); + writel_relaxed(0, dev->va + REG_NS_CTRL0); + /* Enable the S2MPU, otherwise all traffic would be allowed through. */ + writel_relaxed(ALL_VIDS_BITMAP, + dev->va + REG_NS_V9_CTRL_PROT_EN_PER_VID_SET); + writel_relaxed(0, dev->va + REG_NS_V9_CFG_MPTW_ATTRIBUTE); +} /* * Poll the given SFR until its value has all bits of a given mask set. @@ -248,8 +265,8 @@ static void __invalidation_barrier_complete(struct pkvm_iommu *dev) __invalidation_barrier_slow(sync); } - /* Must not access SFRs while S2MPU is busy invalidating (v2 only). */ - if (is_version(dev, S2MPU_VERSION_2)) { + /* Must not access SFRs while S2MPU is busy invalidating */ + if (is_version(dev, S2MPU_VERSION_2) || is_version(dev, S2MPU_VERSION_9)) { __wait_while(dev->va + REG_NS_STATUS, STATUS_BUSY | STATUS_ON_INVALIDATING); } @@ -401,6 +418,64 @@ static int s2mpu_suspend(struct pkvm_iommu *dev) return initialize_with_prot(dev, MPT_PROT_NONE); } +static u32 host_mmio_reg_access_mask_v9(size_t off, bool is_write) +{ + const u32 no_access = 0; + const u32 read_write = (u32)(-1); + const u32 read_only = is_write ? no_access : read_write; + const u32 write_only = is_write ? read_write : no_access; + + switch (off) { + /* Allow reading control registers for debugging. */ + case REG_NS_CTRL0: + return read_only & V9_CTRL0_MASK; + case REG_NS_V9_CTRL_ERR_RESP_T_PER_VID_SET: + return read_only & ALL_VIDS_BITMAP; + case REG_NS_V9_CTRL_PROT_EN_PER_VID_SET: + return read_only & ALL_VIDS_BITMAP; + case REG_NS_V9_READ_STLB: + return write_only & (V9_READ_STLB_MASK_TYPEA|V9_READ_STLB_MASK_TYPEB); + case REG_NS_V9_READ_STLB_TPN: + return read_only & V9_READ_STLB_TPN_MASK; + case REG_NS_V9_READ_STLB_TAG_PPN: + return read_only & V9_READ_STLB_TAG_PPN_MASK; + case REG_NS_V9_READ_STLB_TAG_OTHERS: + return read_only & V9_READ_STLB_TAG_OTHERS_MASK; + case REG_NS_V9_READ_STLB_DATA: + return read_only; + case REG_NS_V9_MPTC_INFO: + return read_only & V9_READ_MPTC_INFO_MASK; + case REG_NS_V9_READ_MPTC: + return write_only & V9_READ_MPTC_MASK; + case REG_NS_V9_READ_MPTC_TAG_PPN: + return read_only & V9_READ_MPTC_TAG_PPN_MASK; + case REG_NS_V9_READ_MPTC_TAG_OTHERS: + return read_only & V9_READ_MPTC_TAG_OTHERS_MASK; + case REG_NS_V9_READ_MPTC_DATA: + return read_only; + case REG_NS_V9_PMMU_INFO: + return read_only & V9_READ_PMMU_INFO_MASK; + case REG_NS_V9_READ_PTLB: + return write_only & V9_READ_PTLB_MASK; + case REG_NS_V9_READ_PTLB_TAG: + return read_only & V9_READ_PTLB_TAG_MASK; + case REG_NS_V9_READ_PTLB_DATA_S1_EN_PPN_AP: + return read_only & V9_READ_PTLB_DATA_S1_ENABLE_PPN_AP_MASK; + case REG_NS_V9_READ_PTLB_DATA_S1_DIS_AP_LIST: + return read_only; + case REG_NS_V9_PMMU_INDICATOR: + return read_only & V9_READ_PMMU_INDICATOR_MASK; + case REG_NS_V9_SWALKER_INFO: + return read_only&V9_SWALKER_INFO_MASK; + }; + if (off >= REG_NS_V9_PMMU_PTLB_INFO(0) && off < REG_NS_V9_PMMU_PTLB_INFO(V9_MAX_PTLB_NUM)) + return read_only&V9_READ_PMMU_PTLB_INFO_MASK; + if (off >= REG_NS_V9_STLB_INFO(0) && off < REG_NS_V9_STLB_INFO(V9_MAX_STLB_NUM)) + return read_only&V9_READ_SLTB_INFO_MASK; + + return no_access; +} + static u32 host_mmio_reg_access_mask_v1_v2(size_t off, bool is_write) { const u32 no_access = 0; @@ -491,12 +566,20 @@ static bool s2mpu_host_dabt_handler(struct pkvm_iommu *dev, cpu_reg(host_ctxt, rd) = readl_relaxed(dev->va + off) & mask; return true; } - -const struct s2mpu_reg_ops ops_v1_v2 = { +/* + * Operations that differ between versions. We need to maintain + * old behaviour were v1 and v2 can be used together. + */ +const struct s2mpu_reg_ops ops_v1_v2 = { .init = __initialize, .host_mmio_reg_access_mask = host_mmio_reg_access_mask_v1_v2, .set_control_regs = __set_control_regs, }; +const struct s2mpu_reg_ops ops_v9 = { + .init = __initialize_v2, + .host_mmio_reg_access_mask = host_mmio_reg_access_mask_v9, + .set_control_regs = __set_control_regs_v9, +}; static int s2mpu_init(void *data, size_t size) { @@ -505,6 +588,7 @@ static int s2mpu_init(void *data, size_t size) phys_addr_t pa; unsigned int gb; int ret = 0; + int smpt_nr_pages, smpt_size; struct s2mpu_mpt_cfg cfg; if (size != sizeof(in_mpt)) @@ -514,8 +598,11 @@ static int s2mpu_init(void *data, size_t size) memcpy(&in_mpt, data, sizeof(in_mpt)); cfg.version = in_mpt.version; + /* Make sure the version sent is supported by the driver. */ if ((cfg.version == S2MPU_VERSION_1) || (cfg.version == S2MPU_VERSION_2)) reg_ops = &ops_v1_v2; + else if (cfg.version == S2MPU_VERSION_9) + reg_ops = &ops_v9; else return -ENODEV; @@ -525,17 +612,20 @@ static int s2mpu_init(void *data, size_t size) if (!mpt_ops) return -EINVAL; + smpt_size = mpt_ops->smpt_size(); + smpt_nr_pages = smpt_size / PAGE_SIZE; + /* Take ownership of all SMPT buffers. This will also map them in. */ for_each_gb(gb) { smpt = kern_hyp_va(in_mpt.fmpt[gb].smpt); pa = __hyp_pa(smpt); - if (!IS_ALIGNED(pa, SMPT_SIZE)) { + if (!IS_ALIGNED(pa, smpt_size)) { ret = -EINVAL; break; } - ret = __pkvm_host_donate_hyp(pa >> PAGE_SHIFT, SMPT_NUM_PAGES); + ret = __pkvm_host_donate_hyp(pa >> PAGE_SHIFT, smpt_nr_pages); if (ret) break; @@ -554,7 +644,7 @@ static int s2mpu_init(void *data, size_t size) break; WARN_ON(__pkvm_hyp_donate_host(__hyp_pa(smpt) >> PAGE_SHIFT, - SMPT_NUM_PAGES)); + smpt_nr_pages)); } memset(&host_mpt, 0, sizeof(host_mpt)); } diff --git a/arch/arm64/kvm/iommu/s2mpu.c b/arch/arm64/kvm/iommu/s2mpu.c index fd63ce1fc6e4..b9dcc3469f06 100644 --- a/arch/arm64/kvm/iommu/s2mpu.c +++ b/arch/arm64/kvm/iommu/s2mpu.c @@ -22,6 +22,7 @@ static int init_s2mpu_driver(u32 version) unsigned long addr; u64 pfn; int ret = 0; + const int smpt_order = smpt_order_from_version(version); mutex_lock(&lock); if (init_done) @@ -39,7 +40,7 @@ static int init_s2mpu_driver(u32 version) /* Allocate SMPT buffers. */ for_each_gb(gb) { - addr = __get_free_pages(GFP_KERNEL, SMPT_ORDER); + addr = __get_free_pages(GFP_KERNEL, smpt_order); if (!addr) { ret = -ENOMEM; goto out_free; @@ -68,7 +69,7 @@ out_free: /* TODO - will driver return the memory? */ if (ret) { for_each_gb(gb) - free_pages((unsigned long)mpt->fmpt[gb].smpt, SMPT_ORDER); + free_pages((unsigned long)mpt->fmpt[gb].smpt, smpt_order); free_page((unsigned long)mpt); } out: