From 8746003d3effd9e3dc1b777ff29b7e7239c715c3 Mon Sep 17 00:00:00 2001 From: Eyck-Alexander Jentzsch Date: Wed, 26 Feb 2025 18:53:14 +0100 Subject: [PATCH] adds floating point reduction instrs, widening are untested --- gen_input/templates/interp/CORENAME.cpp.gtl | 28 +++- src/vm/vector_functions.h | 3 + src/vm/vector_functions.hpp | 171 +++++++++++++++++++- 3 files changed, 198 insertions(+), 4 deletions(-) diff --git a/gen_input/templates/interp/CORENAME.cpp.gtl b/gen_input/templates/interp/CORENAME.cpp.gtl index 74ac525..227a7fa 100644 --- a/gen_input/templates/interp/CORENAME.cpp.gtl +++ b/gen_input/templates/interp/CORENAME.cpp.gtl @@ -750,7 +750,33 @@ if(vector != null) {%> void fp_vector_slide1down(uint8_t* V, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, bool vm, unsigned vd, unsigned vs2, uint64_t imm, uint8_t sew_val) { return vector_slide1down(V, vl, vstart, vtype, vm, vd, vs2, imm, sew_val); } - + void fp_vector_red_op(uint8_t* V, uint8_t funct6, uint8_t funct3, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, bool vm, uint8_t vd, uint8_t vs2, uint8_t vs1, uint8_t rm, uint8_t sew_val){ + switch(sew_val){ + case 0b000: + throw new std::runtime_error("Unsupported sew bit value"); + case 0b001: + throw new std::runtime_error("Unsupported sew bit value"); + case 0b010: + return softvector::fp_vector_red_op<${vlen}, uint32_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1, rm); + case 0b011: + return softvector::fp_vector_red_op<${vlen}, uint64_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1, rm); + default: + throw new std::runtime_error("Unsupported sew bit value"); + } + } + void fp_vector_red_wv(uint8_t* V, uint8_t funct6, uint8_t funct3, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, bool vm, uint8_t vd, uint8_t vs2, uint8_t vs1, uint8_t rm, uint8_t sew_val){ + switch(sew_val){ + case 0b000: + throw new std::runtime_error("Unsupported sew bit value"); + case 0b001: + return softvector::fp_vector_red_op<${vlen}, uint32_t, uint16_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1, rm); + case 0b010: + return softvector::fp_vector_red_op<${vlen}, uint64_t, uint32_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1, rm); + case 0b011: // would require 128 bits vs2 value + default: + throw new std::runtime_error("Unsupported sew bit value"); + } + } <%}%> uint64_t fetch_count{0}; uint64_t tval{0}; diff --git a/src/vm/vector_functions.h b/src/vm/vector_functions.h index d99c023..10a9d5e 100644 --- a/src/vm/vector_functions.h +++ b/src/vm/vector_functions.h @@ -122,6 +122,9 @@ void vector_imm_gather(uint8_t* V, uint64_t vl, uint64_t vstart, vtype_t vtype, template void vector_compress(uint8_t* V, uint64_t vl, uint64_t vstart, vtype_t vtype, unsigned vd, unsigned vs2, unsigned vs1); template void vector_whole_move(uint8_t* V, unsigned vd, unsigned vs2, unsigned count); +template +void fp_vector_red_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, + unsigned vs2, unsigned vs1, uint8_t rm); } // namespace softvector #include "vm/vector_functions.hpp" #endif /* _VM_VECTOR_FUNCTIONS_H_ */ diff --git a/src/vm/vector_functions.hpp b/src/vm/vector_functions.hpp index 78a9f02..d50a277 100644 --- a/src/vm/vector_functions.hpp +++ b/src/vm/vector_functions.hpp @@ -32,7 +32,13 @@ // alex@minres.com - initial API and implementation //////////////////////////////////////////////////////////////////////////////// #pragma once +#include "softfloat.h" +#include "softfloat_types.h" +#include "specialize.h" +#include "vm/fp_functions.h" #include "vm/vector_functions.h" +#include +#include #include #include #include @@ -873,7 +879,7 @@ std::function get_red_funct(unsigned funct6, uns static_cast>(static_cast>(vs2))); }; default: - throw new std::runtime_error("Unknown funct6 in get_funct"); + throw new std::runtime_error("Unknown funct6 in get_red_funct"); } else if(funct3 == OPMVV || funct3 == OPMVX) switch(funct6) { @@ -902,10 +908,10 @@ std::function get_red_funct(unsigned funct6, uns static_cast>(static_cast>(vs2))); }; default: - throw new std::runtime_error("Unknown funct6 in get_funct"); + throw new std::runtime_error("Unknown funct6 in get_red_funct"); } else - throw new std::runtime_error("Unknown funct3 in get_funct"); + throw new std::runtime_error("Unknown funct3 in get_red_funct"); } template void vector_red_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, @@ -932,6 +938,165 @@ void vector_red_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, ui } return; } +// might be that these exist somewhere in softfloat +template constexpr bool isNaN(src_elem_t x); +template <> constexpr bool isNaN(uint32_t x) { return ((x & 0x7F800000) == 0x7F800000) && ((x & 0x007FFFFF) != 0); } +template <> constexpr bool isNaN(uint64_t x) { + return ((x & 0x7FF0000000000000) == 0x7FF0000000000000) && ((x & 0x000FFFFFFFFFFFFF) != 0); +} +template constexpr bool isNegZero(src_elem_t x); +template <> constexpr bool isNegZero(uint32_t x) { return x == 0x80000000; } +template <> constexpr bool isNegZero(uint64_t x) { return x == 0x8000000000000000; } +template constexpr bool isPosZero(src_elem_t x); +template <> constexpr bool isPosZero(uint32_t x) { return x == 0x00000000; } +template <> constexpr bool isPosZero(uint64_t x) { return x == 0x0000000000000000; } + +template elem_size_t fp_add(uint8_t, elem_size_t, elem_size_t); +template <> inline uint32_t fp_add(uint8_t mode, uint32_t v1, uint32_t v2) { return fadd_s(v1, v2, mode); } +template <> inline uint64_t fp_add(uint8_t mode, uint64_t v1, uint64_t v2) { return fadd_d(v1, v2, mode); } +template dest_elem_t widen_float(src_elem_t val) { + static_assert(sizeof(dest_elem_t) == 8 && sizeof(src_elem_t) == 4, ""); + return static_cast(static_cast(val)); +}; + +template elem_size_t fp_min(elem_size_t, elem_size_t); +template <> inline uint32_t fp_min(uint32_t v1, uint32_t v2) { + bool v1_lt_v2 = fcmp_s(v1, v2, 2); + if(isNaN(v1) && isNaN(v2)) + return defaultNaNF32UI; + else if(isNaN(v1)) + return v2; + else if(isNaN(v2)) + return v1; + else if(isNegZero(v1) && isNegZero(v2)) + return v1; + else if(isNegZero(v2) && isNegZero(v1)) + return v2; + else if(v1_lt_v2) + return v1; + else + return v2; +} +template <> inline uint64_t fp_min(uint64_t v1, uint64_t v2) { + bool v1_lt_v2 = fcmp_d(v1, v2, 2); + if(isNaN(v1) && isNaN(v2)) + return defaultNaNF32UI; + else if(isNaN(v1)) + return v2; + else if(isNaN(v2)) + return v1; + else if(isNegZero(v1) && isNegZero(v2)) + return v1; + else if(isNegZero(v2) && isNegZero(v1)) + return v2; + else if(v1_lt_v2) + return v1; + else + return v2; +} +template elem_size_t fp_max(elem_size_t, elem_size_t); +template <> inline uint32_t fp_max(uint32_t v1, uint32_t v2) { + bool v1_lt_v2 = fcmp_s(v1, v2, 2); + if(isNaN(v1) && isNaN(v2)) + return defaultNaNF32UI; + else if(isNaN(v1)) + return v2; + else if(isNaN(v2)) + return v1; + else if(isNegZero(v1) && isNegZero(v2)) + return v2; + else if(isNegZero(v2) && isNegZero(v1)) + return v1; + else if(v1_lt_v2) + return v2; + else + return v1; +} +template <> inline uint64_t fp_max(uint64_t v1, uint64_t v2) { + bool v1_lt_v2 = fcmp_d(v1, v2, 2); + if(isNaN(v1) && isNaN(v2)) + return defaultNaNF32UI; + else if(isNaN(v1)) + return v2; + else if(isNaN(v2)) + return v1; + else if(isNegZero(v1) && isNegZero(v2)) + return v2; + else if(isNegZero(v2) && isNegZero(v1)) + return v1; + else if(v1_lt_v2) + return v2; + else + return v1; +} + +template +std::function get_fp_red_funct(unsigned funct6, unsigned funct3) { + if(funct3 == OPFVV || funct3 == OPFVF) + switch(funct6) { + case 0b000001: // VFREDUSUM + return [](uint8_t rm, uint_fast8_t& accrued_flags, dest_elem_t& running_total, src_elem_t vs2) { + running_total = fp_add(rm, running_total, vs2); + accrued_flags |= softfloat_exceptionFlags; + }; + case 0b000011: // VFREDOSUM + return [](uint8_t rm, uint_fast8_t& accrued_flags, dest_elem_t& running_total, src_elem_t vs2) { + running_total = fp_add(rm, running_total, vs2); + accrued_flags |= softfloat_exceptionFlags; + }; + case 0b000101: // VFREDMIN + return [](uint8_t rm, uint_fast8_t& accrued_flags, dest_elem_t& running_total, src_elem_t vs2) { + running_total = fp_min(running_total, vs2); + accrued_flags |= softfloat_exceptionFlags; + }; + case 0b000111: // VFREDMAX + return [](uint8_t rm, uint_fast8_t& accrued_flags, dest_elem_t& running_total, src_elem_t vs2) { + running_total = fp_max(running_total, vs2); + accrued_flags |= softfloat_exceptionFlags; + }; + case 0b110001: // VFWREDUSUM + return [](uint8_t rm, uint_fast8_t& accrued_flags, dest_elem_t& running_total, src_elem_t vs2) { + running_total = fp_add(rm, running_total, widen_float(vs2)); + accrued_flags |= softfloat_exceptionFlags; + }; + case 0b110011: // VFWREDOSUM + return [](uint8_t rm, uint_fast8_t& accrued_flags, dest_elem_t& running_total, src_elem_t vs2) { + running_total = fp_add(rm, running_total, widen_float(vs2)); + accrued_flags |= softfloat_exceptionFlags; + }; + default: + throw new std::runtime_error("Unknown funct6 in get_fp_red_funct"); + } + else + throw new std::runtime_error("Unknown funct3 in get_fp_red_funct"); +} +template +void fp_vector_red_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, + unsigned vs2, unsigned vs1, uint8_t rm) { + if(vl == 0) + return; + uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); + vmask_view mask_reg = read_vmask(V, elem_count); + auto vs1_elem = get_vreg(V, vs1, elem_count)[0]; + auto vs2_view = get_vreg(V, vs2, elem_count); + auto vd_view = get_vreg(V, vd, elem_count); + auto fn = get_fp_red_funct(funct6, funct3); + dest_elem_t& running_total = {vs1_elem}; + uint_fast8_t accrued_flags = 0; + for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { + bool mask_active = vm ? 1 : mask_reg[idx]; + if(mask_active) { + fn(rm, accrued_flags, running_total, vs2_view[idx]); + } + } + vd_view[0] = running_total; + softfloat_exceptionFlags = accrued_flags; + // the tail is all elements of the destination register beyond the first one + for(unsigned idx = 1; idx < VLEN / (vtype.sew() * RFS); idx++) { + vd_view[idx] = vtype.vta() ? vd_view[idx] : vd_view[idx]; + } + return; +} template void mask_mask_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, unsigned vd, unsigned vs2, unsigned vs1) { uint64_t elem_count = VLEN;