From 28ac169cfea1384cb44c4a414acc2f59cff0e9ca Mon Sep 17 00:00:00 2001 From: Eyck-Alexander Jentzsch Date: Wed, 19 Feb 2025 10:10:41 +0100 Subject: [PATCH] adds narrowing fixed point instructions --- gen_input/templates/interp/CORENAME.cpp.gtl | 26 ++++ src/vm/vector_functions.h | 6 +- src/vm/vector_functions.hpp | 127 ++++++++++++-------- 3 files changed, 104 insertions(+), 55 deletions(-) diff --git a/gen_input/templates/interp/CORENAME.cpp.gtl b/gen_input/templates/interp/CORENAME.cpp.gtl index df29e80..234adbc 100644 --- a/gen_input/templates/interp/CORENAME.cpp.gtl +++ b/gen_input/templates/interp/CORENAME.cpp.gtl @@ -441,6 +441,32 @@ if(vector != null) {%> throw new std::runtime_error("Unsupported sew bit value"); } } + bool sat_vector_vector_vw(uint8_t* V, uint8_t funct6, uint8_t funct3, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, uint64_t vxrm, bool vm, uint8_t vd, uint8_t vs2, uint8_t vs1, uint8_t sew_val){ + switch(sew_val){ + case 0b000: + return softvector::sat_vector_vector_op<${vlen}, uint8_t, uint16_t>(V, funct6, funct3, vl, vstart, vtype, vxrm, vm, vd, vs2, vs1); + case 0b001: + return softvector::sat_vector_vector_op<${vlen}, uint16_t, uint32_t>(V, funct6, funct3, vl, vstart, vtype, vxrm, vm, vd, vs2, vs1); + case 0b010: + return softvector::sat_vector_vector_op<${vlen}, uint32_t, uint64_t>(V, funct6, funct3, vl, vstart, vtype, vxrm, vm, vd, vs2, vs1); + case 0b011: // would require 128 bits vs2 value + default: + throw new std::runtime_error("Unsupported sew bit value"); + } + } + bool sat_vector_imm_vw(uint8_t* V, uint8_t funct6, uint8_t funct3, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, uint64_t vxrm, bool vm, uint8_t vd, uint8_t vs2, int64_t imm, uint8_t sew_val){ + switch(sew_val){ + case 0b000: + return softvector::sat_vector_imm_op<${vlen}, uint8_t, uint16_t>(V, funct6, funct3, vl, vstart, vtype, vxrm, vm, vd, vs2, imm); + case 0b001: + return softvector::sat_vector_imm_op<${vlen}, uint16_t, uint32_t>(V, funct6, funct3, vl, vstart, vtype, vxrm, vm, vd, vs2, imm); + case 0b010: + return softvector::sat_vector_imm_op<${vlen}, uint32_t, uint64_t>(V, funct6, funct3, vl, vstart, vtype, vxrm, vm, vd, vs2, imm); + case 0b011: // would require 128 bits vs2 value + default: + throw new std::runtime_error("Unsupported sew bit value"); + } + } <%}%> uint64_t fetch_count{0}; uint64_t tval{0}; diff --git a/src/vm/vector_functions.h b/src/vm/vector_functions.h index 6416c97..65ea400 100644 --- a/src/vm/vector_functions.h +++ b/src/vm/vector_functions.h @@ -91,12 +91,12 @@ void carry_vector_vector_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vs template void carry_vector_imm_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, typename std::make_signed::type imm); -template +template bool sat_vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, int64_t vxrm, bool vm, unsigned vd, unsigned vs2, unsigned vs1); -template +template bool sat_vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, int64_t vxrm, bool vm, - unsigned vd, unsigned vs2, typename std::make_signed::type imm); + unsigned vd, unsigned vs2, typename std::make_signed::type imm); } // namespace softvector #include "vm/vector_functions.hpp" #endif /* _VM_VECTOR_FUNCTIONS_H_ */ diff --git a/src/vm/vector_functions.hpp b/src/vm/vector_functions.hpp index d6c496c..e196989 100644 --- a/src/vm/vector_functions.hpp +++ b/src/vm/vector_functions.hpp @@ -606,15 +606,15 @@ template T roundoff(T v, uint64_t d, int64_t vxrm) { unsigned r = get_rounding_increment(v, d, vxrm); return (v >> d) + r; } -template -std::function get_sat_funct(unsigned funct6, unsigned funct3) { +template +std::function get_sat_funct(unsigned funct6, unsigned funct3) { if(funct3 == OPIVV || funct3 == OPIVX || funct3 == OPIVI) switch(funct6) { case 0b100000: // VSADDU - return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { - auto res = static_cast>(vs2) + static_cast>(vs1); - if(res > std::numeric_limits::max()) { - vd = std::numeric_limits::max(); + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { + auto res = static_cast>(vs2) + static_cast>(vs1); + if(res > std::numeric_limits::max()) { + vd = std::numeric_limits::max(); return 1; } else { vd = res; @@ -622,14 +622,14 @@ std::function get } }; case 0b100001: // VSADD - return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { - auto res = static_cast>>(static_cast>(vs2)) + - static_cast>>(static_cast>(vs1)); - if(res < std::numeric_limits>::min()) { - vd = std::numeric_limits>::min(); + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { + auto res = static_cast>>(static_cast>(vs2)) + + static_cast>>(static_cast>(vs1)); + if(res < std::numeric_limits>::min()) { + vd = std::numeric_limits>::min(); return 1; - } else if(res > std::numeric_limits>::max()) { - vd = std::numeric_limits>::max(); + } else if(res > std::numeric_limits>::max()) { + vd = std::numeric_limits>::max(); return 1; } else { vd = res; @@ -637,7 +637,7 @@ std::function get } }; case 0b100010: // VSSUBU - return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { if(vs2 < vs1) { vd = 0; return 1; @@ -647,14 +647,14 @@ std::function get } }; case 0b100011: // VSSUB - return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { - auto res = static_cast>>(static_cast>(vs2)) - - static_cast>>(static_cast>(vs1)); - if(res < std::numeric_limits>::min()) { - vd = std::numeric_limits>::min(); + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { + auto res = static_cast>>(static_cast>(vs2)) - + static_cast>>(static_cast>(vs1)); + if(res < std::numeric_limits>::min()) { + vd = std::numeric_limits>::min(); return 1; - } else if(res > std::numeric_limits>::max()) { - vd = std::numeric_limits>::max(); + } else if(res > std::numeric_limits>::max()) { + vd = std::numeric_limits>::max(); return 1; } else { vd = res; @@ -662,15 +662,15 @@ std::function get } }; case 0b100111: // VSMUL - return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { - auto big_val = static_cast>>(static_cast>(vs2)) * - static_cast>>(static_cast>(vs1)); + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { + auto big_val = static_cast>>(static_cast>(vs2)) * + static_cast>>(static_cast>(vs1)); auto res = roundoff(big_val, vtype.sew() - 1, vxrm); - if(res < std::numeric_limits>::min()) { - vd = std::numeric_limits>::min(); + if(res < std::numeric_limits>::min()) { + vd = std::numeric_limits>::min(); return 1; - } else if(res > std::numeric_limits>::max()) { - vd = std::numeric_limits>::max(); + } else if(res > std::numeric_limits>::max()) { + vd = std::numeric_limits>::max(); return 1; } else { vd = res; @@ -678,45 +678,68 @@ std::function get } }; case 0b101010: // VSSRL - return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { - vd = roundoff(vs2, vs1 & shift_mask(), vxrm); + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { + vd = roundoff(vs2, vs1 & shift_mask(), vxrm); return 0; }; case 0b101011: // VSSRA - return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { - vd = roundoff(static_cast>(vs2), vs1 & shift_mask(), vxrm); + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { + vd = roundoff(static_cast>(vs2), vs1 & shift_mask(), vxrm); return 0; }; - // case 0b101110: // VNCLIPU - // case 0b101111: // VNCLIP + case 0b101110: // VNCLIPU + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { + auto res = roundoff(vs2, vs1 & shift_mask(), vxrm); + if(res > std::numeric_limits::max()) { + vd = std::numeric_limits::max(); + return 1; + } else { + vd = res; + return 0; + } + }; + case 0b101111: // VNCLIP + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { + auto res = roundoff(static_cast>(vs2), vs1 & shift_mask(), vxrm); + if(res < std::numeric_limits>::min()) { + vd = std::numeric_limits>::min(); + return 1; + } else if(res > std::numeric_limits>::max()) { + vd = std::numeric_limits>::max(); + return 1; + } else { + vd = res; + return 0; + } + }; default: throw new std::runtime_error("Unknown funct6 in get_sat_funct"); } else if(funct3 == OPMVV || funct3 == OPMVX) switch(funct6) { case 0b001000: // VAADDU - return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { - auto res = static_cast>(vs2) + static_cast>(vs1); + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { + auto res = static_cast>(vs2) + static_cast>(vs1); vd = roundoff(res, 1, vxrm); return 0; }; case 0b001001: // VAADD - return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { - auto res = static_cast>>(static_cast>(vs2)) + - static_cast>>(static_cast>(vs1)); + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { + auto res = static_cast>>(static_cast>(vs2)) + + static_cast>>(static_cast>(vs1)); vd = roundoff(res, 1, vxrm); return 0; }; case 0b001010: // VASUBU - return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { - auto res = static_cast>(vs2) - static_cast>(vs1); + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { + auto res = static_cast>(vs2) - static_cast>(vs1); vd = roundoff(res, 1, vxrm); return 0; }; case 0b001011: // VASUB - return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { - auto res = static_cast>>(static_cast>(vs2)) - - static_cast>>(static_cast>(vs1)); + return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { + auto res = static_cast>>(static_cast>(vs2)) - + static_cast>>(static_cast>(vs1)); vd = roundoff(res, 1, vxrm); return 0; }; @@ -726,16 +749,16 @@ std::function get else throw new std::runtime_error("Unknown funct3 in get_sat_funct"); } -template +template bool sat_vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, int64_t vxrm, bool vm, unsigned vd, unsigned vs2, unsigned vs1) { uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); bool saturated = false; vmask_view mask_reg = read_vmask(V, elem_count); - auto vs1_view = get_vreg(V, vs1, elem_count); - auto vs2_view = get_vreg(V, vs2, elem_count); + auto vs1_view = get_vreg(V, vs1, elem_count); + auto vs2_view = get_vreg(V, vs2, elem_count); auto vd_view = get_vreg(V, vd, elem_count); - auto fn = get_sat_funct(funct6, funct3); + auto fn = get_sat_funct(funct6, funct3); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { @@ -754,15 +777,15 @@ bool sat_vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t } return saturated; } -template +template bool sat_vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, int64_t vxrm, bool vm, - unsigned vd, unsigned vs2, typename std::make_signed::type imm) { + unsigned vd, unsigned vs2, typename std::make_signed::type imm) { uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); bool saturated = false; vmask_view mask_reg = read_vmask(V, elem_count); - auto vs2_view = get_vreg(V, vs2, elem_count); + auto vs2_view = get_vreg(V, vs2, elem_count); auto vd_view = get_vreg(V, vd, elem_count); - auto fn = get_sat_funct(funct6, funct3); + auto fn = get_sat_funct(funct6, funct3); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {