diff --git a/gen_input/templates/interp/CORENAME.cpp.gtl b/gen_input/templates/interp/CORENAME.cpp.gtl index a8c4718..a2f359d 100644 --- a/gen_input/templates/interp/CORENAME.cpp.gtl +++ b/gen_input/templates/interp/CORENAME.cpp.gtl @@ -279,11 +279,11 @@ if(vector != null) {%> void vector_vector_wv(uint8_t* V, uint8_t funct6, uint8_t funct3, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, bool vm, uint8_t vd, uint8_t vs2, uint8_t vs1, uint8_t sew_val){ switch(sew_val){ case 0b000: - return softvector::vector_vector_op<${vlen}, uint16_t, uint8_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1 ); + return softvector::vector_vector_op<${vlen}, uint16_t, uint8_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1); case 0b001: - return softvector::vector_vector_op<${vlen}, uint32_t, uint16_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1 ); + return softvector::vector_vector_op<${vlen}, uint32_t, uint16_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1); case 0b010: - return softvector::vector_vector_op<${vlen}, uint64_t, uint32_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1 ); + return softvector::vector_vector_op<${vlen}, uint64_t, uint32_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1); case 0b011: // would widen to 128 bits default: throw new std::runtime_error("Unsupported sew bit value"); @@ -422,11 +422,11 @@ if(vector != null) {%> void vector_vector_vw(uint8_t* V, uint8_t funct6, uint8_t funct3, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, bool vm, uint8_t vd, uint8_t vs2, uint8_t vs1, uint8_t sew_val){ switch(sew_val){ case 0b000: - return softvector::vector_vector_op<${vlen}, uint8_t, uint16_t, uint8_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1 ); + return softvector::vector_vector_op<${vlen}, uint8_t, uint16_t, uint8_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1); case 0b001: - return softvector::vector_vector_op<${vlen}, uint16_t, uint32_t, uint16_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1 ); + return softvector::vector_vector_op<${vlen}, uint16_t, uint32_t, uint16_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1); case 0b010: - return softvector::vector_vector_op<${vlen}, uint32_t, uint64_t, uint32_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1 ); + return softvector::vector_vector_op<${vlen}, uint32_t, uint64_t, uint32_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1); case 0b011: // would require 128 bits vs2 value default: throw new std::runtime_error("Unsupported sew bit value"); @@ -448,13 +448,13 @@ if(vector != null) {%> void vector_vector_merge(uint8_t* V, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, bool vm, uint8_t vd, uint8_t vs2, uint8_t vs1, uint8_t sew_val){ switch(sew_val){ case 0b000: - return softvector::vector_vector_op<${vlen}, uint8_t>(V, 0, 0, vl, vstart, vtype, vm, vd, vs2, vs1, softvector::carry_t::NO_CARRY, true); + return softvector::vector_vector_merge<${vlen}, uint8_t>(V, vl, vstart, vtype, vm, vd, vs2, vs1); case 0b001: - return softvector::vector_vector_op<${vlen}, uint16_t>(V, 0, 0, vl, vstart, vtype, vm, vd, vs2, vs1, softvector::carry_t::NO_CARRY, true); + return softvector::vector_vector_merge<${vlen}, uint16_t>(V, vl, vstart, vtype, vm, vd, vs2, vs1); case 0b010: - return softvector::vector_vector_op<${vlen}, uint32_t>(V, 0, 0, vl, vstart, vtype, vm, vd, vs2, vs1, softvector::carry_t::NO_CARRY, true); + return softvector::vector_vector_merge<${vlen}, uint32_t>(V, vl, vstart, vtype, vm, vd, vs2, vs1); case 0b011: - return softvector::vector_vector_op<${vlen}, uint64_t>(V, 0, 0, vl, vstart, vtype, vm, vd, vs2, vs1, softvector::carry_t::NO_CARRY, true); + return softvector::vector_vector_merge<${vlen}, uint64_t>(V, vl, vstart, vtype, vm, vd, vs2, vs1); default: throw new std::runtime_error("Unsupported sew bit value"); } @@ -462,13 +462,13 @@ if(vector != null) {%> void vector_imm_merge(uint8_t* V, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, bool vm, uint8_t vd, uint8_t vs2, int64_t imm, uint8_t sew_val){ switch(sew_val){ case 0b000: - return softvector::vector_imm_op<${vlen}, uint8_t>(V, 0, 0, vl, vstart, vtype, vm, vd, vs2, imm, softvector::carry_t::NO_CARRY, true); + return softvector::vector_imm_merge<${vlen}, uint8_t>(V, vl, vstart, vtype, vm, vd, vs2, imm); case 0b001: - return softvector::vector_imm_op<${vlen}, uint16_t>(V, 0, 0, vl, vstart, vtype, vm, vd, vs2, imm, softvector::carry_t::NO_CARRY, true); + return softvector::vector_imm_merge<${vlen}, uint16_t>(V, vl, vstart, vtype, vm, vd, vs2, imm); case 0b010: - return softvector::vector_imm_op<${vlen}, uint32_t>(V, 0, 0, vl, vstart, vtype, vm, vd, vs2, imm, softvector::carry_t::NO_CARRY, true); + return softvector::vector_imm_merge<${vlen}, uint32_t>(V, vl, vstart, vtype, vm, vd, vs2, imm); case 0b011: - return softvector::vector_imm_op<${vlen}, uint64_t>(V, 0, 0, vl, vstart, vtype, vm, vd, vs2, imm, softvector::carry_t::NO_CARRY, true); + return softvector::vector_imm_merge<${vlen}, uint64_t>(V, vl, vstart, vtype, vm, vd, vs2, imm); default: throw new std::runtime_error("Unsupported sew bit value"); } @@ -871,6 +871,37 @@ if(vector != null) {%> throw new std::runtime_error("Unsupported sew bit value"); } } + void mask_fp_vector_vector_op(uint8_t* V, uint8_t funct6, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, bool vm, uint8_t vd, uint8_t vs2, uint8_t vs1, uint8_t rm, uint8_t sew_val){ + switch(sew_val){ + case 0b000: + throw new std::runtime_error("Unsupported sew bit value"); + case 0b001: + throw new std::runtime_error("Unsupported sew bit value"); + case 0b010: + return softvector::mask_fp_vector_vector_op<${vlen}, uint32_t>(V, funct6, vl, vstart, vtype, vm, vd, vs2, vs1, rm); + case 0b011: + return softvector::mask_fp_vector_vector_op<${vlen}, uint64_t>(V, funct6, vl, vstart, vtype, vm, vd, vs2, vs1, rm); + default: + throw new std::runtime_error("Unsupported sew bit value"); + } + } + void mask_fp_vector_imm_op(uint8_t* V, uint8_t funct6, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, bool vm, uint8_t vd, uint8_t vs2, int64_t imm, uint8_t rm, uint8_t sew_val){ + switch(sew_val){ + case 0b000: + throw new std::runtime_error("Unsupported sew bit value"); + case 0b001: + throw new std::runtime_error("Unsupported sew bit value"); + case 0b010: + return softvector::mask_fp_vector_imm_op<${vlen}, uint32_t>(V, funct6, vl, vstart, vtype, vm, vd, vs2, imm, rm); + case 0b011: + return softvector::mask_fp_vector_imm_op<${vlen}, uint64_t>(V, funct6, vl, vstart, vtype, vm, vd, vs2, imm, rm); + default: + throw new std::runtime_error("Unsupported sew bit value"); + } + } + void fp_vector_imm_merge(uint8_t* V, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, bool vm, uint8_t vd, uint8_t vs2, int64_t imm, uint8_t sew_val){ + vector_imm_merge(V, vl, vstart, vtype, vm, vd, vs2, imm, sew_val); + } <%}%> uint64_t fetch_count{0}; uint64_t tval{0}; diff --git a/src/vm/vector_functions.h b/src/vm/vector_functions.h index 38f0759..6d0c51f 100644 --- a/src/vm/vector_functions.h +++ b/src/vm/vector_functions.h @@ -84,10 +84,14 @@ uint64_t vector_load_store_index(void* core, std::function void vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, - unsigned vs2, unsigned vs1, carry_t carry = carry_t::NO_CARRY, bool merge = false); + unsigned vs2, unsigned vs1, carry_t carry = carry_t::NO_CARRY); template void vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, - unsigned vs2, typename std::make_signed::type imm, carry_t carry = carry_t::NO_CARRY, bool merge = false); + unsigned vs2, typename std::make_signed::type imm, carry_t carry = carry_t::NO_CARRY); +template +void vector_vector_merge(uint8_t* V, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, unsigned vs1); +template +void vector_imm_merge(uint8_t* V, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, uint64_t imm); template void vector_unary_op(uint8_t* V, unsigned unary_op, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2); template @@ -144,6 +148,12 @@ void fp_vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, template void fp_vector_unary_op(uint8_t* V, unsigned unary_op, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, uint8_t rm); +template +void mask_fp_vector_vector_op(uint8_t* V, unsigned funct6, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, + unsigned vs1, uint8_t rm); +template +void mask_fp_vector_imm_op(uint8_t* V, unsigned funct6, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, + elem_t imm, uint8_t rm); } // namespace softvector #include "vm/vector_functions.hpp" #endif /* _VM_VECTOR_FUNCTIONS_H_ */ diff --git a/src/vm/vector_functions.hpp b/src/vm/vector_functions.hpp index 6cc1dcc..3e0e584 100644 --- a/src/vm/vector_functions.hpp +++ b/src/vm/vector_functions.hpp @@ -337,16 +337,9 @@ std::function get_funct(unsi else throw new std::runtime_error("Unknown funct3 in get_funct"); } -template std::function get_merge_funct(bool vm) { - if(vm) { // VMV - return [](bool vm, dest_elem_t vs2, dest_elem_t vs1) { return vs1; }; - } else { // VMERGE - return [](bool vm, dest_elem_t vs2, dest_elem_t vs1) { return vm ? vs1 : vs2; }; - } -}; template void vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, - unsigned vs2, unsigned vs1, carry_t carry, bool merge) { + unsigned vs2, unsigned vs1, carry_t carry) { uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); vmask_view mask_reg = read_vmask(V, elem_count); auto vs1_view = get_vreg(V, vs1, elem_count); @@ -355,12 +348,7 @@ void vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, auto fn = get_funct(funct6, funct3); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) - if(merge) { - for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { - auto merge_fn = get_merge_funct(vm); - vd_view[idx] = merge_fn(mask_reg[idx], vs2_view[idx], vs1_view[idx]); - } - } else if(carry == carry_t::NO_CARRY) { + if(carry == carry_t::NO_CARRY) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { @@ -388,7 +376,7 @@ void vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, } template void vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, - unsigned vs2, typename std::make_signed::type imm, carry_t carry, bool merge) { + unsigned vs2, typename std::make_signed::type imm, carry_t carry) { uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); vmask_view mask_reg = read_vmask(V, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); @@ -396,15 +384,7 @@ void vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, ui auto fn = get_funct(funct6, funct3); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) - if(merge) { - for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { - auto cur_mask = mask_reg[idx]; - auto vd_val = vd_view[idx]; - auto vs2_val = vs2_view[idx]; - auto merge_fn = get_merge_funct(vm); - vd_view[idx] = merge_fn(mask_reg[idx], vs2_view[idx], imm); - } - } else if(carry == carry_t::NO_CARRY) { + if(carry == carry_t::NO_CARRY) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { @@ -430,6 +410,35 @@ void vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, ui } return; } +template +void vector_vector_merge(uint8_t* V, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, unsigned vs1) { + uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); + vmask_view mask_reg = read_vmask(V, elem_count); + auto vs1_view = get_vreg(V, vs1, elem_count); + auto vs2_view = get_vreg(V, vs2, elem_count); + auto vd_view = get_vreg(V, vd, elem_count); + for(unsigned idx = vstart; idx < vl; idx++) { + bool mask_active = vm ? 1 : mask_reg[idx]; + if(mask_active) + vd_view[idx] = vs1_view[idx]; + else + vd_view[idx] = vs2_view[idx]; + } +} +template +void vector_imm_merge(uint8_t* V, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, uint64_t imm) { + uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); + vmask_view mask_reg = read_vmask(V, elem_count); + auto vs2_view = get_vreg(V, vs2, elem_count); + auto vd_view = get_vreg(V, vd, elem_count); + for(unsigned idx = vstart; idx < vl; idx++) { + bool mask_active = vm ? 1 : mask_reg[idx]; + if(mask_active) + vd_view[idx] = imm; + else + vd_view[idx] = vs2_view[idx]; + } +} template std::function get_mask_funct(unsigned funct6, unsigned funct3) { if(funct3 == OPIVV || funct3 == OPIVX || funct3 == OPIVI) switch(funct6) { @@ -1055,15 +1064,11 @@ std::function(vs2, vs1); - accrued_flags |= softfloat_exceptionFlags; - return val; + return fp_min(vs2, vs1); }; case 0b000110: // VFMAX return [](uint8_t rm, uint_fast8_t& accrued_flags, dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { - dest_elem_t val = fp_max(vs2, vs1); - accrued_flags |= softfloat_exceptionFlags; - return val; + return fp_max(vs2, vs1); }; case 0b100000: // VFDIV return [](uint8_t rm, uint_fast8_t& accrued_flags, dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) { @@ -1210,12 +1215,6 @@ std::function get_fp_red case 0b000101: // VFREDMIN return [](uint8_t rm, uint_fast8_t& accrued_flags, dest_elem_t& running_total, src_elem_t vs2) { running_total = fp_min(running_total, vs2); - accrued_flags |= softfloat_exceptionFlags; }; case 0b000111: // VFREDMAX return [](uint8_t rm, uint_fast8_t& accrued_flags, dest_elem_t& running_total, src_elem_t vs2) { running_total = fp_max(running_total, vs2); - accrued_flags |= softfloat_exceptionFlags; }; case 0b110001: // VFWREDUSUM return [](uint8_t rm, uint_fast8_t& accrued_flags, dest_elem_t& running_total, src_elem_t vs2) { @@ -1414,6 +1411,104 @@ void fp_vector_unary_op(uint8_t* V, unsigned unary_op, uint64_t vl, uint64_t vst } return; } +template bool fp_eq(elem_size_t, elem_size_t); +template <> inline bool fp_eq(uint32_t v2, uint32_t v1) { return fcmp_s(v2, v1, 0); } +template <> inline bool fp_eq(uint64_t v2, uint64_t v1) { return fcmp_d(v2, v1, 0); } +template bool fp_le(elem_size_t, elem_size_t); +template <> inline bool fp_le(uint32_t v2, uint32_t v1) { return fcmp_s(v2, v1, 1); } +template <> inline bool fp_le(uint64_t v2, uint64_t v1) { return fcmp_d(v2, v1, 1); } +template bool fp_lt(elem_size_t, elem_size_t); +template <> inline bool fp_lt(uint32_t v2, uint32_t v1) { return fcmp_s(v2, v1, 2); } +template <> inline bool fp_lt(uint64_t v2, uint64_t v1) { return fcmp_d(v2, v1, 2); } +template std::function get_fp_mask_funct(unsigned funct6) { + switch(funct6) { + case 0b011000: // VMFEQ + return [](uint8_t rm, uint_fast8_t& accrued_flags, elem_t vs2, elem_t vs1) { + elem_t val = fp_eq(vs2, vs1); + accrued_flags |= softfloat_exceptionFlags; + return val; + }; + case 0b011001: // VMFLE + return [](uint8_t rm, uint_fast8_t& accrued_flags, elem_t vs2, elem_t vs1) { + elem_t val = fp_le(vs2, vs1); + accrued_flags |= softfloat_exceptionFlags; + return val; + }; + case 0b011011: // VMFLT + return [](uint8_t rm, uint_fast8_t& accrued_flags, elem_t vs2, elem_t vs1) { + elem_t val = fp_lt(vs2, vs1); + accrued_flags |= softfloat_exceptionFlags; + return val; + }; + case 0b011100: // VMFNE + return [](uint8_t rm, uint_fast8_t& accrued_flags, elem_t vs2, elem_t vs1) { + elem_t val = !fp_eq(vs2, vs1); + accrued_flags |= softfloat_exceptionFlags; + return val; + }; + case 0b011101: // VMFGT + return [](uint8_t rm, uint_fast8_t& accrued_flags, elem_t vs2, elem_t vs1) { + elem_t val = fp_lt(vs1, vs2); + accrued_flags |= softfloat_exceptionFlags; + return val; + }; + case 0b011111: // VMFGE + return [](uint8_t rm, uint_fast8_t& accrued_flags, elem_t vs2, elem_t vs1) { + elem_t val = fp_le(vs1, vs2); + accrued_flags |= softfloat_exceptionFlags; + return val; + }; + default: + throw new std::runtime_error("Unknown funct6 in get_fp_mask_funct"); + } +} +template +void mask_fp_vector_vector_op(uint8_t* V, unsigned funct6, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, + unsigned vs1, uint8_t rm) { + uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); + vmask_view mask_reg = read_vmask(V, elem_count); + auto vs1_view = get_vreg(V, vs1, elem_count); + auto vs2_view = get_vreg(V, vs2, elem_count); + vmask_view vd_mask_view = read_vmask(V, VLEN, vd); + auto fn = get_fp_mask_funct(funct6); + uint_fast8_t accrued_flags = 0; + for(unsigned idx = vstart; idx < vl; idx++) { + bool mask_active = vm ? 1 : mask_reg[idx]; + if(mask_active) { + vd_mask_view[idx] = fn(rm, accrued_flags, vs2_view[idx], vs1_view[idx]); + } else { + vd_mask_view[idx] = vtype.vma() ? vd_mask_view[idx] : vd_mask_view[idx]; + } + } + softfloat_exceptionFlags = accrued_flags; + for(unsigned idx = vl; idx < VLEN; idx++) { + vd_mask_view[idx] = vtype.vta() ? vd_mask_view[idx] : vd_mask_view[idx]; + } + return; +} +template +void mask_fp_vector_imm_op(uint8_t* V, unsigned funct6, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, + elem_t imm, uint8_t rm) { + uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); + vmask_view mask_reg = read_vmask(V, elem_count); + auto vs2_view = get_vreg(V, vs2, elem_count); + vmask_view vd_mask_view = read_vmask(V, VLEN, vd); + auto fn = get_fp_mask_funct(funct6); + uint_fast8_t accrued_flags = 0; + for(unsigned idx = vstart; idx < vl; idx++) { + bool mask_active = vm ? 1 : mask_reg[idx]; + if(mask_active) { + vd_mask_view[idx] = fn(rm, accrued_flags, vs2_view[idx], imm); + } else { + vd_mask_view[idx] = vtype.vma() ? vd_mask_view[idx] : vd_mask_view[idx]; + } + } + softfloat_exceptionFlags = accrued_flags; + for(unsigned idx = vl; idx < VLEN; idx++) { + vd_mask_view[idx] = vtype.vta() ? vd_mask_view[idx] : vd_mask_view[idx]; + } + return; +} template void mask_mask_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, unsigned vd, unsigned vs2, unsigned vs1) { uint64_t elem_count = VLEN;