adds widening reductions
This commit is contained in:
parent
f049d8cbb3
commit
63889b02e7
@ -481,6 +481,19 @@ if(vector != null) {%>
|
|||||||
throw new std::runtime_error("Unsupported sew bit value");
|
throw new std::runtime_error("Unsupported sew bit value");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
void vector_red_wv(uint8_t* V, uint8_t funct6, uint8_t funct3, uint64_t vl, uint64_t vstart, softvector::vtype_t vtype, bool vm, uint8_t vd, uint8_t vs2, uint8_t vs1, uint8_t sew_val){
|
||||||
|
switch(sew_val){
|
||||||
|
case 0b000:
|
||||||
|
return softvector::vector_red_op<${vlen}, uint16_t, uint8_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1);
|
||||||
|
case 0b001:
|
||||||
|
return softvector::vector_red_op<${vlen}, uint32_t, uint16_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1);
|
||||||
|
case 0b010:
|
||||||
|
return softvector::vector_red_op<${vlen}, uint64_t, uint32_t>(V, funct6, funct3, vl, vstart, vtype, vm, vd, vs2, vs1);
|
||||||
|
case 0b011: // would require 128 bits vs2 value
|
||||||
|
default:
|
||||||
|
throw new std::runtime_error("Unsupported sew bit value");
|
||||||
|
}
|
||||||
|
}
|
||||||
<%}%>
|
<%}%>
|
||||||
uint64_t fetch_count{0};
|
uint64_t fetch_count{0};
|
||||||
uint64_t tval{0};
|
uint64_t tval{0};
|
||||||
|
@ -798,8 +798,14 @@ template <typename dest_elem_t, typename src_elem_t>
|
|||||||
std::function<void(dest_elem_t&, src_elem_t)> get_red_funct(unsigned funct6, unsigned funct3) {
|
std::function<void(dest_elem_t&, src_elem_t)> get_red_funct(unsigned funct6, unsigned funct3) {
|
||||||
if(funct3 == OPIVV || funct3 == OPIVX || funct3 == OPIVI)
|
if(funct3 == OPIVV || funct3 == OPIVX || funct3 == OPIVI)
|
||||||
switch(funct6) {
|
switch(funct6) {
|
||||||
// case 0b110000: // VWREDSUMU
|
case 0b110000: // VWREDSUMU
|
||||||
// case 0b110001: // VWREDSUM
|
return [](dest_elem_t& running_total, src_elem_t vs2) { return running_total += static_cast<dest_elem_t>(vs2); };
|
||||||
|
case 0b110001: // VWREDSUM
|
||||||
|
return [](dest_elem_t& running_total, src_elem_t vs2) {
|
||||||
|
// cast the signed vs2 elem to unsigned to enable wraparound on overflow
|
||||||
|
return running_total += static_cast<dest_elem_t>(
|
||||||
|
static_cast<std::make_signed_t<dest_elem_t>>(static_cast<std::make_signed_t<src_elem_t>>(vs2)));
|
||||||
|
};
|
||||||
default:
|
default:
|
||||||
throw new std::runtime_error("Unknown funct6 in get_funct");
|
throw new std::runtime_error("Unknown funct6 in get_funct");
|
||||||
}
|
}
|
||||||
@ -814,18 +820,20 @@ std::function<void(dest_elem_t&, src_elem_t)> get_red_funct(unsigned funct6, uns
|
|||||||
case 0b000011: // VREDXOR
|
case 0b000011: // VREDXOR
|
||||||
return [](dest_elem_t& running_total, src_elem_t vs2) { running_total ^= vs2; };
|
return [](dest_elem_t& running_total, src_elem_t vs2) { running_total ^= vs2; };
|
||||||
case 0b000100: // VREDMINU
|
case 0b000100: // VREDMINU
|
||||||
return [](dest_elem_t& running_total, src_elem_t vs2) { running_total = std::min(running_total, vs2); };
|
return
|
||||||
|
[](dest_elem_t& running_total, src_elem_t vs2) { running_total = std::min(running_total, static_cast<dest_elem_t>(vs2)); };
|
||||||
case 0b000101: // VREDMIN
|
case 0b000101: // VREDMIN
|
||||||
return [](dest_elem_t& running_total, src_elem_t vs2) {
|
return [](dest_elem_t& running_total, src_elem_t vs2) {
|
||||||
running_total =
|
running_total = std::min(static_cast<std::make_signed_t<dest_elem_t>>(running_total),
|
||||||
std::min(static_cast<std::make_signed_t<dest_elem_t>>(running_total), static_cast<std::make_signed_t<src_elem_t>>(vs2));
|
static_cast<std::make_signed_t<dest_elem_t>>(static_cast<std::make_signed_t<src_elem_t>>(vs2)));
|
||||||
};
|
};
|
||||||
case 0b000110: // VREDMAXU
|
case 0b000110: // VREDMAXU
|
||||||
return [](dest_elem_t& running_total, src_elem_t vs2) { running_total = std::max(running_total, vs2); };
|
return
|
||||||
|
[](dest_elem_t& running_total, src_elem_t vs2) { running_total = std::max(running_total, static_cast<dest_elem_t>(vs2)); };
|
||||||
case 0b000111: // VREDMAX
|
case 0b000111: // VREDMAX
|
||||||
return [](dest_elem_t& running_total, src_elem_t vs2) {
|
return [](dest_elem_t& running_total, src_elem_t vs2) {
|
||||||
running_total =
|
running_total = std::max(static_cast<std::make_signed_t<dest_elem_t>>(running_total),
|
||||||
std::max(static_cast<std::make_signed_t<dest_elem_t>>(running_total), static_cast<std::make_signed_t<src_elem_t>>(vs2));
|
static_cast<std::make_signed_t<dest_elem_t>>(static_cast<std::make_signed_t<src_elem_t>>(vs2)));
|
||||||
};
|
};
|
||||||
default:
|
default:
|
||||||
throw new std::runtime_error("Unknown funct6 in get_funct");
|
throw new std::runtime_error("Unknown funct6 in get_funct");
|
||||||
@ -840,7 +848,7 @@ void vector_red_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, ui
|
|||||||
return;
|
return;
|
||||||
uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew();
|
uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew();
|
||||||
vmask_view mask_reg = read_vmask<VLEN>(V, elem_count);
|
vmask_view mask_reg = read_vmask<VLEN>(V, elem_count);
|
||||||
auto vs1_elem = get_vreg<VLEN, src_elem_t>(V, vs1, elem_count)[0];
|
auto vs1_elem = get_vreg<VLEN, dest_elem_t>(V, vs1, elem_count)[0];
|
||||||
auto vs2_view = get_vreg<VLEN, src_elem_t>(V, vs2, elem_count);
|
auto vs2_view = get_vreg<VLEN, src_elem_t>(V, vs2, elem_count);
|
||||||
auto vd_view = get_vreg<VLEN, dest_elem_t>(V, vd, elem_count);
|
auto vd_view = get_vreg<VLEN, dest_elem_t>(V, vd, elem_count);
|
||||||
auto fn = get_red_funct<dest_elem_t, src_elem_t>(funct6, funct3);
|
auto fn = get_red_funct<dest_elem_t, src_elem_t>(funct6, funct3);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user