From 75d96bf18d4ea8bd68c0a18e48b323c2dfade089 Mon Sep 17 00:00:00 2001 From: Eyck-Alexander Jentzsch Date: Tue, 18 Feb 2025 21:13:40 +0100 Subject: [PATCH] small cleanup, adds first fixed point instrs --- gen_input/templates/interp/CORENAME.cpp.gtl | 28 +++ src/vm/vector_functions.h | 6 + src/vm/vector_functions.hpp | 214 +++++++++++++++++--- 3 files changed, 220 insertions(+), 28 deletions(-) diff --git a/gen_input/templates/interp/CORENAME.cpp.gtl b/gen_input/templates/interp/CORENAME.cpp.gtl index b412a2a..df29e80 100644 --- a/gen_input/templates/interp/CORENAME.cpp.gtl +++ b/gen_input/templates/interp/CORENAME.cpp.gtl @@ -413,6 +413,34 @@ if(vector != null) {%> throw new std::runtime_error("Unsupported sew bit value"); } } + bool sat_vector_vector_op(uint8_t* V, uint8_t funct6, uint8_t funct3, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, uint64_t vxrm, bool vm, uint8_t vd, uint8_t vs2, uint8_t vs1, uint8_t sew_val){ + switch(sew_val){ + case 0b000: + return softvector::sat_vector_vector_op<${vlen}, uint8_t>(V, funct6, funct3, vl, vstart, vtype, vxrm, vm, vd, vs2, vs1); + case 0b001: + return softvector::sat_vector_vector_op<${vlen}, uint16_t>(V, funct6, funct3, vl, vstart, vtype, vxrm, vm, vd, vs2, vs1); + case 0b010: + return softvector::sat_vector_vector_op<${vlen}, uint32_t>(V, funct6, funct3, vl, vstart, vtype, vxrm, vm, vd, vs2, vs1); + case 0b011: + return softvector::sat_vector_vector_op<${vlen}, uint64_t>(V, funct6, funct3, vl, vstart, vtype, vxrm, vm, vd, vs2, vs1); + default: + throw new std::runtime_error("Unsupported sew bit value"); + } + } + bool sat_vector_imm_op(uint8_t* V, uint8_t funct6, uint8_t funct3, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, uint64_t vxrm, bool vm, uint8_t vd, uint8_t vs2, int64_t imm, uint8_t sew_val){ + switch(sew_val){ + case 0b000: + return softvector::sat_vector_imm_op<${vlen}, uint8_t>(V, funct6, funct3, vl, vstart, vtype, vxrm, vm, vd, vs2, imm); + case 0b001: + return softvector::sat_vector_imm_op<${vlen}, uint16_t>(V, funct6, funct3, vl, vstart, vtype, vxrm, vm, vd, vs2, imm); + case 0b010: + return softvector::sat_vector_imm_op<${vlen}, uint32_t>(V, funct6, funct3, vl, vstart, vtype, vxrm, vm, vd, vs2, imm); + case 0b011: + return softvector::sat_vector_imm_op<${vlen}, uint64_t>(V, funct6, funct3, vl, vstart, vtype, vxrm, vm, vd, vs2, imm); + default: + throw new std::runtime_error("Unsupported sew bit value"); + } + } <%}%> uint64_t fetch_count{0}; uint64_t tval{0}; diff --git a/src/vm/vector_functions.h b/src/vm/vector_functions.h index 23fd2ee..6416c97 100644 --- a/src/vm/vector_functions.h +++ b/src/vm/vector_functions.h @@ -91,6 +91,12 @@ void carry_vector_vector_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vs template void carry_vector_imm_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, typename std::make_signed::type imm); +template +bool sat_vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, int64_t vxrm, bool vm, + unsigned vd, unsigned vs2, unsigned vs1); +template +bool sat_vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, int64_t vxrm, bool vm, + unsigned vd, unsigned vs2, typename std::make_signed::type imm); } // namespace softvector #include "vm/vector_functions.hpp" #endif /* _VM_VECTOR_FUNCTIONS_H_ */ diff --git a/src/vm/vector_functions.hpp b/src/vm/vector_functions.hpp index 5b3093b..3499562 100644 --- a/src/vm/vector_functions.hpp +++ b/src/vm/vector_functions.hpp @@ -33,6 +33,7 @@ //////////////////////////////////////////////////////////////////////////////// #pragma once #include "vm/vector_functions.h" +#include #include #include #include @@ -131,13 +132,8 @@ std::function get_funct(unsi return static_cast>(static_cast>(vs2) - static_cast>(vs1)); }; - // case 0b100000: // VSADDU - // case 0b100001: // VSADD - // case 0b100010: // VSSUBU - // case 0b100011: // VSSUB case 0b100101: // VSLL return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 << (vs1 & shift_mask()); }; - // case 0b100111: // VSMUL // case 0b100111: // VMVR case 0b101000: // VSRL return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs2 >> (vs1 & shift_mask()); }; @@ -151,12 +147,10 @@ std::function get_funct(unsi return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return static_cast>(vs2) >> (vs1 & shift_mask()); }; - // case 0b101110: // VNCLIPU - // case 0b101111: // VNCLIP // case 0b110000: // VWREDSUMU // case 0b110001: // VWREDSUM default: - throw new std::runtime_error("Uknown funct6 in get_funct"); + throw new std::runtime_error("Unknown funct6 in get_funct"); } else if(funct3 == OPMVV || funct3 == OPMVX) switch(funct6) { @@ -168,10 +162,6 @@ std::function get_funct(unsi // case 0b000101: // VREDMIN // case 0b000110: // VREDMAXU // case 0b000111: // VREDMAX - // case 0b001000: // VAADDU - // case 0b001001: // VAADD - // case 0b001010: // VASUBU - // case 0b001011: // VASUB // case 0b001110: // VSLID1EUP // case 0b001111: // VSLIDE1DOWN // case 0b010111: // VCOMPRESS @@ -302,17 +292,16 @@ std::function get_funct(unsi return static_cast>(vs1) * vs2 + vd; }; default: - throw new std::runtime_error("Uknown funct6 in get_funct"); + throw new std::runtime_error("Unknown funct6 in get_funct"); } else throw new std::runtime_error("Unknown funct3 in get_funct"); } -template -std::function get_merge_funct(bool vm) { +template std::function get_merge_funct(bool vm) { if(vm) { // VMV - return [](bool vm, dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vs1; }; + return [](bool vm, dest_elem_t vs2, dest_elem_t vs1) { return vs1; }; } else { // VMERGE - return [](bool vm, dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { return vm ? vs1 : vs2; }; + return [](bool vm, dest_elem_t vs2, dest_elem_t vs1) { return vm ? vs1 : vs2; }; } }; template @@ -328,12 +317,8 @@ void vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, // body is from vstart to min(elem_count, vl) if(merge) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { - auto merge_fn = get_merge_funct(vm); - auto cur_mask = mask_reg[idx]; - auto vd_val = vd_view[idx]; - auto vs2_val = vs2_view[idx]; - auto vs1_val = vs1_view[idx]; - vd_view[idx] = merge_fn(mask_reg[idx], vd_view[idx], vs2_view[idx], vs1_view[idx]); + auto merge_fn = get_merge_funct(vm); + vd_view[idx] = merge_fn(mask_reg[idx], vs2_view[idx], vs1_view[idx]); } } else if(carry == carry_t::NO_CARRY) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { @@ -376,8 +361,8 @@ void vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, ui auto cur_mask = mask_reg[idx]; auto vd_val = vd_view[idx]; auto vs2_val = vs2_view[idx]; - auto merge_fn = get_merge_funct(vm); - vd_view[idx] = merge_fn(mask_reg[idx], vd_view[idx], vs2_view[idx], imm); + auto merge_fn = get_merge_funct(vm); + vd_view[idx] = merge_fn(mask_reg[idx], vs2_view[idx], imm); } } else if(carry == carry_t::NO_CARRY) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { @@ -431,7 +416,7 @@ template std::function get_mask_funct(un }; default: - throw new std::runtime_error("Uknown funct in get_mask_funct"); + throw new std::runtime_error("Unknown funct in get_mask_funct"); } } template @@ -511,7 +496,7 @@ std::function get_unary_fn(unsigned unary_op) { case 0b00010: // vzext.vf8 return [](src2_elem_t vs2) { return vs2; }; default: - throw new std::runtime_error("Uknown funct in get_unary_fn"); + throw new std::runtime_error("Unknown funct in get_unary_fn"); } } template @@ -550,7 +535,7 @@ template std::function get_carry return vs2 < static_cast(vs1 + carry) || (vs1 == std::numeric_limits::max() && carry); }; default: - throw new std::runtime_error("Uknown funct in get_carry_funct"); + throw new std::runtime_error("Unknown funct in get_carry_funct"); } } template @@ -602,4 +587,177 @@ void carry_vector_imm_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vstar } return; } +template bool get_rounding_increment(T v, uint64_t d, int64_t vxrm) { + switch(vxrm & 0b11) { // Mask to ensure only lower 2 bits are used + case 0b00: // rnu: round-to-nearest-up (add +0.5 LSB) + return (v >> (d - 1)) & 1; + case 0b01: // rne: round-to-nearest-even + return ((v >> (d - 1)) & 1) && (((v & ((1 << (d - 1)) - 1)) != 0) || ((v >> d) & 1)); + case 0b10: // rdn: round-down (truncate) + return false; + case 0b11: // rod: round-to-odd (jam) + return (!(v & (1 << d)) && ((v & ((1 << d) - 1)) != 0)); + } + return false; +} +template T roundoff_unsigned(T v, uint64_t d, int64_t vxrm) { + unsigned r = get_rounding_increment(v, d, vxrm); + return (v >> d) + r; +} +template T roundoff_signed(T v, uint64_t d, int64_t vxrm) { + unsigned r = get_rounding_increment(v, d, vxrm); + return (v >> d) + r; +} +template +std::function get_sat_funct(unsigned funct6, unsigned funct3) { + if(funct3 == OPIVV || funct3 == OPIVX || funct3 == OPIVI) + switch(funct6) { + case 0b100000: // VSADDU + return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + auto res = static_cast>(vs2) + static_cast>(vs1); + if(res > std::numeric_limits::max()) { + vd = std::numeric_limits::max(); + return 1; + } else { + vd = res; + return 0; + } + }; + case 0b100001: // VSADD + return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + auto res = static_cast>>(static_cast>(vs2)) + + static_cast>>(static_cast>(vs1)); + if(res < std::numeric_limits>::min()) { + vd = std::numeric_limits>::min(); + return 1; + } else if(res > std::numeric_limits>::max()) { + vd = std::numeric_limits>::max(); + return 1; + } else { + vd = res; + return 0; + } + }; + case 0b100010: // VSSUBU + return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + if(vs2 < vs1) { + vd = 0; + return 1; + } else { + vd = vs2 - vs1; + return 0; + } + }; + case 0b100011: // VSSUB + return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + auto res = static_cast>>(static_cast>(vs2)) - + static_cast>>(static_cast>(vs1)); + if(res < std::numeric_limits>::min()) { + vd = std::numeric_limits>::min(); + return 1; + } else if(res > std::numeric_limits>::max()) { + vd = std::numeric_limits>::max(); + return 1; + } else { + vd = res; + return 0; + } + }; + // case 0b100111: // VSMUL + // case 0b101010: // VSSRL + // case 0b101011: // VSSRA + // case 0b101110: // VNCLIPU + // case 0b101111: // VNCLIP + default: + throw new std::runtime_error("Unknown funct6 in get_sat_funct"); + } + else if(funct3 == OPMVV || funct3 == OPMVX) + switch(funct6) { + case 0b001000: // VAADDU + return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + auto res = static_cast>(vs2) + static_cast>(vs1); + vd = roundoff_unsigned(res, 1, vxrm); + return 0; + }; + case 0b001001: // VAADD + return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + auto res = static_cast>>(static_cast>(vs2)) + + static_cast>>(static_cast>(vs1)); + vd = roundoff_signed(res, 1, vxrm); + return 0; + }; + case 0b001010: // VASUBU + return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + auto res = static_cast>(vs2) - static_cast>(vs1); + vd = roundoff_unsigned(res, 1, vxrm); + return 0; + }; + case 0b001011: // VASUB + return [](uint64_t vxrm, dest_elem_t& vd, src_elem_t vs2, src_elem_t vs1) { + auto res = static_cast>>(static_cast>(vs2)) - + static_cast>>(static_cast>(vs1)); + vd = roundoff_signed(res, 1, vxrm); + return 0; + }; + default: + throw new std::runtime_error("Unknown funct6 in get_sat_funct"); + } + else + throw new std::runtime_error("Unknown funct3 in get_sat_funct"); +} +template +bool sat_vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, int64_t vxrm, bool vm, + unsigned vd, unsigned vs2, unsigned vs1) { + uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); + bool saturated = false; + vmask_view mask_reg = read_vmask(V, elem_count); + auto vs1_view = get_vreg(V, vs1, elem_count); + auto vs2_view = get_vreg(V, vs2, elem_count); + auto vd_view = get_vreg(V, vd, elem_count); + auto fn = get_sat_funct(funct6, funct3); + // elements w/ index smaller than vstart are in the prestart and get skipped + // body is from vstart to min(elem_count, vl) + for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { + bool mask_active = vm ? 1 : mask_reg[idx]; + if(mask_active) { + saturated |= fn(vxrm, vd_view[idx], vs2_view[idx], vs1_view[idx]); + } else { + vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx]; + } + } + // elements w/ index larger than elem_count are in the tail (fractional LMUL) + // elements w/ index larger than vl are in the tail + unsigned maximum_elems = VLEN * vtype.lmul() / (sizeof(dest_elem_t) * 8); + for(unsigned idx = std::min(elem_count, vl); idx < maximum_elems; idx++) { + vd_view[idx] = vtype.vta() ? vd_view[idx] : vd_view[idx]; + } + return saturated; +} +template +bool sat_vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, int64_t vxrm, bool vm, + unsigned vd, unsigned vs2, typename std::make_signed::type imm) { + uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); + bool saturated = false; + vmask_view mask_reg = read_vmask(V, elem_count); + auto vs2_view = get_vreg(V, vs2, elem_count); + auto vd_view = get_vreg(V, vd, elem_count); + auto fn = get_sat_funct(funct6, funct3); + // elements w/ index smaller than vstart are in the prestart and get skipped + // body is from vstart to min(elem_count, vl) + for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { + bool mask_active = vm ? 1 : mask_reg[idx]; + if(mask_active) { + saturated |= fn(vxrm, vd_view[idx], vs2_view[idx], imm); + } else { + vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx]; + } + } + // elements w/ index larger than elem_count are in the tail (fractional LMUL) + // elements w/ index larger than vl are in the tail + unsigned maximum_elems = VLEN * vtype.lmul() / (sizeof(dest_elem_t) * 8); + for(unsigned idx = std::min(elem_count, vl); idx < maximum_elems; idx++) { + vd_view[idx] = vtype.vta() ? vd_view[idx] : vd_view[idx]; + } + return saturated; +} } // namespace softvector \ No newline at end of file