diff --git a/src/vm/vector_functions.hpp b/src/vm/vector_functions.hpp index e71b178..ae902af 100644 --- a/src/vm/vector_functions.hpp +++ b/src/vm/vector_functions.hpp @@ -40,6 +40,7 @@ #ifndef _VM_VECTOR_FUNCTIONS_H_ #error __FILE__ should only be included from vector_functions.h #endif +#include #include namespace softvector { @@ -79,6 +80,19 @@ enum FUNCT3 { OPFVF = 0b101, OPMVX = 0b110, }; +template struct twice; +template <> struct twice { using type = int16_t; }; +template <> struct twice { using type = uint16_t; }; +template <> struct twice { using type = int32_t; }; +template <> struct twice { using type = uint32_t; }; +template <> struct twice { using type = int64_t; }; +template <> struct twice { using type = uint64_t; }; +#ifdef __SIZEOF_INT128__ +template <> struct twice { using type = __int128_t; }; +template <> struct twice { using type = __uint128_t; }; +#endif +template using twice_t = typename twice::type; // for convenience + template std::function get_funct(unsigned funct6, unsigned funct3) { if(funct3 == OPIVV || funct3 == OPIVX || funct3 == OPIVI) @@ -93,15 +107,13 @@ std::function get_funct(unsigned funct6, return [](src2_elem_t vs2, src1_elem_t vs1) { return std::min(vs2, static_cast(vs1)); }; case 0b000101: // VMIN return [](src2_elem_t vs2, src1_elem_t vs1) { - return std::min(static_cast>(vs2), - static_cast>(vs1)); + return std::min(static_cast>(vs2), static_cast>(vs1)); }; case 0b000110: // VMAXU return [](src2_elem_t vs2, src1_elem_t vs1) { return std::max(vs2, static_cast(vs1)); }; case 0b000111: // VMAX return [](src2_elem_t vs2, src1_elem_t vs1) { - return std::max(static_cast>(vs2), - static_cast>(vs1)); + return std::max(static_cast>(vs2), static_cast>(vs1)); }; case 0b001001: // VAND return [](src2_elem_t vs2, src1_elem_t vs1) { return vs1 & vs2; }; @@ -116,8 +128,8 @@ std::function get_funct(unsigned funct6, return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; case 0b010010: // VSBC return [](src2_elem_t vs2, src1_elem_t vs1) { - return static_cast>(static_cast>(vs2) - - static_cast>(vs1)); + return static_cast>(static_cast>(vs2) - + static_cast>(vs1)); }; // case 0b010111: // VMERGE / VMV // case 0b100000: // VSADDU @@ -132,13 +144,13 @@ std::function get_funct(unsigned funct6, return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 >> (vs1 & shift_mask()); }; case 0b101001: // VSRA return [](src2_elem_t vs2, src1_elem_t vs1) { - return static_cast>(vs2) >> (vs1 & shift_mask()); + return static_cast>(vs2) >> (vs1 & shift_mask()); }; case 0b101100: // VNSRL return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 >> (vs1 & shift_mask()); }; case 0b101101: // VNSRA return [](src2_elem_t vs2, src1_elem_t vs1) { - return static_cast>(vs2) >> (vs1 & shift_mask()); + return static_cast>(vs2) >> (vs1 & shift_mask()); }; // case 0b101110: // VNCLIPU // case 0b101111: // VNCLIP @@ -149,37 +161,53 @@ std::function get_funct(unsigned funct6, } else if(funct3 == OPMVV || funct3 == OPMVX) switch(funct6) { - // case 0b000000: // VREDSUM - // case 0b000001: // VREDAND - // case 0b000010: // VREDOR - // case 0b000011: // VREDXOR - // case 0b000100: // VREDMINU - // case 0b000101: // VREDMIN - // case 0b000110: // VREDMAXU - // case 0b000111: // VREDMAX - // case 0b001000: // VAADDU - // case 0b001001: // VAADD - // case 0b001010: // VASUBU - // case 0b001011: // VASUB - // case 0b001110: // VSLID1EUP - // case 0b001111: // VSLIDE1DOWN - // case 0b010111: // VCOMPRESS - // case 0b011000: // VMANDN - // case 0b011001: // VMAND - // case 0b011010: // VMOR - // case 0b011011: // VMXOR - // case 0b011100: // VMORN - // case 0b011101: // VMNAND - // case 0b011110: // VMNOR - // case 0b011111: // VMXNOR - // case 0b100000: // VDIVU - // case 0b100001: // VDIV - // case 0b100010: // VREMU - // case 0b100011: // VREM - // case 0b100100: // VMULHU - // case 0b100101: // VMUL - // case 0b100110: // VMULHSU - // case 0b100111: // VMULH + // case 0b000000: // VREDSUM + // case 0b000001: // VREDAND + // case 0b000010: // VREDOR + // case 0b000011: // VREDXOR + // case 0b000100: // VREDMINU + // case 0b000101: // VREDMIN + // case 0b000110: // VREDMAXU + // case 0b000111: // VREDMAX + // case 0b001000: // VAADDU + // case 0b001001: // VAADD + // case 0b001010: // VASUBU + // case 0b001011: // VASUB + // case 0b001110: // VSLID1EUP + // case 0b001111: // VSLIDE1DOWN + // case 0b010111: // VCOMPRESS + // case 0b011000: // VMANDN + // case 0b011001: // VMAND + // case 0b011010: // VMOR + // case 0b011011: // VMXOR + // case 0b011100: // VMORN + // case 0b011101: // VMNAND + // case 0b011110: // VMNOR + // case 0b011111: // VMXNOR + // case 0b100000: // VDIVU + // case 0b100001: // VDIV + // case 0b100010: // VREMU + // case 0b100011: // VREM + case 0b100100: // VMULHU + return [](src2_elem_t vs2, src1_elem_t vs1) { + return (static_cast>(vs2) * static_cast>(vs1)) >> sizeof(dest_elem_t) * 8; + }; + case 0b100101: // VMUL + return [](src2_elem_t vs2, src1_elem_t vs1) { + return static_cast>(vs2) * static_cast>(vs1); + }; + case 0b100110: // VMULHSU + return [](src2_elem_t vs2, src1_elem_t vs1) { + return (static_cast>>(static_cast>(vs2)) * + static_cast>(vs1)) >> + sizeof(dest_elem_t) * 8; + }; + case 0b100111: // VMULH + return [](src2_elem_t vs2, src1_elem_t vs1) { + return (static_cast>>(static_cast>(vs2)) * + static_cast>>(static_cast>(vs1))) >> + sizeof(dest_elem_t) * 8; + }; // case 0b101001: // VMADD // case 0b101011: // VNMSUB // case 0b101101: // VMACC @@ -188,29 +216,29 @@ std::function get_funct(unsigned funct6, return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; case 0b110001: // VWADD return [](src2_elem_t vs2, src1_elem_t vs1) { - return static_cast>(static_cast>(vs2) + - static_cast>(vs1)); + return static_cast>(static_cast>(vs2) + + static_cast>(vs1)); }; case 0b110010: // VWSUBU return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 - vs1; }; case 0b110011: // VWSUB return [](src2_elem_t vs2, src1_elem_t vs1) { - return static_cast>(static_cast>(vs2) - - static_cast>(vs1)); + return static_cast>(static_cast>(vs2) - + static_cast>(vs1)); }; case 0b110100: // VWADDU.W return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; case 0b110101: // VWADD.W return [](src2_elem_t vs2, src1_elem_t vs1) { - return static_cast>(static_cast>(vs2) + - static_cast>(vs1)); + return static_cast>(static_cast>(vs2) + + static_cast>(vs1)); }; case 0b110110: // VWSUBU.W return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 - vs1; }; case 0b110111: // VWSUB.W return [](src2_elem_t vs2, src1_elem_t vs1) { - return static_cast>(static_cast>(vs2) - - static_cast>(vs1)); + return static_cast>(static_cast>(vs2) - + static_cast>(vs1)); }; // case 0b111000: // VWMULU // case 0b111010: // VWMULSU @@ -241,7 +269,8 @@ void vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { - vd_view[idx] = fn(vs2_view[idx], vs1_view[idx]); + auto res = fn(vs2_view[idx], vs1_view[idx]); + vd_view[idx] = res; } else { vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx]; } @@ -302,20 +331,6 @@ void vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, ui } return; } -template std::function get_carry_mask_funct(unsigned funct) { - switch(funct) { - case 0b010001: // VMADC - return [](elem_t vs2, elem_t vs1, elem_t carry) { - return static_cast(vs2 + vs1 + carry) < std::max(vs1, vs2) || static_cast(vs2 + vs1) < std::max(vs1, vs2); - }; - case 0b010011: // VMSBC - return [](elem_t vs2, elem_t vs1, elem_t carry) { - return vs2 < static_cast(vs1 + carry) || (vs1 == std::numeric_limits::max() && carry); - }; - default: - throw new std::runtime_error("Uknown funct in get_carry_mask_funct"); - } -} template std::function get_mask_funct(unsigned funct) { switch(funct) { case 0b011000: // VMSEQ @@ -416,7 +431,7 @@ std::function get_unary_fn(unsigned unary_op) { case 0b00111: // vsext.vf2 case 0b00101: // vsext.vf4 case 0b00011: // vsext.vf8 - return [](src2_elem_t vs2) { return static_cast>(vs2); }; + return [](src2_elem_t vs2) { return static_cast>(vs2); }; case 0b00110: // vzext.vf2 case 0b00100: // vzext.vf4 case 0b00010: // vzext.vf8 @@ -450,6 +465,20 @@ void vector_unary_op(uint8_t* V, unsigned unary_op, uint64_t vl, uint64_t vstart } return; } +template std::function get_carry_funct(unsigned funct) { + switch(funct) { + case 0b010001: // VMADC + return [](elem_t vs2, elem_t vs1, elem_t carry) { + return static_cast(vs2 + vs1 + carry) < std::max(vs1, vs2) || static_cast(vs2 + vs1) < std::max(vs1, vs2); + }; + case 0b010011: // VMSBC + return [](elem_t vs2, elem_t vs1, elem_t carry) { + return vs2 < static_cast(vs1 + carry) || (vs1 == std::numeric_limits::max() && carry); + }; + default: + throw new std::runtime_error("Uknown funct in get_carry_funct"); + } +} template void carry_vector_vector_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, unsigned vs1) { @@ -458,7 +487,7 @@ void carry_vector_vector_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vs auto vs1_view = get_vreg(V, vs1, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); vmask_view vd_mask_view = read_vmask(V, elem_count, vd); - auto fn = get_carry_mask_funct(funct); + auto fn = get_carry_funct(funct); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { @@ -482,7 +511,7 @@ void carry_vector_imm_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vstar vmask_view mask_reg = read_vmask(V, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); vmask_view vd_mask_view = read_vmask(V, elem_count, vd); - auto fn = get_carry_mask_funct(funct); + auto fn = get_carry_funct(funct); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {