allows assigning to mask_view elements

This commit is contained in:
Eyck-Alexander Jentzsch 2025-03-03 20:33:23 +01:00
parent ae90adc854
commit 08280a094f
3 changed files with 41 additions and 49 deletions

View File

@ -42,6 +42,7 @@
#include <limits> #include <limits>
#include <math.h> #include <math.h>
#include <stdexcept> #include <stdexcept>
#include <vector>
namespace softvector { namespace softvector {
@ -74,9 +75,22 @@ double vtype_t::lmul() {
int8_t signed_vlmul = (vlmul >> 2) ? 0b11111000 | vlmul : vlmul; int8_t signed_vlmul = (vlmul >> 2) ? 0b11111000 | vlmul : vlmul;
return pow(2, signed_vlmul); return pow(2, signed_vlmul);
} }
bool vmask_view::operator[](size_t idx) const {
mask_bit_reference& mask_bit_reference::operator=(const bool new_value) {
*start = *start & ~(1U << pos) | static_cast<unsigned>(new_value) << pos;
return *this;
}
mask_bit_reference::mask_bit_reference(uint8_t* start, uint8_t pos)
: start(start)
, pos(pos) {
assert(pos < 8 && "Bit reference can only be initialized for bytes");
};
mask_bit_reference::operator bool() const { return *(start) & (1U << (pos)); }
mask_bit_reference vmask_view::operator[](size_t idx) const {
assert(idx < elem_count); assert(idx < elem_count);
return *(start + idx / 8) & (1U << (idx % 8)); return {start + idx / 8, static_cast<uint8_t>(idx % 8)};
} }
vmask_view read_vmask(uint8_t* V, uint16_t VLEN, uint16_t elem_count, uint8_t reg_idx) { vmask_view read_vmask(uint8_t* V, uint16_t VLEN, uint16_t elem_count, uint8_t reg_idx) {

View File

@ -53,10 +53,20 @@ struct vtype_t {
bool vma(); bool vma();
bool vta(); bool vta();
}; };
class mask_bit_reference {
uint8_t* start;
uint8_t pos;
public:
mask_bit_reference& operator=(const bool new_value);
mask_bit_reference(uint8_t* start, uint8_t pos);
operator bool() const;
};
struct vmask_view { struct vmask_view {
uint8_t* start; uint8_t* start;
size_t elem_count; size_t elem_count;
bool operator[](size_t) const; mask_bit_reference operator[](size_t) const;
}; };
enum class carry_t { NO_CARRY = 0, ADD_CARRY = 1, SUB_CARRY = 2 }; enum class carry_t { NO_CARRY = 0, ADD_CARRY = 1, SUB_CARRY = 2 };
vmask_view read_vmask(uint8_t* V, uint16_t VLEN, uint16_t elem_count, uint8_t reg_idx = 0); vmask_view read_vmask(uint8_t* V, uint16_t VLEN, uint16_t elem_count, uint8_t reg_idx = 0);

View File

@ -490,29 +490,22 @@ void mask_vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_
vmask_view mask_reg = read_vmask<VLEN>(V, elem_count); vmask_view mask_reg = read_vmask<VLEN>(V, elem_count);
auto vs1_view = get_vreg<VLEN, elem_t>(V, vs1, elem_count); auto vs1_view = get_vreg<VLEN, elem_t>(V, vs1, elem_count);
auto vs2_view = get_vreg<VLEN, elem_t>(V, vs2, elem_count); auto vs2_view = get_vreg<VLEN, elem_t>(V, vs2, elem_count);
vmask_view vd_mask_view = read_vmask<VLEN>(V, elem_count, vd); vmask_view vd_mask_view = read_vmask<VLEN>(V, VLEN, vd);
auto fn = get_mask_funct<elem_t>(funct6, funct3); auto fn = get_mask_funct<elem_t>(funct6, funct3);
// elements w/ index smaller than vstart are in the prestart and get skipped // elements w/ index smaller than vstart are in the prestart and get skipped
// body is from vstart to min(elem_count, vl) // body is from vstart to min(elem_count, vl)
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
bool mask_active = vm ? 1 : mask_reg[idx]; bool mask_active = vm ? 1 : mask_reg[idx];
if(mask_active) { if(mask_active) {
bool new_bit_value = fn(vs2_view[idx], vs1_view[idx]); vd_mask_view[idx] = fn(vs2_view[idx], vs1_view[idx]);
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8;
unsigned cur_bit = idx % 8;
*cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast<unsigned>(new_bit_value) << cur_bit;
} else { } else {
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; vd_mask_view[idx] = vtype.vma() ? vd_mask_view[idx] : vd_mask_view[idx];
unsigned cur_bit = idx % 8;
*cur_mask_byte_addr = vtype.vma() ? *cur_mask_byte_addr : *cur_mask_byte_addr;
} }
} }
// elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than elem_count are in the tail (fractional LMUL)
// elements w/ index larger than vl are in the tail // elements w/ index larger than vl are in the tail
for(unsigned idx = std::min(elem_count, vl); idx < VLEN; idx++) { for(unsigned idx = std::min(elem_count, vl); idx < VLEN; idx++) {
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; vd_mask_view[idx] = vtype.vta() ? vd_mask_view[idx] : vd_mask_view[idx];
unsigned cur_bit = idx % 8;
*cur_mask_byte_addr = vtype.vta() ? *cur_mask_byte_addr : *cur_mask_byte_addr;
} }
return; return;
} }
@ -522,29 +515,22 @@ void mask_vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t v
uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew(); uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew();
vmask_view mask_reg = read_vmask<VLEN>(V, elem_count); vmask_view mask_reg = read_vmask<VLEN>(V, elem_count);
auto vs2_view = get_vreg<VLEN, elem_t>(V, vs2, elem_count); auto vs2_view = get_vreg<VLEN, elem_t>(V, vs2, elem_count);
vmask_view vd_mask_view = read_vmask<VLEN>(V, elem_count, vd); vmask_view vd_mask_view = read_vmask<VLEN>(V, VLEN, vd);
auto fn = get_mask_funct<elem_t>(funct6, funct3); auto fn = get_mask_funct<elem_t>(funct6, funct3);
// elements w/ index smaller than vstart are in the prestart and get skipped // elements w/ index smaller than vstart are in the prestart and get skipped
// body is from vstart to min(elem_count, vl) // body is from vstart to min(elem_count, vl)
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
bool mask_active = vm ? 1 : mask_reg[idx]; bool mask_active = vm ? 1 : mask_reg[idx];
if(mask_active) { if(mask_active) {
bool new_bit_value = fn(vs2_view[idx], imm); vd_mask_view[idx] = fn(vs2_view[idx], imm);
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8;
unsigned cur_bit = idx % 8;
*cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast<unsigned>(new_bit_value) << cur_bit;
} else { } else {
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; vd_mask_view[idx] = vtype.vma() ? vd_mask_view[idx] : vd_mask_view[idx];
unsigned cur_bit = idx % 8;
*cur_mask_byte_addr = vtype.vma() ? *cur_mask_byte_addr : *cur_mask_byte_addr;
} }
} }
// elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than elem_count are in the tail (fractional LMUL)
// elements w/ index larger than vl are in the tail // elements w/ index larger than vl are in the tail
for(unsigned idx = std::min(elem_count, vl); idx < VLEN; idx++) { for(unsigned idx = std::min(elem_count, vl); idx < VLEN; idx++) {
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8; vd_mask_view[idx] = vtype.vta() ? vd_mask_view[idx] : vd_mask_view[idx];
unsigned cur_bit = idx % 8;
*cur_mask_byte_addr = vtype.vta() ? *cur_mask_byte_addr : *cur_mask_byte_addr;
} }
return; return;
} }
@ -615,10 +601,7 @@ void carry_vector_vector_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vs
// body is from vstart to min(elem_count, vl) // body is from vstart to min(elem_count, vl)
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
elem_t carry = vm ? 0 : mask_reg[idx]; elem_t carry = vm ? 0 : mask_reg[idx];
bool new_bit_value = fn(vs2_view[idx], vs1_view[idx], carry); vd_mask_view[idx] = fn(vs2_view[idx], vs1_view[idx], carry);
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8;
unsigned cur_bit = idx % 8;
*cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast<unsigned>(new_bit_value) << cur_bit;
} }
// elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than elem_count are in the tail (fractional LMUL)
// elements w/ index larger than vl are in the tail // elements w/ index larger than vl are in the tail
@ -639,10 +622,7 @@ void carry_vector_imm_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vstar
// body is from vstart to min(elem_count, vl) // body is from vstart to min(elem_count, vl)
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) { for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
elem_t carry = vm ? 0 : mask_reg[idx]; elem_t carry = vm ? 0 : mask_reg[idx];
bool new_bit_value = fn(vs2_view[idx], imm, carry); vd_mask_view[idx] = fn(vs2_view[idx], imm, carry);
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8;
unsigned cur_bit = idx % 8;
*cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast<unsigned>(new_bit_value) << cur_bit;
} }
// elements w/ index larger than elem_count are in the tail (fractional LMUL) // elements w/ index larger than elem_count are in the tail (fractional LMUL)
// elements w/ index larger than vl are in the tail // elements w/ index larger than vl are in the tail
@ -1442,19 +1422,13 @@ void mask_mask_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uin
auto vd_view = read_vmask<VLEN>(V, elem_count, vd); auto vd_view = read_vmask<VLEN>(V, elem_count, vd);
auto fn = get_mask_funct<unsigned>(funct6, funct3); // could be bool, but would break the make_signed_t in get_mask_funct auto fn = get_mask_funct<unsigned>(funct6, funct3); // could be bool, but would break the make_signed_t in get_mask_funct
for(unsigned idx = vstart; idx < vl; idx++) { for(unsigned idx = vstart; idx < vl; idx++) {
unsigned new_bit_value = fn(vs2_view[idx], vs1_view[idx]); vd_view[idx] = fn(vs2_view[idx], vs1_view[idx]);
uint8_t* cur_mask_byte_addr = vd_view.start + idx / 8;
unsigned cur_bit = idx % 8;
*cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast<unsigned>(new_bit_value) << cur_bit;
} }
// the tail is all elements of the destination register beyond the first one // the tail is all elements of the destination register beyond the first one
for(unsigned idx = 1; idx < VLEN; idx++) { for(unsigned idx = 1; idx < VLEN; idx++) {
// always tail agnostic // always tail agnostic
// this is a nop, placeholder for vta behavior // this is a nop, placeholder for vta behavior
unsigned new_bit_value = vd_view[idx]; vd_view[idx] = vd_view[idx];
uint8_t* cur_mask_byte_addr = vd_view.start + idx / 8;
unsigned cur_bit = idx % 8;
*cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast<unsigned>(new_bit_value) << cur_bit;
} }
return; return;
} }
@ -1528,20 +1502,14 @@ template <unsigned VLEN> void mask_set_op(uint8_t* V, unsigned enc, uint64_t vl,
for(unsigned idx = vstart; idx < vl; idx++) { for(unsigned idx = vstart; idx < vl; idx++) {
bool mask_active = vm ? 1 : mask_reg[idx]; bool mask_active = vm ? 1 : mask_reg[idx];
if(mask_active) { if(mask_active) {
unsigned new_bit_value = fn(marker, vs2_view[idx]); vd_view[idx] = fn(marker, vs2_view[idx]);
uint8_t* cur_mask_byte_addr = vd_view.start + idx / 8;
unsigned cur_bit = idx % 8;
*cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast<unsigned>(new_bit_value) << cur_bit;
} }
} }
// the tail is all elements of the destination register beyond the first one // the tail is all elements of the destination register beyond the first one
for(unsigned idx = vl; idx < VLEN; idx++) { for(unsigned idx = vl; idx < VLEN; idx++) {
// always tail agnostic // always tail agnostic
// this is a nop, placeholder for vta behavior // this is a nop, placeholder for vta behavior
unsigned new_bit_value = vd_view[idx]; vd_view[idx] = vd_view[idx];
uint8_t* cur_mask_byte_addr = vd_view.start + idx / 8;
unsigned cur_bit = idx % 8;
*cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast<unsigned>(new_bit_value) << cur_bit;
} }
} }
template <unsigned VLEN, typename src_elem_t> template <unsigned VLEN, typename src_elem_t>