//////////////////////////////////////////////////////////////////////////////// // Copyright (C) 2025, MINRES Technologies GmbH // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are met: // // 1. Redistributions of source code must retain the above copyright notice, // this list of conditions and the following disclaimer. // // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation // and/or other materials provided with the distribution. // // 3. Neither the name of the copyright holder nor the names of its contributors // may be used to endorse or promote products derived from this software // without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE // POSSIBILITY OF SUCH DAMAGE. // // Contributors: // alex@minres.com - initial API and implementation //////////////////////////////////////////////////////////////////////////////// #pragma once #include "vm/vector_functions.h" #include #include #include #include #include #ifndef _VM_VECTOR_FUNCTIONS_H_ #error __FILE__ should only be included from vector_functions.h #endif #include #include namespace softvector { template struct vreg_view { uint8_t* start; size_t elem_count; inline elem_t& get(size_t idx = 0) { assert(idx < elem_count); return *(reinterpret_cast(start) + idx); } elem_t& operator[](size_t idx) { assert(idx < elem_count); return *(reinterpret_cast(start) + idx); } }; template vreg_view get_vreg(uint8_t* V, uint8_t reg_idx, uint16_t elem_count) { assert(V + elem_count * sizeof(elem_t) <= V + VLEN * RFS / 8); return {V + VLEN / 8 * reg_idx, elem_count}; } template vmask_view read_vmask(uint8_t* V, uint16_t elem_count, uint8_t reg_idx) { uint8_t* mask_start = V + VLEN / 8 * reg_idx; assert(mask_start + elem_count / 8 <= V + VLEN * RFS / 8); return {mask_start, elem_count}; } template constexpr elem_t shift_mask() { static_assert(std::numeric_limits::is_integer, "shift_mask only supports integer types"); return std::numeric_limits::digits - 1; } enum FUNCT3 { OPIVV = 0b000, OPFVV = 0b001, OPMVV = 0b010, OPIVI = 0b011, OPIVX = 0b100, OPFVF = 0b101, OPMVX = 0b110, }; template struct twice; template <> struct twice { using type = int16_t; }; template <> struct twice { using type = uint16_t; }; template <> struct twice { using type = int32_t; }; template <> struct twice { using type = uint32_t; }; template <> struct twice { using type = int64_t; }; template <> struct twice { using type = uint64_t; }; #ifdef __SIZEOF_INT128__ template <> struct twice { using type = __int128_t; }; template <> struct twice { using type = __uint128_t; }; #endif template using twice_t = typename twice::type; // for convenience template std::function get_funct(unsigned funct6, unsigned funct3) { if(funct3 == OPIVV || funct3 == OPIVX || funct3 == OPIVI) switch(funct6) { case 0b000000: // VADD return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; case 0b000010: // VSUB return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 - vs1; }; case 0b000011: // VRSUB return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1 - vs2; }; case 0b000100: // VMINU return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return std::min(vs2, static_cast(vs1)); }; case 0b000101: // VMIN return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return std::min(static_cast>(vs2), static_cast>(vs1)); }; case 0b000110: // VMAXU return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return std::max(vs2, static_cast(vs1)); }; case 0b000111: // VMAX return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return std::max(static_cast>(vs2), static_cast>(vs1)); }; case 0b001001: // VAND return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1 & vs2; }; case 0b001010: // VOR return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1 | vs2; }; case 0b001011: // VXOR return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1 ^ vs2; }; // case 0b001100: // VRGATHER // case 0b001110: // VRGATHEREI16 // case 0b001111: // VLSLIDEDOWN case 0b010000: // VADC return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; case 0b010010: // VSBC return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(static_cast>(vs2) - static_cast>(vs1)); }; case 0b100101: // VSLL return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 << (vs1 & shift_mask()); }; // case 0b100111: // VMVR case 0b101000: // VSRL return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 >> (vs1 & shift_mask()); }; case 0b101001: // VSRA return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(vs2) >> (vs1 & shift_mask()); }; case 0b101100: // VNSRL return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 >> (vs1 & shift_mask()); }; case 0b101101: // VNSRA return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(vs2) >> (vs1 & shift_mask()); }; default: throw new std::runtime_error("Unknown funct6 in get_funct"); } else if(funct3 == OPMVV || funct3 == OPMVX) switch(funct6) { // case 0b001110: // VSLID1EUP // case 0b001111: // VSLIDE1DOWN // case 0b010111: // VCOMPRESS // case 0b011000: // VMANDN // case 0b011001: // VMAND // case 0b011010: // VMOR // case 0b011011: // VMXOR // case 0b011100: // VMORN // case 0b011101: // VMNAND // case 0b011110: // VMNOR // case 0b011111: // VMXNOR case 0b100000: // VDIVU return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t { if(vs1 == 0) return -1; else return vs2 / vs1; }; case 0b100001: // VDIV return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t { if(vs1 == 0) return -1; else if(vs2 == std::numeric_limits>::min() && static_cast>(vs1) == -1) return vs2; else return static_cast>(vs2) / static_cast>(vs1); }; case 0b100010: // VREMU return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t { if(vs1 == 0) return vs2; else return vs2 % vs1; }; case 0b100011: // VREM return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t { if(vs1 == 0) return vs2; else if(vs2 == std::numeric_limits>::min() && static_cast>(vs1) == -1) return 0; else return static_cast>(vs2) % static_cast>(vs1); }; case 0b100100: // VMULHU return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return (static_cast>(vs2) * static_cast>(vs1)) >> sizeof(dest_elem_t) * 8; }; case 0b100101: // VMUL return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(vs2) * static_cast>(vs1); }; case 0b100110: // VMULHSU return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return (static_cast>>(static_cast>(vs2)) * static_cast>(vs1)) >> sizeof(dest_elem_t) * 8; }; case 0b100111: // VMULH return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return (static_cast>>(static_cast>(vs2)) * static_cast>>(static_cast>(vs1))) >> sizeof(dest_elem_t) * 8; }; case 0b101001: // VMADD return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1 * vd + vs2; }; case 0b101011: // VNMSUB return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return -1 * (vs1 * vd) + vs2; }; case 0b101101: // VMACC return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1 * vs2 + vd; }; case 0b101111: // VNMSAC return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return -1 * (vs1 * vs2) + vd; }; case 0b110000: // VWADDU return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; case 0b110001: // VWADD return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(static_cast>(vs2) + static_cast>(vs1)); }; case 0b110010: // VWSUBU return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 - vs1; }; case 0b110011: // VWSUB return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(static_cast>(vs2) - static_cast>(vs1)); }; case 0b110100: // VWADDU.W return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; }; case 0b110101: // VWADD.W return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(static_cast>(vs2) + static_cast>(vs1)); }; case 0b110110: // VWSUBU.W return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 - vs1; }; case 0b110111: // VWSUB.W return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(static_cast>(vs2) - static_cast>(vs1)); }; case 0b111000: // VWMULU return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return (static_cast>(vs2) * static_cast>(vs1)); }; case 0b111010: // VWMULSU return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return (static_cast>>(static_cast>(vs2)) * static_cast>(vs1)); }; case 0b111011: // VWMUL return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return (static_cast>>(static_cast>(vs2)) * static_cast>>(static_cast>(vs1))); }; case 0b111100: // VWMACCU return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1 * vs2 + vd; }; case 0b111101: // VWMACC return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(vs1) * static_cast>(vs2) + vd; }; case 0b111110: // VWMACCUS return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1 * static_cast>(vs2) + vd; }; case 0b111111: // VWMACCSU return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(vs1) * vs2 + vd; }; default: throw new std::runtime_error("Unknown funct6 in get_funct"); } else throw new std::runtime_error("Unknown funct3 in get_funct"); } template std::function get_merge_funct(bool vm) { if(vm) { // VMV return [](bool vm, dest_elem_t vs2, dest_elem_t vs1) { return vs1; }; } else { // VMERGE return [](bool vm, dest_elem_t vs2, dest_elem_t vs1) { return vm ? vs1 : vs2; }; } }; template void vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, unsigned vs1, carry_t carry, bool merge) { uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); vmask_view mask_reg = read_vmask(V, elem_count); auto vs1_view = get_vreg(V, vs1, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); auto vd_view = get_vreg(V, vd, elem_count); auto fn = get_funct(funct6, funct3); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) if(merge) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { auto merge_fn = get_merge_funct(vm); vd_view[idx] = merge_fn(mask_reg[idx], vs2_view[idx], vs1_view[idx]); } } else if(carry == carry_t::NO_CARRY) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { vd_view[idx] = fn(vd_view[idx], vs2_view[idx], vs1_view[idx]); } else { vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx]; } } } else if(carry == carry_t::SUB_CARRY) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { vd_view[idx] = fn(vd_view[idx], vs2_view[idx], vs1_view[idx]) - mask_reg[idx]; } } else { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { vd_view[idx] = fn(vd_view[idx], vs2_view[idx], vs1_view[idx]) + mask_reg[idx]; } } // elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than vl are in the tail unsigned maximum_elems = VLEN * vtype.lmul() / (sizeof(dest_elem_t) * 8); for(unsigned idx = std::min(elem_count, vl); idx < maximum_elems; idx++) { vd_view[idx] = vtype.vta() ? vd_view[idx] : vd_view[idx]; } return; } template void vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, typename std::make_signed::type imm, carry_t carry, bool merge) { uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); vmask_view mask_reg = read_vmask(V, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); auto vd_view = get_vreg(V, vd, elem_count); auto fn = get_funct(funct6, funct3); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) if(merge) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { auto cur_mask = mask_reg[idx]; auto vd_val = vd_view[idx]; auto vs2_val = vs2_view[idx]; auto merge_fn = get_merge_funct(vm); vd_view[idx] = merge_fn(mask_reg[idx], vs2_view[idx], imm); } } else if(carry == carry_t::NO_CARRY) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { vd_view[idx] = fn(vd_view[idx], vs2_view[idx], imm); } else { vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx]; } } } else if(carry == carry_t::SUB_CARRY) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { vd_view[idx] = fn(vd_view[idx], vs2_view[idx], imm) - mask_reg[idx]; } } else { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { vd_view[idx] = fn(vd_view[idx], vs2_view[idx], imm) + mask_reg[idx]; } } // elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than vl are in the tail unsigned maximum_elems = VLEN * vtype.lmul() / (sizeof(dest_elem_t) * 8); for(unsigned idx = std::min(elem_count, vl); idx < maximum_elems; idx++) { vd_view[idx] = vtype.vta() ? vd_view[idx] : vd_view[idx]; } return; } template std::function get_mask_funct(unsigned funct) { switch(funct) { case 0b011000: // VMSEQ return [](elem_t vs2, elem_t vs1) { return vs2 == vs1; }; case 0b011001: // VMSNE return [](elem_t vs2, elem_t vs1) { return vs2 != vs1; }; case 0b011010: // VMSLTU return [](elem_t vs2, elem_t vs1) { return vs2 < vs1; }; case 0b011011: // VMSLT return [](elem_t vs2, elem_t vs1) { return static_cast>(vs2) < static_cast>(vs1); }; case 0b011100: // VMSLEU return [](elem_t vs2, elem_t vs1) { return vs2 <= vs1; }; case 0b011101: // VMSLE return [](elem_t vs2, elem_t vs1) { return static_cast>(vs2) <= static_cast>(vs1); }; case 0b011110: // VMSGTU return [](elem_t vs2, elem_t vs1) { return vs2 > vs1; }; case 0b011111: // VMSGT return [](elem_t vs2, elem_t vs1) { return static_cast>(vs2) > static_cast>(vs1); }; default: throw new std::runtime_error("Unknown funct in get_mask_funct"); } } template void mask_vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, unsigned vs1) { uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); vmask_view mask_reg = read_vmask(V, elem_count); auto vs1_view = get_vreg(V, vs1, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); vmask_view vd_mask_view = read_vmask(V, elem_count, vd); auto fn = get_mask_funct(funct6); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { bool new_bit_value = fn(vs2_view[idx], vs1_view[idx]); uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; unsigned cur_bit = idx % 8; *cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast(new_bit_value) << cur_bit; } else { uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; unsigned cur_bit = idx % 8; *cur_mask_byte_addr = vtype.vma() ? *cur_mask_byte_addr : *cur_mask_byte_addr; } } // elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than vl are in the tail for(unsigned idx = std::min(elem_count, vl); idx < VLEN; idx++) { uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; unsigned cur_bit = idx % 8; *cur_mask_byte_addr = vtype.vta() ? *cur_mask_byte_addr : *cur_mask_byte_addr; } return; } template void mask_vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, typename std::make_signed::type imm) { uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); vmask_view mask_reg = read_vmask(V, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); vmask_view vd_mask_view = read_vmask(V, elem_count, vd); auto fn = get_mask_funct(funct6); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { bool new_bit_value = fn(vs2_view[idx], imm); uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; unsigned cur_bit = idx % 8; *cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast(new_bit_value) << cur_bit; } else { uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; unsigned cur_bit = idx % 8; *cur_mask_byte_addr = vtype.vma() ? *cur_mask_byte_addr : *cur_mask_byte_addr; } } // elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than vl are in the tail for(unsigned idx = std::min(elem_count, vl); idx < VLEN; idx++) { uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; unsigned cur_bit = idx % 8; *cur_mask_byte_addr = vtype.vta() ? *cur_mask_byte_addr : *cur_mask_byte_addr; } return; } template std::function get_unary_fn(unsigned unary_op) { switch(unary_op) { case 0b00111: // vsext.vf2 case 0b00101: // vsext.vf4 case 0b00011: // vsext.vf8 return [](src2_elem_t vs2) { return static_cast>(vs2); }; case 0b00110: // vzext.vf2 case 0b00100: // vzext.vf4 case 0b00010: // vzext.vf8 return [](src2_elem_t vs2) { return vs2; }; default: throw new std::runtime_error("Unknown funct in get_unary_fn"); } } template void vector_unary_op(uint8_t* V, unsigned unary_op, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2) { uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); vmask_view mask_reg = read_vmask(V, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); auto vd_view = get_vreg(V, vd, elem_count); auto fn = get_unary_fn(unary_op); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { vd_view[idx] = fn(vs2_view[idx]); } else { vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx]; } } // elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than vl are in the tail unsigned maximum_elems = VLEN * vtype.lmul() / (sizeof(dest_elem_t) * 8); for(unsigned idx = std::min(elem_count, vl); idx < maximum_elems; idx++) { vd_view[idx] = vtype.vta() ? vd_view[idx] : vd_view[idx]; } return; } template std::function get_carry_funct(unsigned funct) { switch(funct) { case 0b010001: // VMADC return [](elem_t vs2, elem_t vs1, elem_t carry) { return static_cast(vs2 + vs1 + carry) < std::max(vs1, vs2) || static_cast(vs2 + vs1) < std::max(vs1, vs2); }; case 0b010011: // VMSBC return [](elem_t vs2, elem_t vs1, elem_t carry) { return vs2 < static_cast(vs1 + carry) || (vs1 == std::numeric_limits::max() && carry); }; default: throw new std::runtime_error("Unknown funct in get_carry_funct"); } } template void carry_vector_vector_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, unsigned vs1) { uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); vmask_view mask_reg = read_vmask(V, elem_count); auto vs1_view = get_vreg(V, vs1, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); vmask_view vd_mask_view = read_vmask(V, elem_count, vd); auto fn = get_carry_funct(funct); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { elem_t carry = vm ? 0 : mask_reg[idx]; bool new_bit_value = fn(vs2_view[idx], vs1_view[idx], carry); uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; unsigned cur_bit = idx % 8; *cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast(new_bit_value) << cur_bit; } // elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than vl are in the tail for(unsigned idx = std::min(elem_count, vl); idx < VLEN; idx++) { // always tail agnostic } return; } template void carry_vector_imm_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, typename std::make_signed::type imm) { uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); vmask_view mask_reg = read_vmask(V, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); vmask_view vd_mask_view = read_vmask(V, elem_count, vd); auto fn = get_carry_funct(funct); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { elem_t carry = vm ? 0 : mask_reg[idx]; bool new_bit_value = fn(vs2_view[idx], imm, carry); uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; unsigned cur_bit = idx % 8; *cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast(new_bit_value) << cur_bit; } // elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than vl are in the tail for(unsigned idx = std::min(elem_count, vl); idx < VLEN; idx++) { // always tail agnostic } return; } template bool get_rounding_increment(T v, uint64_t d, int64_t vxrm) { if(d == 0) return 0; switch(vxrm & 0b11) { // Mask to ensure only lower 2 bits are used case 0b00: // rnu: round-to-nearest-up (add +0.5 LSB) return (v >> (d - 1)) & 1; case 0b01: // rne: round-to-nearest-even return ((v >> (d - 1)) & 1) && (((v & ((1 << (d - 1)) - 1)) != 0) || ((v >> d) & 1)); case 0b10: // rdn: round-down (truncate) return false; case 0b11: // rod: round-to-odd (jam) return (!(v & (static_cast(1) << d)) && ((v & ((static_cast(1) << d) - 1)) != 0)); } return false; } template T roundoff(T v, uint64_t d, int64_t vxrm) { unsigned r = get_rounding_increment(v, d, vxrm); return (v >> d) + r; } template std::function get_sat_funct(unsigned funct6, unsigned funct3) { if(funct3 == OPIVV || funct3 == OPIVX || funct3 == OPIVI) switch(funct6) { case 0b100000: // VSADDU return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { auto res = static_cast>(vs2) + static_cast>(vs1); if(res > std::numeric_limits::max()) { vd = std::numeric_limits::max(); return 1; } else { vd = res; return 0; } }; case 0b100001: // VSADD return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { auto res = static_cast>>(static_cast>(vs2)) + static_cast>>(static_cast>(vs1)); if(res < std::numeric_limits>::min()) { vd = std::numeric_limits>::min(); return 1; } else if(res > std::numeric_limits>::max()) { vd = std::numeric_limits>::max(); return 1; } else { vd = res; return 0; } }; case 0b100010: // VSSUBU return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { if(vs2 < vs1) { vd = 0; return 1; } else { vd = vs2 - vs1; return 0; } }; case 0b100011: // VSSUB return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { auto res = static_cast>>(static_cast>(vs2)) - static_cast>>(static_cast>(vs1)); if(res < std::numeric_limits>::min()) { vd = std::numeric_limits>::min(); return 1; } else if(res > std::numeric_limits>::max()) { vd = std::numeric_limits>::max(); return 1; } else { vd = res; return 0; } }; case 0b100111: // VSMUL return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { auto big_val = static_cast>>(static_cast>(vs2)) * static_cast>>(static_cast>(vs1)); auto res = roundoff(big_val, vtype.sew() - 1, vxrm); if(res < std::numeric_limits>::min()) { vd = std::numeric_limits>::min(); return 1; } else if(res > std::numeric_limits>::max()) { vd = std::numeric_limits>::max(); return 1; } else { vd = res; return 0; } }; case 0b101010: // VSSRL return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { vd = roundoff(vs2, vs1 & shift_mask(), vxrm); return 0; }; case 0b101011: // VSSRA return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { vd = roundoff(static_cast>(vs2), vs1 & shift_mask(), vxrm); return 0; }; case 0b101110: // VNCLIPU return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { auto res = roundoff(vs2, vs1 & shift_mask(), vxrm); if(res > std::numeric_limits::max()) { vd = std::numeric_limits::max(); return 1; } else { vd = res; return 0; } }; case 0b101111: // VNCLIP return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { auto res = roundoff(static_cast>(vs2), vs1 & shift_mask(), vxrm); if(res < std::numeric_limits>::min()) { vd = std::numeric_limits>::min(); return 1; } else if(res > std::numeric_limits>::max()) { vd = std::numeric_limits>::max(); return 1; } else { vd = res; return 0; } }; default: throw new std::runtime_error("Unknown funct6 in get_sat_funct"); } else if(funct3 == OPMVV || funct3 == OPMVX) switch(funct6) { case 0b001000: // VAADDU return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { auto res = static_cast>(vs2) + static_cast>(vs1); vd = roundoff(res, 1, vxrm); return 0; }; case 0b001001: // VAADD return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { auto res = static_cast>>(static_cast>(vs2)) + static_cast>>(static_cast>(vs1)); vd = roundoff(res, 1, vxrm); return 0; }; case 0b001010: // VASUBU return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { auto res = static_cast>(vs2) - static_cast>(vs1); vd = roundoff(res, 1, vxrm); return 0; }; case 0b001011: // VASUB return [](uint64_t vxrm, vtype_t vtype, dest_elem_t& vd, src2_elem_t vs2, src1_elem_T vs1) { auto res = static_cast>>(static_cast>(vs2)) - static_cast>>(static_cast>(vs1)); vd = roundoff(res, 1, vxrm); return 0; }; default: throw new std::runtime_error("Unknown funct6 in get_sat_funct"); } else throw new std::runtime_error("Unknown funct3 in get_sat_funct"); } template bool sat_vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, int64_t vxrm, bool vm, unsigned vd, unsigned vs2, unsigned vs1) { uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); bool saturated = false; vmask_view mask_reg = read_vmask(V, elem_count); auto vs1_view = get_vreg(V, vs1, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); auto vd_view = get_vreg(V, vd, elem_count); auto fn = get_sat_funct(funct6, funct3); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { saturated |= fn(vxrm, vtype, vd_view[idx], vs2_view[idx], vs1_view[idx]); } else { vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx]; } } // elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than vl are in the tail unsigned maximum_elems = VLEN * vtype.lmul() / (sizeof(dest_elem_t) * 8); for(unsigned idx = std::min(elem_count, vl); idx < maximum_elems; idx++) { vd_view[idx] = vtype.vta() ? vd_view[idx] : vd_view[idx]; } return saturated; } template bool sat_vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, int64_t vxrm, bool vm, unsigned vd, unsigned vs2, typename std::make_signed::type imm) { uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); bool saturated = false; vmask_view mask_reg = read_vmask(V, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); auto vd_view = get_vreg(V, vd, elem_count); auto fn = get_sat_funct(funct6, funct3); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { saturated |= fn(vxrm, vtype, vd_view[idx], vs2_view[idx], imm); } else { vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx]; } } // elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than vl are in the tail unsigned maximum_elems = VLEN * vtype.lmul() / (sizeof(dest_elem_t) * 8); for(unsigned idx = std::min(elem_count, vl); idx < maximum_elems; idx++) { vd_view[idx] = vtype.vta() ? vd_view[idx] : vd_view[idx]; } return saturated; } template std::function get_red_funct(unsigned funct6, unsigned funct3) { if(funct3 == OPIVV || funct3 == OPIVX || funct3 == OPIVI) switch(funct6) { // case 0b110000: // VWREDSUMU // case 0b110001: // VWREDSUM default: throw new std::runtime_error("Unknown funct6 in get_funct"); } else if(funct3 == OPMVV || funct3 == OPMVX) switch(funct6) { case 0b000000: // VREDSUM return [](dest_elem_t& running_total, src_elem_t vs2) { return running_total += vs2; }; case 0b000001: // VREDAND return [](dest_elem_t& running_total, src_elem_t vs2) { return running_total &= vs2; }; case 0b000010: // VREDOR return [](dest_elem_t& running_total, src_elem_t vs2) { return running_total |= vs2; }; case 0b000011: // VREDXOR return [](dest_elem_t& running_total, src_elem_t vs2) { running_total ^= vs2; }; case 0b000100: // VREDMINU return [](dest_elem_t& running_total, src_elem_t vs2) { running_total = std::min(running_total, vs2); }; case 0b000101: // VREDMIN return [](dest_elem_t& running_total, src_elem_t vs2) { running_total = std::min(static_cast>(running_total), static_cast>(vs2)); }; case 0b000110: // VREDMAXU return [](dest_elem_t& running_total, src_elem_t vs2) { running_total = std::max(running_total, vs2); }; case 0b000111: // VREDMAX return [](dest_elem_t& running_total, src_elem_t vs2) { running_total = std::max(static_cast>(running_total), static_cast>(vs2)); }; default: throw new std::runtime_error("Unknown funct6 in get_funct"); } else throw new std::runtime_error("Unknown funct3 in get_funct"); } template void vector_red_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, unsigned vs1) { if(vl == 0) return; uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); vmask_view mask_reg = read_vmask(V, elem_count); auto vs1_elem = get_vreg(V, vs1, elem_count)[0]; auto vs2_view = get_vreg(V, vs2, elem_count); auto vd_view = get_vreg(V, vd, elem_count); auto fn = get_red_funct(funct6, funct3); dest_elem_t& running_total = {vs1_elem}; for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { fn(running_total, vs2_view[idx]); } } vd_view[0] = running_total; // the tail is all elements of the destination register beyond the first one for(unsigned idx = 1; idx < VLEN / (vtype.sew() * RFS); idx++) { vd_view[idx] = vtype.vta() ? vd_view[idx] : vd_view[idx]; } return; } } // namespace softvector