adds widening reductions

This commit is contained in:
2025-02-19 15:02:30 +01:00
parent f049d8cbb3
commit 63889b02e7
2 changed files with 30 additions and 9 deletions

View File

@ -798,8 +798,14 @@ template <typename dest_elem_t, typename src_elem_t>
std::function<void(dest_elem_t&, src_elem_t)> get_red_funct(unsigned funct6, unsigned funct3) {
if(funct3 == OPIVV || funct3 == OPIVX || funct3 == OPIVI)
switch(funct6) {
// case 0b110000: // VWREDSUMU
// case 0b110001: // VWREDSUM
case 0b110000: // VWREDSUMU
return [](dest_elem_t& running_total, src_elem_t vs2) { return running_total += static_cast<dest_elem_t>(vs2); };
case 0b110001: // VWREDSUM
return [](dest_elem_t& running_total, src_elem_t vs2) {
// cast the signed vs2 elem to unsigned to enable wraparound on overflow
return running_total += static_cast<dest_elem_t>(
static_cast<std::make_signed_t<dest_elem_t>>(static_cast<std::make_signed_t<src_elem_t>>(vs2)));
};
default:
throw new std::runtime_error("Unknown funct6 in get_funct");
}
@ -814,18 +820,20 @@ std::function<void(dest_elem_t&, src_elem_t)> get_red_funct(unsigned funct6, uns
case 0b000011: // VREDXOR
return [](dest_elem_t& running_total, src_elem_t vs2) { running_total ^= vs2; };
case 0b000100: // VREDMINU
return [](dest_elem_t& running_total, src_elem_t vs2) { running_total = std::min(running_total, vs2); };
return
[](dest_elem_t& running_total, src_elem_t vs2) { running_total = std::min(running_total, static_cast<dest_elem_t>(vs2)); };
case 0b000101: // VREDMIN
return [](dest_elem_t& running_total, src_elem_t vs2) {
running_total =
std::min(static_cast<std::make_signed_t<dest_elem_t>>(running_total), static_cast<std::make_signed_t<src_elem_t>>(vs2));
running_total = std::min(static_cast<std::make_signed_t<dest_elem_t>>(running_total),
static_cast<std::make_signed_t<dest_elem_t>>(static_cast<std::make_signed_t<src_elem_t>>(vs2)));
};
case 0b000110: // VREDMAXU
return [](dest_elem_t& running_total, src_elem_t vs2) { running_total = std::max(running_total, vs2); };
return
[](dest_elem_t& running_total, src_elem_t vs2) { running_total = std::max(running_total, static_cast<dest_elem_t>(vs2)); };
case 0b000111: // VREDMAX
return [](dest_elem_t& running_total, src_elem_t vs2) {
running_total =
std::max(static_cast<std::make_signed_t<dest_elem_t>>(running_total), static_cast<std::make_signed_t<src_elem_t>>(vs2));
running_total = std::max(static_cast<std::make_signed_t<dest_elem_t>>(running_total),
static_cast<std::make_signed_t<dest_elem_t>>(static_cast<std::make_signed_t<src_elem_t>>(vs2)));
};
default:
throw new std::runtime_error("Unknown funct6 in get_funct");
@ -840,7 +848,7 @@ void vector_red_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, ui
return;
uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew();
vmask_view mask_reg = read_vmask<VLEN>(V, elem_count);
auto vs1_elem = get_vreg<VLEN, src_elem_t>(V, vs1, elem_count)[0];
auto vs1_elem = get_vreg<VLEN, dest_elem_t>(V, vs1, elem_count)[0];
auto vs2_view = get_vreg<VLEN, src_elem_t>(V, vs2, elem_count);
auto vd_view = get_vreg<VLEN, dest_elem_t>(V, vd, elem_count);
auto fn = get_red_funct<dest_elem_t, src_elem_t>(funct6, funct3);