From 08280a094f0a0c6ef9bf988bcff797a1467bde8d Mon Sep 17 00:00:00 2001 From: Eyck-Alexander Jentzsch Date: Mon, 3 Mar 2025 20:33:23 +0100 Subject: [PATCH] allows assigning to mask_view elements --- src/vm/vector_functions.cpp | 18 +++++++++-- src/vm/vector_functions.h | 12 +++++++- src/vm/vector_functions.hpp | 60 +++++++++---------------------------- 3 files changed, 41 insertions(+), 49 deletions(-) diff --git a/src/vm/vector_functions.cpp b/src/vm/vector_functions.cpp index 34bb545..b8b3b24 100644 --- a/src/vm/vector_functions.cpp +++ b/src/vm/vector_functions.cpp @@ -42,6 +42,7 @@ #include #include #include +#include namespace softvector { @@ -74,9 +75,22 @@ double vtype_t::lmul() { int8_t signed_vlmul = (vlmul >> 2) ? 0b11111000 | vlmul : vlmul; return pow(2, signed_vlmul); } -bool vmask_view::operator[](size_t idx) const { + +mask_bit_reference& mask_bit_reference::operator=(const bool new_value) { + *start = *start & ~(1U << pos) | static_cast(new_value) << pos; + return *this; +} + +mask_bit_reference::mask_bit_reference(uint8_t* start, uint8_t pos) +: start(start) +, pos(pos) { + assert(pos < 8 && "Bit reference can only be initialized for bytes"); +}; +mask_bit_reference::operator bool() const { return *(start) & (1U << (pos)); } + +mask_bit_reference vmask_view::operator[](size_t idx) const { assert(idx < elem_count); - return *(start + idx / 8) & (1U << (idx % 8)); + return {start + idx / 8, static_cast(idx % 8)}; } vmask_view read_vmask(uint8_t* V, uint16_t VLEN, uint16_t elem_count, uint8_t reg_idx) { diff --git a/src/vm/vector_functions.h b/src/vm/vector_functions.h index c36dea3..38f0759 100644 --- a/src/vm/vector_functions.h +++ b/src/vm/vector_functions.h @@ -53,10 +53,20 @@ struct vtype_t { bool vma(); bool vta(); }; +class mask_bit_reference { + uint8_t* start; + uint8_t pos; + +public: + mask_bit_reference& operator=(const bool new_value); + mask_bit_reference(uint8_t* start, uint8_t pos); + operator bool() const; +}; + struct vmask_view { uint8_t* start; size_t elem_count; - bool operator[](size_t) const; + mask_bit_reference operator[](size_t) const; }; enum class carry_t { NO_CARRY = 0, ADD_CARRY = 1, SUB_CARRY = 2 }; vmask_view read_vmask(uint8_t* V, uint16_t VLEN, uint16_t elem_count, uint8_t reg_idx = 0); diff --git a/src/vm/vector_functions.hpp b/src/vm/vector_functions.hpp index 2727031..6cc1dcc 100644 --- a/src/vm/vector_functions.hpp +++ b/src/vm/vector_functions.hpp @@ -490,29 +490,22 @@ void mask_vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_ vmask_view mask_reg = read_vmask(V, elem_count); auto vs1_view = get_vreg(V, vs1, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); - vmask_view vd_mask_view = read_vmask(V, elem_count, vd); + vmask_view vd_mask_view = read_vmask(V, VLEN, vd); auto fn = get_mask_funct(funct6, funct3); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { - bool new_bit_value = fn(vs2_view[idx], vs1_view[idx]); - uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; - unsigned cur_bit = idx % 8; - *cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast(new_bit_value) << cur_bit; + vd_mask_view[idx] = fn(vs2_view[idx], vs1_view[idx]); } else { - uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; - unsigned cur_bit = idx % 8; - *cur_mask_byte_addr = vtype.vma() ? *cur_mask_byte_addr : *cur_mask_byte_addr; + vd_mask_view[idx] = vtype.vma() ? vd_mask_view[idx] : vd_mask_view[idx]; } } // elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than vl are in the tail for(unsigned idx = std::min(elem_count, vl); idx < VLEN; idx++) { - uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; - unsigned cur_bit = idx % 8; - *cur_mask_byte_addr = vtype.vta() ? *cur_mask_byte_addr : *cur_mask_byte_addr; + vd_mask_view[idx] = vtype.vta() ? vd_mask_view[idx] : vd_mask_view[idx]; } return; } @@ -522,29 +515,22 @@ void mask_vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t v uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); vmask_view mask_reg = read_vmask(V, elem_count); auto vs2_view = get_vreg(V, vs2, elem_count); - vmask_view vd_mask_view = read_vmask(V, elem_count, vd); + vmask_view vd_mask_view = read_vmask(V, VLEN, vd); auto fn = get_mask_funct(funct6, funct3); // elements w/ index smaller than vstart are in the prestart and get skipped // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { - bool new_bit_value = fn(vs2_view[idx], imm); - uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; - unsigned cur_bit = idx % 8; - *cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast(new_bit_value) << cur_bit; + vd_mask_view[idx] = fn(vs2_view[idx], imm); } else { - uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; - unsigned cur_bit = idx % 8; - *cur_mask_byte_addr = vtype.vma() ? *cur_mask_byte_addr : *cur_mask_byte_addr; + vd_mask_view[idx] = vtype.vma() ? vd_mask_view[idx] : vd_mask_view[idx]; } } // elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than vl are in the tail for(unsigned idx = std::min(elem_count, vl); idx < VLEN; idx++) { - uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; - unsigned cur_bit = idx % 8; - *cur_mask_byte_addr = vtype.vta() ? *cur_mask_byte_addr : *cur_mask_byte_addr; + vd_mask_view[idx] = vtype.vta() ? vd_mask_view[idx] : vd_mask_view[idx]; } return; } @@ -615,10 +601,7 @@ void carry_vector_vector_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vs // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { elem_t carry = vm ? 0 : mask_reg[idx]; - bool new_bit_value = fn(vs2_view[idx], vs1_view[idx], carry); - uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; - unsigned cur_bit = idx % 8; - *cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast(new_bit_value) << cur_bit; + vd_mask_view[idx] = fn(vs2_view[idx], vs1_view[idx], carry); } // elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than vl are in the tail @@ -639,10 +622,7 @@ void carry_vector_imm_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vstar // body is from vstart to min(elem_count, vl) for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { elem_t carry = vm ? 0 : mask_reg[idx]; - bool new_bit_value = fn(vs2_view[idx], imm, carry); - uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; - unsigned cur_bit = idx % 8; - *cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast(new_bit_value) << cur_bit; + vd_mask_view[idx] = fn(vs2_view[idx], imm, carry); } // elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than vl are in the tail @@ -1442,19 +1422,13 @@ void mask_mask_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uin auto vd_view = read_vmask(V, elem_count, vd); auto fn = get_mask_funct(funct6, funct3); // could be bool, but would break the make_signed_t in get_mask_funct for(unsigned idx = vstart; idx < vl; idx++) { - unsigned new_bit_value = fn(vs2_view[idx], vs1_view[idx]); - uint8_t* cur_mask_byte_addr = vd_view.start + idx / 8; - unsigned cur_bit = idx % 8; - *cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast(new_bit_value) << cur_bit; + vd_view[idx] = fn(vs2_view[idx], vs1_view[idx]); } // the tail is all elements of the destination register beyond the first one for(unsigned idx = 1; idx < VLEN; idx++) { // always tail agnostic // this is a nop, placeholder for vta behavior - unsigned new_bit_value = vd_view[idx]; - uint8_t* cur_mask_byte_addr = vd_view.start + idx / 8; - unsigned cur_bit = idx % 8; - *cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast(new_bit_value) << cur_bit; + vd_view[idx] = vd_view[idx]; } return; } @@ -1528,20 +1502,14 @@ template void mask_set_op(uint8_t* V, unsigned enc, uint64_t vl, for(unsigned idx = vstart; idx < vl; idx++) { bool mask_active = vm ? 1 : mask_reg[idx]; if(mask_active) { - unsigned new_bit_value = fn(marker, vs2_view[idx]); - uint8_t* cur_mask_byte_addr = vd_view.start + idx / 8; - unsigned cur_bit = idx % 8; - *cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast(new_bit_value) << cur_bit; + vd_view[idx] = fn(marker, vs2_view[idx]); } } // the tail is all elements of the destination register beyond the first one for(unsigned idx = vl; idx < VLEN; idx++) { // always tail agnostic // this is a nop, placeholder for vta behavior - unsigned new_bit_value = vd_view[idx]; - uint8_t* cur_mask_byte_addr = vd_view.start + idx / 8; - unsigned cur_bit = idx % 8; - *cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast(new_bit_value) << cur_bit; + vd_view[idx] = vd_view[idx]; } } template