From ac1322d66b6c1375013a3b89b3f8b2376f30a0de Mon Sep 17 00:00:00 2001 From: Eyck-Alexander Jentzsch Date: Mon, 17 Feb 2025 15:48:30 +0100 Subject: [PATCH] changes to ternary functions for Multiply-Add Instructions --- gen_input/templates/interp/CORENAME.cpp.gtl | 1 - src/vm/vector_functions.hpp | 102 ++++++++++---------- 2 files changed, 51 insertions(+), 52 deletions(-) diff --git a/gen_input/templates/interp/CORENAME.cpp.gtl b/gen_input/templates/interp/CORENAME.cpp.gtl index e8ce241..96ee429 100644 --- a/gen_input/templates/interp/CORENAME.cpp.gtl +++ b/gen_input/templates/interp/CORENAME.cpp.gtl @@ -385,7 +385,6 @@ if(vector != null) {%> throw new std::runtime_error("Unsupported sew bit value"); } } - <%}%> uint64_t fetch_count{0}; uint64_t tval{0}; diff --git a/src/vm/vector_functions.hpp b/src/vm/vector_functions.hpp index d00624c..c35740f 100644 --- a/src/vm/vector_functions.hpp +++ b/src/vm/vector_functions.hpp @@ -94,40 +94,40 @@ template <> struct twice { using type = __uint128_t; }; template using twice_t = typename twice::type; // for convenience template -std::function get_funct(unsigned funct6, unsigned funct3) { +std::function get_funct(unsigned funct6, unsigned funct3) { if(funct3 == OPIVV || funct3 == OPIVX || funct3 == OPIVI) switch(funct6) { case 0b000000: // VADD - return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; case 0b000010: // VSUB - return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 - vs1; }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 - vs1; }; case 0b000011: // VRSUB - return [](src2_elem_t vs2, src1_elem_t vs1) { return vs1 - vs2; }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1 - vs2; }; case 0b000100: // VMINU - return [](src2_elem_t vs2, src1_elem_t vs1) { return std::min(vs2, static_cast(vs1)); }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return std::min(vs2, static_cast(vs1)); }; case 0b000101: // VMIN - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return std::min(static_cast>(vs2), static_cast>(vs1)); }; case 0b000110: // VMAXU - return [](src2_elem_t vs2, src1_elem_t vs1) { return std::max(vs2, static_cast(vs1)); }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return std::max(vs2, static_cast(vs1)); }; case 0b000111: // VMAX - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return std::max(static_cast>(vs2), static_cast>(vs1)); }; case 0b001001: // VAND - return [](src2_elem_t vs2, src1_elem_t vs1) { return vs1 & vs2; }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1 & vs2; }; case 0b001010: // VOR - return [](src2_elem_t vs2, src1_elem_t vs1) { return vs1 | vs2; }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1 | vs2; }; case 0b001011: // VXOR - return [](src2_elem_t vs2, src1_elem_t vs1) { return vs1 ^ vs2; }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1 ^ vs2; }; // case 0b001100: // VRGATHER // case 0b001110: // VRGATHEREI16 // case 0b001111: // VLSLIDEDOWN case 0b010000: // VADC - return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; case 0b010010: // VSBC - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(static_cast>(vs2) - static_cast>(vs1)); }; @@ -137,19 +137,19 @@ std::function get_funct(unsigned funct6, // case 0b100010: // VSSUBU // case 0b100011: // VSSUB case 0b100101: // VSLL - return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 << (vs1 & shift_mask()); }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 << (vs1 & shift_mask()); }; // case 0b100111: // VSMUL // case 0b100111: // VMVR case 0b101000: // VSRL - return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 >> (vs1 & shift_mask()); }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 >> (vs1 & shift_mask()); }; case 0b101001: // VSRA - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(vs2) >> (vs1 & shift_mask()); }; case 0b101100: // VNSRL - return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 >> (vs1 & shift_mask()); }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 >> (vs1 & shift_mask()); }; case 0b101101: // VNSRA - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(vs2) >> (vs1 & shift_mask()); }; // case 0b101110: // VNCLIPU @@ -185,14 +185,14 @@ std::function get_funct(unsigned funct6, // case 0b011110: // VMNOR // case 0b011111: // VMXNOR case 0b100000: // VDIVU - return [](src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t { if(vs1 == 0) return -1; else return vs2 / vs1; }; case 0b100001: // VDIV - return [](src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t { if(vs1 == 0) return -1; else if(vs2 == std::numeric_limits>::min() && @@ -202,14 +202,14 @@ std::function get_funct(unsigned funct6, return static_cast>(vs2) / static_cast>(vs1); }; case 0b100010: // VREMU - return [](src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t { if(vs1 == 0) return vs2; else return vs2 % vs1; }; case 0b100011: // VREM - return [](src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t { if(vs1 == 0) return vs2; else if(vs2 == std::numeric_limits>::min() && @@ -219,68 +219,72 @@ std::function get_funct(unsigned funct6, return static_cast>(vs2) % static_cast>(vs1); }; case 0b100100: // VMULHU - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return (static_cast>(vs2) * static_cast>(vs1)) >> sizeof(dest_elem_t) * 8; }; case 0b100101: // VMUL - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(vs2) * static_cast>(vs1); }; case 0b100110: // VMULHSU - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return (static_cast>>(static_cast>(vs2)) * static_cast>(vs1)) >> sizeof(dest_elem_t) * 8; }; case 0b100111: // VMULH - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return (static_cast>>(static_cast>(vs2)) * static_cast>>(static_cast>(vs1))) >> sizeof(dest_elem_t) * 8; }; - // case 0b101001: // VMADD - // case 0b101011: // VNMSUB - // case 0b101101: // VMACC - // case 0b101111: // VNMSAC + case 0b101001: // VMADD + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1 * vd + vs2; }; + case 0b101011: // VNMSUB + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return -1 * (vs1 * vd) + vs2; }; + case 0b101101: // VMACC + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1 * vs2 + vd; }; + case 0b101111: // VNMSAC + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return -1 * (vs1 * vs2) + vd; }; case 0b110000: // VWADDU - return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; case 0b110001: // VWADD - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(static_cast>(vs2) + static_cast>(vs1)); }; case 0b110010: // VWSUBU - return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 - vs1; }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 - vs1; }; case 0b110011: // VWSUB - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(static_cast>(vs2) - static_cast>(vs1)); }; case 0b110100: // VWADDU.W - return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; case 0b110101: // VWADD.W - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(static_cast>(vs2) + static_cast>(vs1)); }; case 0b110110: // VWSUBU.W - return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 - vs1; }; + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 - vs1; }; case 0b110111: // VWSUB.W - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(static_cast>(vs2) - static_cast>(vs1)); }; case 0b111000: // VWMULU - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return (static_cast>(vs2) * static_cast>(vs1)); }; case 0b111010: // VWMULSU - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return (static_cast>>(static_cast>(vs2)) * static_cast>(vs1)); }; case 0b111011: // VWMUL - return [](src2_elem_t vs2, src1_elem_t vs1) { + return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return (static_cast>>(static_cast>(vs2)) * static_cast>>(static_cast>(vs1))); }; @@ -310,19 +314,18 @@ void vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { - auto res = fn(vs2_view[idx], vs1_view[idx]); - vd_view[idx] = res; + vd_view[idx] = fn(vd_view[idx], vs2_view[idx], vs1_view[idx]); } else { vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx]; } } } else if(carry == carry_t::SUB_CARRY) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { - vd_view[idx] = fn(vs2_view[idx], vs1_view[idx]) - mask_reg[idx]; + vd_view[idx] = fn(vd_view[idx], vs2_view[idx], vs1_view[idx]) - mask_reg[idx]; } } else { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { - vd_view[idx] = fn(vs2_view[idx], vs1_view[idx]) + mask_reg[idx]; + vd_view[idx] = fn(vd_view[idx], vs2_view[idx], vs1_view[idx]) + mask_reg[idx]; } } // elements w/ index larger than elem_count are in the tail (fractional LMUL) @@ -347,21 +350,18 @@ void vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, ui for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { - vd_view[idx] = fn(vs2_view[idx], imm); + vd_view[idx] = fn(vd_view[idx], vs2_view[idx], imm); } else { vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx]; } } } else if(carry == carry_t::SUB_CARRY) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { - auto val1 = fn(vs2_view[idx], imm); - auto val2 = static_cast>(mask_reg[idx]); - auto diff = val1 - val2; - vd_view[idx] = fn(vs2_view[idx], imm) - mask_reg[idx]; + vd_view[idx] = fn(vd_view[idx], vs2_view[idx], imm) - mask_reg[idx]; } } else { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { - vd_view[idx] = fn(vs2_view[idx], imm) + mask_reg[idx]; + vd_view[idx] = fn(vd_view[idx], vs2_view[idx], imm) + mask_reg[idx]; } } // elements w/ index larger than elem_count are in the tail (fractional LMUL)