1#ifndef included_AMP_CSRMatrixSpGEMMDefault
2#define included_AMP_CSRMatrixSpGEMMDefault
4#include "AMP/matrices/data/CSRMatrixCommunicator.h"
5#include "AMP/matrices/data/CSRMatrixData.h"
6#include "AMP/utils/AMP_MPI.h"
14template<
typename Config>
22 using lidx_t =
typename Config::lidx_t;
23 using gidx_t =
typename Config::gidx_t;
26 static_assert( std::is_same_v<typename allocator_type::value_type, void> );
44 std::shared_ptr<matrixdata_t> B_,
45 std::shared_ptr<matrixdata_t> C_,
64 "CSRMatrixSpGEMMDefault: All three matrices must have the same communicator" );
82 template<Mode mode_t, BlockType block_t>
83 void multiply( std::shared_ptr<localmatrixdata_t> A_data,
84 std::shared_ptr<localmatrixdata_t> B_data,
85 std::shared_ptr<localmatrixdata_t> C_data );
87 template<Mode mode_t, BlockType block_t>
89 std::shared_ptr<localmatrixdata_t> BR_data,
90 std::shared_ptr<localmatrixdata_t> C_data );
92 template<BlockType block_t>
94 std::shared_ptr<localmatrixdata_t> B_data,
95 std::shared_ptr<localmatrixdata_t> C_data );
108 std::shared_ptr<matrixdata_t>
A;
109 std::shared_ptr<matrixdata_t>
B;
110 std::shared_ptr<matrixdata_t>
C;
113 std::shared_ptr<localmatrixdata_t>
A_diag;
114 std::shared_ptr<localmatrixdata_t>
A_offd;
115 std::shared_ptr<localmatrixdata_t>
B_diag;
116 std::shared_ptr<localmatrixdata_t>
B_offd;
123 std::shared_ptr<localmatrixdata_t>
C_diag;
124 std::shared_ptr<localmatrixdata_t>
C_offd;
172 template<
typename col_t>
185 static_assert( std::is_same_v<col_t, gidx_t> || std::is_same_v<col_t, lidx_t> );
194 static constexpr bool IsGlobal = std::is_same_v<gidx_t, col_t>;
209 template<
typename col_t>
223 static_assert( std::is_same_v<col_t, gidx_t> || std::is_same_v<col_t, lidx_t> );
226 uint16_t
hash( col_t col_idx )
const;
233 static constexpr bool IsGlobal = std::is_same_v<gidx_t, col_t>;
Provides C++ wrapper around MPI routines.
CSRLocalMatrixData< Config > localmatrixdata_t
typename Config::lidx_t lidx_t
std::shared_ptr< localmatrixdata_t > A_diag
typename Config::scalar_t scalar_t
void symbolicMultiply_NonOverlapped()
CSRMatrixCommunicator< Config > d_csr_comm
std::shared_ptr< matrixdata_t > C
std::shared_ptr< localmatrixdata_t > BR_offd
void symbolicMultiply_Overlapped()
typename Config::gidx_t gidx_t
void numericMultiplyReuse()
std::shared_ptr< localmatrixdata_t > C_diag_diag
std::shared_ptr< localmatrixdata_t > BR_diag
static constexpr lidx_t SPACC_SIZE
void numericMultiply_NonOverlapped()
std::shared_ptr< localmatrixdata_t > C_offd
std::map< int, std::shared_ptr< localmatrixdata_t > > d_send_matrices
std::shared_ptr< matrixdata_t > B
std::map< int, SpGEMMCommInfo > d_src_info
void multiplyFused(std::shared_ptr< localmatrixdata_t > B_data, std::shared_ptr< localmatrixdata_t > BR_data, std::shared_ptr< localmatrixdata_t > C_data)
~CSRMatrixSpGEMMDefault()=default
std::shared_ptr< localmatrixdata_t > B_offd
void multiply(std::shared_ptr< localmatrixdata_t > A_data, std::shared_ptr< localmatrixdata_t > B_data, std::shared_ptr< localmatrixdata_t > C_data)
std::map< int, std::shared_ptr< localmatrixdata_t > > d_recv_matrices
std::map< int, SpGEMMCommInfo > d_dest_info
std::shared_ptr< localmatrixdata_t > C_diag_offd
typename Config::allocator_type allocator_type
void numericMultiply_Overlapped()
std::shared_ptr< localmatrixdata_t > C_offd_offd
std::shared_ptr< localmatrixdata_t > A_offd
typename matrixdata_t::localmatrixdata_t localmatrixdata_t
std::shared_ptr< matrixdata_t > A
std::shared_ptr< localmatrixdata_t > B_diag
std::shared_ptr< localmatrixdata_t > C_diag
std::shared_ptr< localmatrixdata_t > C_offd_diag
CSRMatrixSpGEMMDefault(std::shared_ptr< matrixdata_t > A_, std::shared_ptr< matrixdata_t > B_, std::shared_ptr< matrixdata_t > C_, bool overlap_comms_)
void multiplyReuse(std::shared_ptr< localmatrixdata_t > A_data, std::shared_ptr< localmatrixdata_t > B_data, std::shared_ptr< localmatrixdata_t > C_data)
#define AMP_DEBUG_ASSERT(EXP)
Assert error (debug only)
#define AMP_DEBUG_INSIST(EXP, MSG)
Insist error.
AMP_MPI getComm(const TYPE &obj)
Return the underlying MPI class for the object.
static constexpr bool IsGlobal
void insert_or_append(col_t col_idx)
void set_flag(col_t col_idx, lidx_t k)
std::vector< lidx_t > flags
lidx_t contains(col_t col_idx) const
std::vector< col_t > cols
std::vector< lidx_t > flag_inv
DenseAccumulator(int capacity_, gidx_t offset_)
void insert_or_append(col_t col_idx, scalar_t val, col_t *col_space, scalar_t *val_space)
std::vector< lidx_t > rownnz
std::vector< gidx_t > rowids
SpGEMMCommInfo(int numrow_)
static constexpr bool IsGlobal
lidx_t contains(col_t col_idx) const
void insert_or_append(col_t col_idx)
void set_flag(col_t col_idx, lidx_t k)
std::vector< uint16_t > flags
SparseAccumulator(int capacity_, gidx_t offset_)
uint16_t hash(col_t col_idx) const
std::vector< col_t > cols
void grow(col_t *col_space)
void insert_or_append(col_t col_idx, scalar_t val, col_t *col_space, scalar_t *val_space)