diff --git a/lib/sbi/sbi_pmu.c b/lib/sbi/sbi_pmu.c index 5983a784..6ca4efdb 100644 --- a/lib/sbi/sbi_pmu.c +++ b/lib/sbi/sbi_pmu.c @@ -206,6 +206,12 @@ static int pmu_ctr_validate(struct sbi_pmu_hart_state *phs, return event_idx_type; } +static bool pmu_ctr_idx_validate(unsigned long cbase, unsigned long cmask) +{ + /* Do a basic sanity check of counter base & mask */ + return cmask && cbase + sbi_fls(cmask) < total_ctrs; +} + int sbi_pmu_ctr_fw_read(uint32_t cidx, uint64_t *cval) { int event_idx_type; @@ -472,7 +478,7 @@ int sbi_pmu_ctr_start(unsigned long cbase, unsigned long cmask, int i, cidx; uint64_t edata; - if ((cbase + sbi_fls(cmask)) >= total_ctrs) + if (!pmu_ctr_idx_validate(cbase, cmask)) return ret; if (flags & SBI_PMU_STOP_FLAG_TAKE_SNAPSHOT) @@ -577,8 +583,8 @@ int sbi_pmu_ctr_stop(unsigned long cbase, unsigned long cmask, uint32_t event_code; int i, cidx; - if ((cbase + sbi_fls(cmask)) >= total_ctrs) - return SBI_EINVAL; + if (!pmu_ctr_idx_validate(cbase, cmask)) + return ret; if (flag & SBI_PMU_STOP_FLAG_TAKE_SNAPSHOT) return SBI_ENO_SHMEM; @@ -839,8 +845,7 @@ int sbi_pmu_ctr_cfg_match(unsigned long cidx_base, unsigned long cidx_mask, int ret, event_type, ctr_idx = SBI_ENOTSUPP; u32 event_code; - /* Do a basic sanity check of counter base & mask */ - if ((cidx_base + sbi_fls(cidx_mask)) >= total_ctrs) + if (!pmu_ctr_idx_validate(cidx_base, cidx_mask)) return SBI_EINVAL; event_type = pmu_event_validate(phs, event_idx, event_data);