561 lines
28 KiB
C++
561 lines
28 KiB
C++
////////////////////////////////////////////////////////////////////////////////
|
|
// Copyright (C) 2025, MINRES Technologies GmbH
|
|
// All rights reserved.
|
|
//
|
|
// Redistribution and use in source and binary forms, with or without
|
|
// modification, are permitted provided that the following conditions are met:
|
|
//
|
|
// 1. Redistributions of source code must retain the above copyright notice,
|
|
// this list of conditions and the following disclaimer.
|
|
//
|
|
// 2. Redistributions in binary form must reproduce the above copyright notice,
|
|
// this list of conditions and the following disclaimer in the documentation
|
|
// and/or other materials provided with the distribution.
|
|
//
|
|
// 3. Neither the name of the copyright holder nor the names of its contributors
|
|
// may be used to endorse or promote products derived from this software
|
|
// without specific prior written permission.
|
|
//
|
|
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
|
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
|
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
|
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
|
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
|
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
|
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
|
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
|
// POSSIBILITY OF SUCH DAMAGE.
|
|
//
|
|
// Contributors:
|
|
// alex@minres.com - initial API and implementation
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
#pragma once
|
|
#include "vm/vector_functions.h"
|
|
#include <functional>
|
|
#include <limits>
|
|
#include <stdexcept>
|
|
#include <type_traits>
|
|
#ifndef _VM_VECTOR_FUNCTIONS_H_
|
|
#error __FILE__ should only be included from vector_functions.h
|
|
#endif
|
|
#include <boost/integer.hpp>
|
|
#include <math.h>
|
|
namespace softvector {
|
|
|
|
template <typename elem_t> struct vreg_view {
|
|
uint8_t* start;
|
|
size_t elem_count;
|
|
inline elem_t& get(size_t idx = 0) {
|
|
assert(idx < elem_count);
|
|
return *(reinterpret_cast<elem_t*>(start) + idx);
|
|
}
|
|
elem_t& operator[](size_t idx) {
|
|
assert(idx < elem_count);
|
|
return *(reinterpret_cast<elem_t*>(start) + idx);
|
|
}
|
|
};
|
|
|
|
template <unsigned VLEN, typename elem_t> vreg_view<elem_t> get_vreg(uint8_t* V, uint8_t reg_idx, uint16_t elem_count) {
|
|
assert(V + elem_count * sizeof(elem_t) <= V + VLEN * RFS / 8);
|
|
return {V + VLEN / 8 * reg_idx, elem_count};
|
|
}
|
|
template <unsigned VLEN> vmask_view read_vmask(uint8_t* V, uint16_t elem_count, uint8_t reg_idx) {
|
|
uint8_t* mask_start = V + VLEN / 8 * reg_idx;
|
|
assert(mask_start + elem_count / 8 <= V + VLEN * RFS / 8);
|
|
return {mask_start, elem_count};
|
|
}
|
|
|
|
template <typename elem_t> constexpr elem_t shift_mask() {
|
|
static_assert(std::numeric_limits<elem_t>::is_integer, "shift_mask only supports integer types");
|
|
return std::numeric_limits<elem_t>::digits - 1;
|
|
}
|
|
enum FUNCT3 {
|
|
OPIVV = 0b000,
|
|
OPFVV = 0b001,
|
|
OPMVV = 0b010,
|
|
OPIVI = 0b011,
|
|
OPIVX = 0b100,
|
|
OPFVF = 0b101,
|
|
OPMVX = 0b110,
|
|
};
|
|
template <class, typename enable = void> struct twice;
|
|
template <> struct twice<int8_t> { using type = int16_t; };
|
|
template <> struct twice<uint8_t> { using type = uint16_t; };
|
|
template <> struct twice<int16_t> { using type = int32_t; };
|
|
template <> struct twice<uint16_t> { using type = uint32_t; };
|
|
template <> struct twice<int32_t> { using type = int64_t; };
|
|
template <> struct twice<uint32_t> { using type = uint64_t; };
|
|
#ifdef __SIZEOF_INT128__
|
|
template <> struct twice<int64_t> { using type = __int128_t; };
|
|
template <> struct twice<uint64_t> { using type = __uint128_t; };
|
|
#endif
|
|
template <class T> using twice_t = typename twice<T>::type; // for convenience
|
|
|
|
template <typename dest_elem_t, typename src2_elem_t = dest_elem_t, typename src1_elem_t = dest_elem_t>
|
|
std::function<dest_elem_t(src2_elem_t, src1_elem_t)> get_funct(unsigned funct6, unsigned funct3) {
|
|
if(funct3 == OPIVV || funct3 == OPIVX || funct3 == OPIVI)
|
|
switch(funct6) {
|
|
case 0b000000: // VADD
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; };
|
|
case 0b000010: // VSUB
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 - vs1; };
|
|
case 0b000011: // VRSUB
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return vs1 - vs2; };
|
|
case 0b000100: // VMINU
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return std::min(vs2, static_cast<src2_elem_t>(vs1)); };
|
|
case 0b000101: // VMIN
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) {
|
|
return std::min(static_cast<std::make_signed_t<src2_elem_t>>(vs2), static_cast<std::make_signed_t<src2_elem_t>>(vs1));
|
|
};
|
|
case 0b000110: // VMAXU
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return std::max(vs2, static_cast<src2_elem_t>(vs1)); };
|
|
case 0b000111: // VMAX
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) {
|
|
return std::max(static_cast<std::make_signed_t<src2_elem_t>>(vs2), static_cast<std::make_signed_t<src2_elem_t>>(vs1));
|
|
};
|
|
case 0b001001: // VAND
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return vs1 & vs2; };
|
|
case 0b001010: // VOR
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return vs1 | vs2; };
|
|
case 0b001011: // VXOR
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return vs1 ^ vs2; };
|
|
// case 0b001100: // VRGATHER
|
|
// case 0b001110: // VRGATHEREI16
|
|
// case 0b001111: // VLSLIDEDOWN
|
|
case 0b010000: // VADC
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; };
|
|
case 0b010010: // VSBC
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) {
|
|
return static_cast<std::make_signed_t<dest_elem_t>>(static_cast<std::make_signed_t<src2_elem_t>>(vs2) -
|
|
static_cast<std::make_signed_t<src1_elem_t>>(vs1));
|
|
};
|
|
// case 0b010111: // VMERGE / VMV
|
|
// case 0b100000: // VSADDU
|
|
// case 0b100001: // VSADD
|
|
// case 0b100010: // VSSUBU
|
|
// case 0b100011: // VSSUB
|
|
case 0b100101: // VSLL
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 << (vs1 & shift_mask<src2_elem_t>()); };
|
|
// case 0b100111: // VSMUL
|
|
// case 0b100111: // VMV<NR>R
|
|
case 0b101000: // VSRL
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 >> (vs1 & shift_mask<src2_elem_t>()); };
|
|
case 0b101001: // VSRA
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) {
|
|
return static_cast<std::make_signed_t<src2_elem_t>>(vs2) >> (vs1 & shift_mask<src2_elem_t>());
|
|
};
|
|
case 0b101100: // VNSRL
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 >> (vs1 & shift_mask<src2_elem_t>()); };
|
|
case 0b101101: // VNSRA
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) {
|
|
return static_cast<std::make_signed_t<src2_elem_t>>(vs2) >> (vs1 & shift_mask<src2_elem_t>());
|
|
};
|
|
// case 0b101110: // VNCLIPU
|
|
// case 0b101111: // VNCLIP
|
|
// case 0b110000: // VWREDSUMU
|
|
// case 0b110001: // VWREDSUM
|
|
default:
|
|
throw new std::runtime_error("Uknown funct6 in get_funct");
|
|
}
|
|
else if(funct3 == OPMVV || funct3 == OPMVX)
|
|
switch(funct6) {
|
|
// case 0b000000: // VREDSUM
|
|
// case 0b000001: // VREDAND
|
|
// case 0b000010: // VREDOR
|
|
// case 0b000011: // VREDXOR
|
|
// case 0b000100: // VREDMINU
|
|
// case 0b000101: // VREDMIN
|
|
// case 0b000110: // VREDMAXU
|
|
// case 0b000111: // VREDMAX
|
|
// case 0b001000: // VAADDU
|
|
// case 0b001001: // VAADD
|
|
// case 0b001010: // VASUBU
|
|
// case 0b001011: // VASUB
|
|
// case 0b001110: // VSLID1EUP
|
|
// case 0b001111: // VSLIDE1DOWN
|
|
// case 0b010111: // VCOMPRESS
|
|
// case 0b011000: // VMANDN
|
|
// case 0b011001: // VMAND
|
|
// case 0b011010: // VMOR
|
|
// case 0b011011: // VMXOR
|
|
// case 0b011100: // VMORN
|
|
// case 0b011101: // VMNAND
|
|
// case 0b011110: // VMNOR
|
|
// case 0b011111: // VMXNOR
|
|
case 0b100000: // VDIVU
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t {
|
|
if(vs1 == 0)
|
|
return -1;
|
|
else
|
|
return vs2 / vs1;
|
|
};
|
|
case 0b100001: // VDIV
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t {
|
|
if(vs1 == 0)
|
|
return -1;
|
|
else if(vs2 == std::numeric_limits<std::make_signed_t<src2_elem_t>>::min() &&
|
|
static_cast<std::make_signed_t<src1_elem_t>>(vs1) == -1)
|
|
return vs2;
|
|
else
|
|
return static_cast<std::make_signed_t<src2_elem_t>>(vs2) / static_cast<std::make_signed_t<src1_elem_t>>(vs1);
|
|
};
|
|
case 0b100010: // VREMU
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t {
|
|
if(vs1 == 0)
|
|
return vs2;
|
|
else
|
|
return vs2 % vs1;
|
|
};
|
|
case 0b100011: // VREM
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) -> dest_elem_t {
|
|
if(vs1 == 0)
|
|
return vs2;
|
|
else if(vs2 == std::numeric_limits<std::make_signed_t<src2_elem_t>>::min() &&
|
|
static_cast<std::make_signed_t<src1_elem_t>>(vs1) == -1)
|
|
return 0;
|
|
else
|
|
return static_cast<std::make_signed_t<src2_elem_t>>(vs2) % static_cast<std::make_signed_t<src1_elem_t>>(vs1);
|
|
};
|
|
case 0b100100: // VMULHU
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) {
|
|
return (static_cast<twice_t<src2_elem_t>>(vs2) * static_cast<twice_t<src2_elem_t>>(vs1)) >> sizeof(dest_elem_t) * 8;
|
|
};
|
|
case 0b100101: // VMUL
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) {
|
|
return static_cast<std::make_signed_t<src2_elem_t>>(vs2) * static_cast<std::make_signed_t<src1_elem_t>>(vs1);
|
|
};
|
|
case 0b100110: // VMULHSU
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) {
|
|
return (static_cast<twice_t<std::make_signed_t<src2_elem_t>>>(static_cast<std::make_signed_t<src2_elem_t>>(vs2)) *
|
|
static_cast<twice_t<src2_elem_t>>(vs1)) >>
|
|
sizeof(dest_elem_t) * 8;
|
|
};
|
|
case 0b100111: // VMULH
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) {
|
|
return (static_cast<twice_t<std::make_signed_t<src2_elem_t>>>(static_cast<std::make_signed_t<src2_elem_t>>(vs2)) *
|
|
static_cast<twice_t<std::make_signed_t<src2_elem_t>>>(static_cast<std::make_signed_t<src1_elem_t>>(vs1))) >>
|
|
sizeof(dest_elem_t) * 8;
|
|
};
|
|
// case 0b101001: // VMADD
|
|
// case 0b101011: // VNMSUB
|
|
// case 0b101101: // VMACC
|
|
// case 0b101111: // VNMSAC
|
|
case 0b110000: // VWADDU
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; };
|
|
case 0b110001: // VWADD
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) {
|
|
return static_cast<std::make_signed_t<dest_elem_t>>(static_cast<std::make_signed_t<src2_elem_t>>(vs2) +
|
|
static_cast<std::make_signed_t<src1_elem_t>>(vs1));
|
|
};
|
|
case 0b110010: // VWSUBU
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 - vs1; };
|
|
case 0b110011: // VWSUB
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) {
|
|
return static_cast<std::make_signed_t<dest_elem_t>>(static_cast<std::make_signed_t<src2_elem_t>>(vs2) -
|
|
static_cast<std::make_signed_t<src1_elem_t>>(vs1));
|
|
};
|
|
case 0b110100: // VWADDU.W
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 + vs1; };
|
|
case 0b110101: // VWADD.W
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) {
|
|
return static_cast<std::make_signed_t<dest_elem_t>>(static_cast<std::make_signed_t<src2_elem_t>>(vs2) +
|
|
static_cast<std::make_signed_t<src1_elem_t>>(vs1));
|
|
};
|
|
case 0b110110: // VWSUBU.W
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) { return vs2 - vs1; };
|
|
case 0b110111: // VWSUB.W
|
|
return [](src2_elem_t vs2, src1_elem_t vs1) {
|
|
return static_cast<std::make_signed_t<dest_elem_t>>(static_cast<std::make_signed_t<src2_elem_t>>(vs2) -
|
|
static_cast<std::make_signed_t<src1_elem_t>>(vs1));
|
|
};
|
|
// case 0b111000: // VWMULU
|
|
// case 0b111010: // VWMULSU
|
|
// case 0b111011: // VWMUL
|
|
// case 0b111100: // VWMACCU
|
|
// case 0b111101: // VWMACC
|
|
// case 0b111110: // VWMACCUS
|
|
// case 0b111111: // VWMACCSU
|
|
|
|
default:
|
|
throw new std::runtime_error("Uknown funct6 in get_funct");
|
|
}
|
|
else
|
|
throw new std::runtime_error("Unknown funct3 in get_funct");
|
|
}
|
|
template <unsigned VLEN, typename dest_elem_t, typename src2_elem_t, typename src1_elem_t>
|
|
void vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd,
|
|
unsigned vs2, unsigned vs1, carry_t carry) {
|
|
uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew();
|
|
vmask_view mask_reg = read_vmask<VLEN>(V, elem_count);
|
|
auto vs1_view = get_vreg<VLEN, src1_elem_t>(V, vs1, elem_count);
|
|
auto vs2_view = get_vreg<VLEN, src2_elem_t>(V, vs2, elem_count);
|
|
auto vd_view = get_vreg<VLEN, dest_elem_t>(V, vd, elem_count);
|
|
auto fn = get_funct<dest_elem_t, src2_elem_t, src1_elem_t>(funct6, funct3);
|
|
// elements w/ index smaller than vstart are in the prestart and get skipped
|
|
// body is from vstart to min(elem_count, vl)
|
|
if(carry == carry_t::NO_CARRY) {
|
|
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
|
|
bool mask_active = vm ? 1 : mask_reg[idx];
|
|
if(mask_active) {
|
|
auto res = fn(vs2_view[idx], vs1_view[idx]);
|
|
vd_view[idx] = res;
|
|
} else {
|
|
vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx];
|
|
}
|
|
}
|
|
} else if(carry == carry_t::SUB_CARRY) {
|
|
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
|
|
vd_view[idx] = fn(vs2_view[idx], vs1_view[idx]) - mask_reg[idx];
|
|
}
|
|
} else {
|
|
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
|
|
vd_view[idx] = fn(vs2_view[idx], vs1_view[idx]) + mask_reg[idx];
|
|
}
|
|
}
|
|
// elements w/ index larger than elem_count are in the tail (fractional LMUL)
|
|
// elements w/ index larger than vl are in the tail
|
|
unsigned maximum_elems = VLEN * vtype.lmul() / (sizeof(dest_elem_t) * 8);
|
|
for(unsigned idx = std::min(elem_count, vl); idx < maximum_elems; idx++) {
|
|
vd_view[idx] = vtype.vta() ? vd_view[idx] : vd_view[idx];
|
|
}
|
|
return;
|
|
}
|
|
template <unsigned VLEN, typename dest_elem_t, typename src2_elem_t, typename src1_elem_t>
|
|
void vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd,
|
|
unsigned vs2, typename std::make_signed<src1_elem_t>::type imm, carry_t carry) {
|
|
uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew();
|
|
vmask_view mask_reg = read_vmask<VLEN>(V, elem_count);
|
|
auto vs2_view = get_vreg<VLEN, src2_elem_t>(V, vs2, elem_count);
|
|
auto vd_view = get_vreg<VLEN, dest_elem_t>(V, vd, elem_count);
|
|
auto fn = get_funct<dest_elem_t, src2_elem_t, src1_elem_t>(funct6, funct3);
|
|
// elements w/ index smaller than vstart are in the prestart and get skipped
|
|
// body is from vstart to min(elem_count, vl)
|
|
if(carry == carry_t::NO_CARRY) {
|
|
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
|
|
bool mask_active = vm ? 1 : mask_reg[idx];
|
|
if(mask_active) {
|
|
vd_view[idx] = fn(vs2_view[idx], imm);
|
|
} else {
|
|
vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx];
|
|
}
|
|
}
|
|
} else if(carry == carry_t::SUB_CARRY) {
|
|
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
|
|
auto val1 = fn(vs2_view[idx], imm);
|
|
auto val2 = static_cast<std::make_signed_t<dest_elem_t>>(mask_reg[idx]);
|
|
auto diff = val1 - val2;
|
|
vd_view[idx] = fn(vs2_view[idx], imm) - mask_reg[idx];
|
|
}
|
|
} else {
|
|
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
|
|
vd_view[idx] = fn(vs2_view[idx], imm) + mask_reg[idx];
|
|
}
|
|
}
|
|
// elements w/ index larger than elem_count are in the tail (fractional LMUL)
|
|
// elements w/ index larger than vl are in the tail
|
|
unsigned maximum_elems = VLEN * vtype.lmul() / (sizeof(dest_elem_t) * 8);
|
|
for(unsigned idx = std::min(elem_count, vl); idx < maximum_elems; idx++) {
|
|
vd_view[idx] = vtype.vta() ? vd_view[idx] : vd_view[idx];
|
|
}
|
|
return;
|
|
}
|
|
template <typename elem_t> std::function<bool(elem_t, elem_t)> get_mask_funct(unsigned funct) {
|
|
switch(funct) {
|
|
case 0b011000: // VMSEQ
|
|
return [](elem_t vs2, elem_t vs1) { return vs2 == vs1; };
|
|
case 0b011001: // VMSNE
|
|
return [](elem_t vs2, elem_t vs1) { return vs2 != vs1; };
|
|
case 0b011010: // VMSLTU
|
|
return [](elem_t vs2, elem_t vs1) { return vs2 < vs1; };
|
|
case 0b011011: // VMSLT
|
|
return [](elem_t vs2, elem_t vs1) {
|
|
return static_cast<std::make_signed_t<elem_t>>(vs2) < static_cast<std::make_signed_t<elem_t>>(vs1);
|
|
};
|
|
case 0b011100: // VMSLEU
|
|
return [](elem_t vs2, elem_t vs1) { return vs2 <= vs1; };
|
|
case 0b011101: // VMSLE
|
|
return [](elem_t vs2, elem_t vs1) {
|
|
return static_cast<std::make_signed_t<elem_t>>(vs2) <= static_cast<std::make_signed_t<elem_t>>(vs1);
|
|
};
|
|
case 0b011110: // VMSGTU
|
|
return [](elem_t vs2, elem_t vs1) { return vs2 > vs1; };
|
|
case 0b011111: // VMSGT
|
|
return [](elem_t vs2, elem_t vs1) {
|
|
return static_cast<std::make_signed_t<elem_t>>(vs2) > static_cast<std::make_signed_t<elem_t>>(vs1);
|
|
};
|
|
|
|
default:
|
|
throw new std::runtime_error("Uknown funct in get_mask_funct");
|
|
}
|
|
}
|
|
template <unsigned VLEN, typename elem_t>
|
|
void mask_vector_vector_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd,
|
|
unsigned vs2, unsigned vs1) {
|
|
uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew();
|
|
vmask_view mask_reg = read_vmask<VLEN>(V, elem_count);
|
|
auto vs1_view = get_vreg<VLEN, elem_t>(V, vs1, elem_count);
|
|
auto vs2_view = get_vreg<VLEN, elem_t>(V, vs2, elem_count);
|
|
vmask_view vd_mask_view = read_vmask<VLEN>(V, elem_count, vd);
|
|
auto fn = get_mask_funct<elem_t>(funct6);
|
|
// elements w/ index smaller than vstart are in the prestart and get skipped
|
|
// body is from vstart to min(elem_count, vl)
|
|
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
|
|
bool mask_active = vm ? 1 : mask_reg[idx];
|
|
if(mask_active) {
|
|
bool new_bit_value = fn(vs2_view[idx], vs1_view[idx]);
|
|
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8;
|
|
unsigned cur_bit = idx % 8;
|
|
*cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast<unsigned>(new_bit_value) << cur_bit;
|
|
} else {
|
|
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8;
|
|
unsigned cur_bit = idx % 8;
|
|
*cur_mask_byte_addr = vtype.vma() ? *cur_mask_byte_addr : *cur_mask_byte_addr;
|
|
}
|
|
}
|
|
// elements w/ index larger than elem_count are in the tail (fractional LMUL)
|
|
// elements w/ index larger than vl are in the tail
|
|
for(unsigned idx = std::min(elem_count, vl); idx < VLEN; idx++) {
|
|
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8;
|
|
unsigned cur_bit = idx % 8;
|
|
*cur_mask_byte_addr = vtype.vta() ? *cur_mask_byte_addr : *cur_mask_byte_addr;
|
|
}
|
|
return;
|
|
}
|
|
template <unsigned VLEN, typename elem_t>
|
|
void mask_vector_imm_op(uint8_t* V, unsigned funct6, unsigned funct3, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd,
|
|
unsigned vs2, typename std::make_signed<elem_t>::type imm) {
|
|
uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew();
|
|
vmask_view mask_reg = read_vmask<VLEN>(V, elem_count);
|
|
auto vs2_view = get_vreg<VLEN, elem_t>(V, vs2, elem_count);
|
|
vmask_view vd_mask_view = read_vmask<VLEN>(V, elem_count, vd);
|
|
auto fn = get_mask_funct<elem_t>(funct6);
|
|
// elements w/ index smaller than vstart are in the prestart and get skipped
|
|
// body is from vstart to min(elem_count, vl)
|
|
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
|
|
bool mask_active = vm ? 1 : mask_reg[idx];
|
|
if(mask_active) {
|
|
bool new_bit_value = fn(vs2_view[idx], imm);
|
|
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8;
|
|
unsigned cur_bit = idx % 8;
|
|
*cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast<unsigned>(new_bit_value) << cur_bit;
|
|
} else {
|
|
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8;
|
|
unsigned cur_bit = idx % 8;
|
|
*cur_mask_byte_addr = vtype.vma() ? *cur_mask_byte_addr : *cur_mask_byte_addr;
|
|
}
|
|
}
|
|
// elements w/ index larger than elem_count are in the tail (fractional LMUL)
|
|
// elements w/ index larger than vl are in the tail
|
|
for(unsigned idx = std::min(elem_count, vl); idx < VLEN; idx++) {
|
|
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8;
|
|
unsigned cur_bit = idx % 8;
|
|
*cur_mask_byte_addr = vtype.vta() ? *cur_mask_byte_addr : *cur_mask_byte_addr;
|
|
}
|
|
return;
|
|
}
|
|
template <typename dest_elem_t, typename src2_elem_t = dest_elem_t>
|
|
std::function<dest_elem_t(src2_elem_t)> get_unary_fn(unsigned unary_op) {
|
|
switch(unary_op) {
|
|
case 0b00111: // vsext.vf2
|
|
case 0b00101: // vsext.vf4
|
|
case 0b00011: // vsext.vf8
|
|
return [](src2_elem_t vs2) { return static_cast<std::make_signed_t<src2_elem_t>>(vs2); };
|
|
case 0b00110: // vzext.vf2
|
|
case 0b00100: // vzext.vf4
|
|
case 0b00010: // vzext.vf8
|
|
return [](src2_elem_t vs2) { return vs2; };
|
|
default:
|
|
throw new std::runtime_error("Uknown funct in get_unary_fn");
|
|
}
|
|
}
|
|
template <unsigned VLEN, typename dest_elem_t, typename src2_elem_t>
|
|
void vector_unary_op(uint8_t* V, unsigned unary_op, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2) {
|
|
uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew();
|
|
vmask_view mask_reg = read_vmask<VLEN>(V, elem_count);
|
|
auto vs2_view = get_vreg<VLEN, src2_elem_t>(V, vs2, elem_count);
|
|
auto vd_view = get_vreg<VLEN, dest_elem_t>(V, vd, elem_count);
|
|
auto fn = get_unary_fn<dest_elem_t, src2_elem_t>(unary_op);
|
|
// elements w/ index smaller than vstart are in the prestart and get skipped
|
|
// body is from vstart to min(elem_count, vl)
|
|
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
|
|
bool mask_active = vm ? 1 : mask_reg[idx];
|
|
if(mask_active) {
|
|
vd_view[idx] = fn(vs2_view[idx]);
|
|
} else {
|
|
vd_view[idx] = vtype.vma() ? vd_view[idx] : vd_view[idx];
|
|
}
|
|
}
|
|
// elements w/ index larger than elem_count are in the tail (fractional LMUL)
|
|
// elements w/ index larger than vl are in the tail
|
|
unsigned maximum_elems = VLEN * vtype.lmul() / (sizeof(dest_elem_t) * 8);
|
|
for(unsigned idx = std::min(elem_count, vl); idx < maximum_elems; idx++) {
|
|
vd_view[idx] = vtype.vta() ? vd_view[idx] : vd_view[idx];
|
|
}
|
|
return;
|
|
}
|
|
template <typename elem_t> std::function<bool(elem_t, elem_t, elem_t)> get_carry_funct(unsigned funct) {
|
|
switch(funct) {
|
|
case 0b010001: // VMADC
|
|
return [](elem_t vs2, elem_t vs1, elem_t carry) {
|
|
return static_cast<elem_t>(vs2 + vs1 + carry) < std::max(vs1, vs2) || static_cast<elem_t>(vs2 + vs1) < std::max(vs1, vs2);
|
|
};
|
|
case 0b010011: // VMSBC
|
|
return [](elem_t vs2, elem_t vs1, elem_t carry) {
|
|
return vs2 < static_cast<elem_t>(vs1 + carry) || (vs1 == std::numeric_limits<elem_t>::max() && carry);
|
|
};
|
|
default:
|
|
throw new std::runtime_error("Uknown funct in get_carry_funct");
|
|
}
|
|
}
|
|
template <unsigned VLEN, typename elem_t>
|
|
void carry_vector_vector_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2,
|
|
unsigned vs1) {
|
|
uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew();
|
|
vmask_view mask_reg = read_vmask<VLEN>(V, elem_count);
|
|
auto vs1_view = get_vreg<VLEN, elem_t>(V, vs1, elem_count);
|
|
auto vs2_view = get_vreg<VLEN, elem_t>(V, vs2, elem_count);
|
|
vmask_view vd_mask_view = read_vmask<VLEN>(V, elem_count, vd);
|
|
auto fn = get_carry_funct<elem_t>(funct);
|
|
// elements w/ index smaller than vstart are in the prestart and get skipped
|
|
// body is from vstart to min(elem_count, vl)
|
|
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
|
|
elem_t carry = vm ? 0 : mask_reg[idx];
|
|
bool new_bit_value = fn(vs2_view[idx], vs1_view[idx], carry);
|
|
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8;
|
|
unsigned cur_bit = idx % 8;
|
|
*cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast<unsigned>(new_bit_value) << cur_bit;
|
|
}
|
|
// elements w/ index larger than elem_count are in the tail (fractional LMUL)
|
|
// elements w/ index larger than vl are in the tail
|
|
for(unsigned idx = std::min(elem_count, vl); idx < VLEN; idx++) {
|
|
// always tail agnostic
|
|
}
|
|
return;
|
|
}
|
|
template <unsigned VLEN, typename elem_t>
|
|
void carry_vector_imm_op(uint8_t* V, unsigned funct, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, unsigned vd, unsigned vs2,
|
|
typename std::make_signed<elem_t>::type imm) {
|
|
uint64_t elem_count = VLEN * vtype.lmul() / vtype.sew();
|
|
vmask_view mask_reg = read_vmask<VLEN>(V, elem_count);
|
|
auto vs2_view = get_vreg<VLEN, elem_t>(V, vs2, elem_count);
|
|
vmask_view vd_mask_view = read_vmask<VLEN>(V, elem_count, vd);
|
|
auto fn = get_carry_funct<elem_t>(funct);
|
|
// elements w/ index smaller than vstart are in the prestart and get skipped
|
|
// body is from vstart to min(elem_count, vl)
|
|
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
|
|
elem_t carry = vm ? 0 : mask_reg[idx];
|
|
bool new_bit_value = fn(vs2_view[idx], imm, carry);
|
|
uint8_t* cur_mask_byte_addr = vd_mask_view.start + idx / 8;
|
|
unsigned cur_bit = idx % 8;
|
|
*cur_mask_byte_addr = *cur_mask_byte_addr & ~(1U << cur_bit) | static_cast<unsigned>(new_bit_value) << cur_bit;
|
|
}
|
|
// elements w/ index larger than elem_count are in the tail (fractional LMUL)
|
|
// elements w/ index larger than vl are in the tail
|
|
for(unsigned idx = std::min(elem_count, vl); idx < VLEN; idx++) {
|
|
// always tail agnostic
|
|
}
|
|
return;
|
|
}
|
|
} // namespace softvector
|