From e1911bc4503b37234704a52375e0ae0266e93c20 Mon Sep 17 00:00:00 2001 From: Eyck-Alexander Jentzsch Date: Tue, 18 Feb 2025 21:45:16 +0100 Subject: [PATCH] adds vsmul, widens functions parameters for sat_vector operations --- src/vm/vector_functions.hpp | 55 ++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/src/vm/vector_functions.hpp b/src/vm/vector_functions.hpp index 3499562..32d3160 100644 --- a/src/vm/vector_functions.hpp +++ b/src/vm/vector_functions.hpp @@ -596,24 +596,20 @@ template bool get_rounding_increment(T v, uint64_t d, int64_t vxrm) case 0b10: // rdn: round-down (truncate) return false; case 0b11: // rod: round-to-odd (jam) - return (!(v & (1 << d)) && ((v & ((1 << d) - 1)) != 0)); + return (!(v & (static_cast(1) << d)) && ((v & ((static_cast(1) << d) - 1)) != 0)); } return false; } -template T roundoff_unsigned(T v, uint64_t d, int64_t vxrm) { - unsigned r = get_rounding_increment(v, d, vxrm); - return (v >> d) + r; -} -template T roundoff_signed(T v, uint64_t d, int64_t vxrm) { +template T roundoff(T v, uint64_t d, int64_t vxrm) { unsigned r = get_rounding_increment(v, d, vxrm); return (v >> d) + r; } template -std::function get_sat_funct(unsigned funct6, unsigned funct3) { +std::function get_sat_funct(unsigned funct6, unsigned funct3) { if(funct3 == OPIVV || funct3 == OPIVX || funct3 == OPIVI) switch(funct6) { case 0b100000: // VSADDU - return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { auto res = static_cast>(vs2) + static_cast>(vs1); if(res > std::numeric_limits::max()) { vd = std::numeric_limits::max(); @@ -624,7 +620,7 @@ std::function get_sat_func } }; case 0b100001: // VSADD - return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { auto res = static_cast>>(static_cast>(vs2)) + static_cast>>(static_cast>(vs1)); if(res < std::numeric_limits>::min()) { @@ -639,7 +635,7 @@ std::function get_sat_func } }; case 0b100010: // VSSUBU - return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { if(vs2 < vs1) { vd = 0; return 1; @@ -649,7 +645,7 @@ std::function get_sat_func } }; case 0b100011: // VSSUB - return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { auto res = static_cast>>(static_cast>(vs2)) - static_cast>>(static_cast>(vs1)); if(res < std::numeric_limits>::min()) { @@ -663,7 +659,22 @@ std::function get_sat_func return 0; } }; - // case 0b100111: // VSMUL + case 0b100111: // VSMUL + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + auto big_val = static_cast>>(static_cast>(vs2)) * + static_cast>>(static_cast>(vs1)); + auto res = roundoff(big_val, vtype.sew() - 1, vxrm); + if(res < std::numeric_limits>::min()) { + vd = std::numeric_limits>::min(); + return 1; + } else if(res > std::numeric_limits>::max()) { + vd = std::numeric_limits>::max(); + return 1; + } else { + vd = res; + return 0; + } + }; // case 0b101010: // VSSRL // case 0b101011: // VSSRA // case 0b101110: // VNCLIPU @@ -674,29 +685,29 @@ std::function get_sat_func else if(funct3 == OPMVV || funct3 == OPMVX) switch(funct6) { case 0b001000: // VAADDU - return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { auto res = static_cast>(vs2) + static_cast>(vs1); - vd = roundoff_unsigned(res, 1, vxrm); + vd = roundoff(res, 1, vxrm); return 0; }; case 0b001001: // VAADD - return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { auto res = static_cast>>(static_cast>(vs2)) + static_cast>>(static_cast>(vs1)); - vd = roundoff_signed(res, 1, vxrm); + vd = roundoff(res, 1, vxrm); return 0; }; case 0b001010: // VASUBU - return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { auto res = static_cast>(vs2) - static_cast>(vs1); - vd = roundoff_unsigned(res, 1, vxrm); + vd = roundoff(res, 1, vxrm); return 0; }; case 0b001011: // VASUB - return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { auto res = static_cast>>(static_cast>(vs2)) - static_cast>>(static_cast>(vs1)); - vd = roundoff_signed(res, 1, vxrm); + vd = roundoff(res, 1, vxrm); return 0; }; default: @@ -720,7 +731,7 @@ bool sat_vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { - saturated |= fn(vxrm, vd_view[idx], vs2_view[idx], vs1_view[idx]); + saturated |= fn(vxrm, vtype, vd_view[idx], vs2_view[idx], vs1_view[idx]); } else { vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx]; } @@ -747,7 +758,7 @@ bool sat_vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { - saturated |= fn(vxrm, vd_view[idx], vs2_view[idx], imm); + saturated |= fn(vxrm, vtype, vd_view[idx], vs2_view[idx], imm); } else { vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx]; }