DBT-RISE-TGC/src/vm/vector_functions.cpp

222 lines
12 KiB
C++

////////////////////////////////////////////////////////////////////////////////
// Copyright (C) 2025, MINRES Technologies GmbH
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// 1. Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// 3. Neither the name of the copyright holder nor the names of its contributors
// may be used to endorse or promote products derived from this software
// without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Contributors:
// alex@minres.com - initial API and implementation
////////////////////////////////////////////////////////////////////////////////
#include "vector_functions.h"
#include "iss/vm_types.h"
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <limits>
#include <math.h>
#include <stdexcept>
namespace softvector {
unsigned RFS = 32;
bool softvec_read(void* core, uint64_t addr, uint64_t length, uint8_t* data) {
// Read length bytes from addr into *data
iss::status status = static_cast<iss::arch_if*>(core)->read(iss::address_type::PHYSICAL, iss::access_type::READ,
0 /*traits<ARCH>::MEM*/, addr, length, data);
return status == iss::Ok;
}
bool softvec_write(void* core, uint64_t addr, uint64_t length, uint8_t* data) {
// Write length bytes from addr into *data
iss::status status = static_cast<iss::arch_if*>(core)->write(iss::address_type::PHYSICAL, iss::access_type::READ,
0 /*traits<ARCH>::MEM*/, addr, length, data);
return status == iss::Ok;
}
using vlen_t = uint64_t;
struct vreg_view {
uint8_t* start;
size_t size;
template <typename T> T& get(size_t idx = 0) {
assert((idx * sizeof(T)) <= size);
return *(reinterpret_cast<T*>(start) + idx);
}
};
vtype_t::vtype_t(uint32_t vtype_val) { underlying = (vtype_val & 0x8000) << 32 | (vtype_val & ~0x8000); }
vtype_t::vtype_t(uint64_t vtype_val) { underlying = vtype_val; }
bool vtype_t::vill() { return underlying >> 31; }
bool vtype_t::vma() { return (underlying >> 7) & 1; }
bool vtype_t::vta() { return (underlying >> 6) & 1; }
unsigned vtype_t::sew() {
uint8_t vsew = (underlying >> 3) & 0b111;
// pow(2, 3 + vsew);
return 1 << (3 + vsew);
}
double vtype_t::lmul() {
uint8_t vlmul = underlying & 0b111;
assert(vlmul != 0b100); // reserved encoding
int8_t signed_vlmul = (vlmul >> 2) ? 0b11111000 | vlmul : vlmul;
return pow(2, signed_vlmul);
}
vreg_view read_vmask(uint8_t* V, uint16_t VLEN, uint16_t elem_count, uint8_t reg_idx) {
uint8_t* mask_start = V + VLEN / 8 * reg_idx;
return {mask_start, elem_count / 8u}; // this can return size==0 as elem_count can be as low as 1
}
uint64_t vector_load_store(void* core, std::function<bool(void*, uint64_t, uint64_t, uint8_t*)> load_store_fn, uint8_t* V, uint16_t VLEN,
uint8_t addressed_register, uint64_t base_addr, uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm,
uint8_t elem_size_byte, uint64_t elem_count, int8_t EMUL_pow, uint8_t segment_size, int64_t stride) {
// eew = elem_size_byte * 8
assert(pow(2, EMUL_pow) * segment_size <= 8);
assert(segment_size > 0);
assert((elem_count & (elem_count - 1)) == 0); // check that elem_count is power of 2
assert(elem_count <= VLEN * RFS / 8);
unsigned emul_stride = EMUL_pow <= 0 ? 1 : pow(2, EMUL_pow);
assert(emul_stride * segment_size <= 8);
assert(!(addressed_register % emul_stride));
vreg_view mask_view = read_vmask(V, VLEN, elem_count, 0);
// elements w/ index smaller than vstart are in the prestart and get skipped
// body is from vstart to min(elem_count, vl)
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
unsigned trap_idx = idx;
uint8_t current_mask_byte = mask_view.get<uint8_t>(idx / 8);
bool mask_active = vm ? 1 : current_mask_byte & (1 << idx % 8);
if(mask_active) {
for(unsigned s_idx = 0; s_idx < segment_size; s_idx++) {
// base + selected register + current_elem + current_segment
uint8_t* addressed_elem = V + (addressed_register * VLEN / 8) + (elem_size_byte * idx) + (VLEN / 8 * s_idx * emul_stride);
assert(addressed_elem <= V + VLEN * RFS / 8);
uint64_t addr = base_addr + (elem_size_byte) * (idx * segment_size + s_idx) * stride;
if(!load_store_fn(core, addr, elem_size_byte, addressed_elem))
return trap_idx;
}
} else {
for(unsigned s_idx = 0; s_idx < segment_size; s_idx++) {
// base + selected register + current_elem + current_segment
uint8_t* addressed_elem = V + (addressed_register * VLEN / 8) + (elem_size_byte * idx) + (VLEN / 8 * s_idx * emul_stride);
assert(addressed_elem <= V + VLEN * RFS / 8);
// this only updates the first 8 bits, so eew > 8 would not work correctly
*addressed_elem = vtype.vma() ? *addressed_elem : *addressed_elem;
}
}
}
// elements w/ index larger than elem_count are in the tail (fractional LMUL)
// elements w/ index larger than vl are in the tail
for(unsigned idx = std::min(elem_count, vl); idx < VLEN / 8; idx++) {
for(unsigned s_idx = 0; s_idx < segment_size; s_idx++) {
// base + selected register + current_elem + current_segment
uint8_t* addressed_elem = V + (addressed_register * VLEN / 8) + (elem_size_byte * idx) + (VLEN / 8 * s_idx * emul_stride);
assert(addressed_elem <= V + VLEN * RFS / 8);
// this only updates the first 8 bits, so eew > 8 would not work correctly
*addressed_elem = vtype.vta() ? *addressed_elem : *addressed_elem;
}
}
return 0;
}
int64_t read_n_bits(uint8_t* V, unsigned n) {
switch(n) {
case 8:
return static_cast<int64_t>(*reinterpret_cast<int8_t*>(V));
case 16:
return static_cast<int64_t>(*reinterpret_cast<int16_t*>(V));
case 32:
return static_cast<int64_t>(*reinterpret_cast<int32_t*>(V));
case 64:
return static_cast<int64_t>(*reinterpret_cast<int64_t*>(V));
default:
throw new std::invalid_argument("Invalid arg in read_n_bits");
}
}
// this function behaves similar to vector_load_store(...) with the key difference that the SEW and LMUL from the parameters apply to the
// index registers (instead of the data registers) and the SEW and LMUL encoded in vtype apply to the data registers
uint64_t vector_load_store_index(void* core, std::function<bool(void*, uint64_t, uint64_t, uint8_t*)> load_store_fn, uint8_t* V,
uint16_t VLEN, uint8_t XLEN, uint8_t addressed_register, uint8_t index_register, uint64_t base_addr,
uint64_t vl, uint64_t vstart, vtype_t vtype, bool vm, uint8_t index_elem_size_byte, uint64_t elem_count,
uint8_t segment_size, bool ordered) {
// index_eew = index_elem_size_byte * 8
// for now ignore the ordered parameter, as all indexed operations are implementes as ordered
assert(segment_size > 0);
assert((elem_count & (elem_count - 1)) == 0); // check that elem_count is power of 2
assert(elem_count <= VLEN * RFS / 8);
unsigned data_emul_stride = vtype.lmul() < 0 ? 0 : vtype.lmul();
assert(data_emul_stride * segment_size <= 8);
unsigned data_elem_size_byte = vtype.sew() / 8;
assert(!(addressed_register % data_emul_stride));
vreg_view mask_view = read_vmask(V, VLEN, elem_count, 0);
// elements w/ index smaller than vstart are in the prestart and get skipped
// body is from vstart to min(elem_count, vl)
for(unsigned idx = vstart; idx < std::min(elem_count, vl); idx++) {
unsigned trap_idx = idx;
uint8_t current_mask_byte = mask_view.get<uint8_t>(idx / 8);
bool mask_active = vm ? 1 : current_mask_byte & (1 << idx % 8);
if(mask_active) {
uint8_t* offset_elem = V + (index_register * VLEN / 8) + (index_elem_size_byte * idx);
assert(offset_elem <= (V + VLEN * RFS / 8 - index_elem_size_byte)); // ensure reading index_elem_size_bytes is legal
// read sew bits from offset_elem truncate / extend to XLEN bits
int64_t offset_val = read_n_bits(offset_elem, index_elem_size_byte * 8);
assert(XLEN == 64 | XLEN == 32);
uint64_t mask = XLEN == 64 ? std::numeric_limits<uint64_t>::max() : std::numeric_limits<uint32_t>::max();
unsigned index_offset = offset_val & mask;
for(unsigned s_idx = 0; s_idx < segment_size; s_idx++) {
// base + selected register + current_elem + current_segment
uint8_t* addressed_elem =
V + (addressed_register * VLEN / 8) + (data_elem_size_byte * idx) + (VLEN / 8 * s_idx * data_emul_stride);
assert(addressed_elem <= V + VLEN * RFS / 8);
// base + offset + current_segment
uint64_t addr = base_addr + index_offset + s_idx * data_elem_size_byte;
if(!load_store_fn(core, addr, data_elem_size_byte, addressed_elem))
return trap_idx;
}
} else {
for(unsigned s_idx = 0; s_idx < segment_size; s_idx++) {
// base + selected register + current_elem + current_segment
uint8_t* addressed_elem =
V + (addressed_register * VLEN / 8) + (data_elem_size_byte * idx) + (VLEN / 8 * s_idx * data_emul_stride);
assert(addressed_elem <= V + VLEN * RFS / 8);
// this only updates the first 8 bits, so eew > 8 would not work correctly
*addressed_elem = vtype.vma() ? *addressed_elem : *addressed_elem;
}
}
}
// elements w/ index larger than elem_count are in the tail (fractional LMUL)
// elements w/ index larger than vl are in the tail
for(unsigned idx = std::min(elem_count, vl); idx < VLEN / 8; idx++) {
for(unsigned s_idx = 0; s_idx < segment_size; s_idx++) {
// base + selected register + current_elem + current_segment
uint8_t* addressed_elem =
V + (addressed_register * VLEN / 8) + (data_elem_size_byte * idx) + (VLEN / 8 * s_idx * data_emul_stride);
assert(addressed_elem <= V + VLEN * RFS / 8);
// this only updates the first 8 bits, so eew > 8 would not work correctly
*addressed_elem = vtype.vta() ? *addressed_elem : *addressed_elem;
}
}
return 0;
}
} // namespace softvector