From c1277b6528360bcc407efcd79941cab13f1de249 Mon Sep 17 00:00:00 2001 From: Eyck-Alexander Jentzsch Date: Wed, 19 Feb 2025 19:46:33 +0100 Subject: [PATCH] adds mask_mask logical instructions --- gen_input/templates/interp/CORENAME.cpp.gtl | 3 + src/vm/vector_functions.h | 2 + src/vm/vector_functions.hpp | 114 +++++++++++++------- 3 files changed, 82 insertions(+), 37 deletions(-) diff --git a/gen_input/templates/interp/CORENAME.cpp.gtl b/gen_input/templates/interp/CORENAME.cpp.gtl index 4b3aed8..890888a 100644 --- a/gen_input/templates/interp/CORENAME.cpp.gtl +++ b/gen_input/templates/interp/CORENAME.cpp.gtl @@ -494,6 +494,9 @@ if(vector != null) {%> throw new std::runtime_error("Unsupported sew bit value"); } } + void mask_mask_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, unsigned vd, unsigned vs2, unsigned vs1){ + return softvector::mask_mask_op<${vlen}>(V, funct6, funct3, vl, vstart, vd, vs2, vs1); + } <%}%> uint64_t fetch_count{0}; uint64_t tval{0}; diff --git a/src/vm/vector_functions.h b/src/vm/vector_functions.h index d006fad..19b8cc7 100644 --- a/src/vm/vector_functions.h +++ b/src/vm/vector_functions.h @@ -100,6 +100,8 @@ bool sat_vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl template void vector_red_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2, unsigned vs1); +template +void mask_mask_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, unsigned vd, unsigned vs2, unsigned vs1); } // namespace softvector #include "vm/vector_functions.hpp" #endif /* _VM_VECTOR_FUNCTIONS_H_ */ diff --git a/src/vm/vector_functions.hpp b/src/vm/vector_functions.hpp index c3ece59..cba2bc9 100644 --- a/src/vm/vector_functions.hpp +++ b/src/vm/vector_functions.hpp @@ -155,14 +155,6 @@ std::function get_funct(unsi // case 0b001110: // VSLID1EUP // case 0b001111: // VSLIDE1DOWN // case 0b010111: // VCOMPRESS - // case 0b011000: // VMANDN - // case 0b011001: // VMAND - // case 0b011010: // VMOR - // case 0b011011: // VMXOR - // case 0b011100: // VMORN - // case 0b011101: // VMNAND - // case 0b011110: // VMNOR - // case 0b011111: // VMXNOR case 0b100000: // VDIVU return [](dest_elem_t vd, src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t { if(vs1 == 0) @@ -380,34 +372,58 @@ void vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, ui } return; } -template std::function get_mask_funct(unsigned funct) { - switch(funct) { - case 0b011000: // VMSEQ - return [](elem_t vs2, elem_t vs1) { return vs2 == vs1; }; - case 0b011001: // VMSNE - return [](elem_t vs2, elem_t vs1) { return vs2 != vs1; }; - case 0b011010: // VMSLTU - return [](elem_t vs2, elem_t vs1) { return vs2 < vs1; }; - case 0b011011: // VMSLT - return [](elem_t vs2, elem_t vs1) { - return static_cast>(vs2) < static_cast>(vs1); - }; - case 0b011100: // VMSLEU - return [](elem_t vs2, elem_t vs1) { return vs2 <= vs1; }; - case 0b011101: // VMSLE - return [](elem_t vs2, elem_t vs1) { - return static_cast>(vs2) <= static_cast>(vs1); - }; - case 0b011110: // VMSGTU - return [](elem_t vs2, elem_t vs1) { return vs2 > vs1; }; - case 0b011111: // VMSGT - return [](elem_t vs2, elem_t vs1) { - return static_cast>(vs2) > static_cast>(vs1); - }; +template std::function get_mask_funct(unsigned funct6, unsigned funct3) { + if(funct3 == OPIVV || funct3 == OPIVX || funct3 == OPIVI) + switch(funct6) { + case 0b011000: // VMSEQ + return [](elem_t vs2, elem_t vs1) { return vs2 == vs1; }; + case 0b011001: // VMSNE + return [](elem_t vs2, elem_t vs1) { return vs2 != vs1; }; + case 0b011010: // VMSLTU + return [](elem_t vs2, elem_t vs1) { return vs2 < vs1; }; + case 0b011011: // VMSLT + return [](elem_t vs2, elem_t vs1) { + return static_cast>(vs2) < static_cast>(vs1); + }; + case 0b011100: // VMSLEU + return [](elem_t vs2, elem_t vs1) { return vs2 <= vs1; }; + case 0b011101: // VMSLE + return [](elem_t vs2, elem_t vs1) { + return static_cast>(vs2) <= static_cast>(vs1); + }; + case 0b011110: // VMSGTU + return [](elem_t vs2, elem_t vs1) { return vs2 > vs1; }; + case 0b011111: // VMSGT + return [](elem_t vs2, elem_t vs1) { + return static_cast>(vs2) > static_cast>(vs1); + }; - default: - throw new std::runtime_error("Unknown funct in get_mask_funct"); - } + default: + throw new std::runtime_error("Unknown funct6 in get_mask_funct"); + } + else if(funct3 == OPMVV || funct3 == OPMVX) + switch(funct6) { + case 0b011000: // VMANDN + return [](elem_t vs2, elem_t vs1) { return vs2 & !vs1; }; + case 0b011001: // VMAND + return [](elem_t vs2, elem_t vs1) { return vs2 & vs1; }; + case 0b011010: // VMOR + return [](elem_t vs2, elem_t vs1) { return vs2 | vs1; }; + case 0b011011: // VMXOR + return [](elem_t vs2, elem_t vs1) { return vs2 ^ vs1; }; + case 0b011100: // VMORN + return [](elem_t vs2, elem_t vs1) { return vs2 | !vs1; }; + case 0b011101: // VMNAND + return [](elem_t vs2, elem_t vs1) { return !(vs2 & vs1); }; + case 0b011110: // VMNOR + return [](elem_t vs2, elem_t vs1) { return !(vs2 | vs1); }; + case 0b011111: // VMXNOR + return [](elem_t vs2, elem_t vs1) { return !(vs2 ^ vs1); }; + default: + throw new std::runtime_error("Unknown funct6 in get_mask_funct"); + } + else + throw new std::runtime_error("Unknown funct3 in get_mask_funct"); } template void mask_vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, @@ -417,7 +433,7 @@ void mask_vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_ auto vs1_view = get_vreg(V, vs1, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); vmask_view vd_mask_view = read_vmask(V, elem_count, vd); - auto fn = get_mask_funct(funct6); + auto fn = get_mask_funct(funct6, funct3); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { @@ -449,7 +465,7 @@ void mask_vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t v vmask_view mask_reg = read_vmask(V, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); vmask_view vd_mask_view = read_vmask(V, elem_count, vd); - auto fn = get_mask_funct(funct6); + auto fn = get_mask_funct(funct6, funct3); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { @@ -866,4 +882,28 @@ void vector_red_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, ui } return; } +template +void mask_mask_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, unsigned vd, unsigned vs2, unsigned vs1) { + uint64_t elem_count = VLEN; + auto vs1_view = read_vmask(V, elem_count, vs1); + auto vs2_view = read_vmask(V, elem_count, vs2); + auto vd_view = read_vmask(V, elem_count, vd); + auto fn = get_mask_funct(funct6, funct3); // could be bool, but would break the make_signed_t in get_mask_funct + for(unsigned idx = vstart; idx < std::min(vl, elem_count); idx++) { + unsigned new_bit_value = fn(vs2_view[idx], vs1_view[idx]); + uint8_t* cur_mask_byte_addr = vd_view.start + idx / 8; + unsigned cur_bit = idx % 8; + *cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast(new_bit_value) << cur_bit; + } + // the tail is all elements of the destination register beyond the first one + for(unsigned idx = 1; idx < VLEN; idx++) { + // always tail agnostic + // this is a nop, placeholder for vta behavior + unsigned new_bit_value = vd_view[idx]; + uint8_t* cur_mask_byte_addr = vd_view.start + idx / 8; + unsigned cur_bit = idx % 8; + *cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast(new_bit_value) << cur_bit; + } + return; +} } // namespace softvector \ No newline at end of file