1#ifndef included_AMP_CSRMatrixSpGEMMDevice
2#define included_AMP_CSRMatrixSpGEMMDevice
4#include "AMP/AMP_TPLs.h"
5#include "AMP/matrices/data/CSRMatrixCommunicator.h"
6#include "AMP/matrices/data/CSRMatrixData.h"
7#include "AMP/utils/AMP_MPI.h"
8#include "AMP/utils/Memory.h"
11 #include "AMP/matrices/operations/device/spgemm/cuda/SpGEMM_Cuda.h"
15 #include "AMP/matrices/operations/device/spgemm/hip/SpGEMM_Hip.h"
25template<
typename Config>
33 using lidx_t =
typename Config::lidx_t;
34 using gidx_t =
typename Config::gidx_t;
37 static_assert( std::is_same_v<typename allocator_type::value_type, void> );
41 std::shared_ptr<matrixdata_t> B_,
42 std::shared_ptr<matrixdata_t> C_ )
58 "CSRMatrixSpGEMMDevice: All three matrices must have the same communicator" );
65 void multiply( std::shared_ptr<localmatrixdata_t> A_data,
66 std::shared_ptr<localmatrixdata_t> B_data,
67 std::shared_ptr<localmatrixdata_t> C_data );
74 void merge( std::shared_ptr<localmatrixdata_t> inL,
75 std::shared_ptr<localmatrixdata_t> inR,
76 std::shared_ptr<localmatrixdata_t> out );
79 std::shared_ptr<matrixdata_t>
A;
80 std::shared_ptr<matrixdata_t>
B;
81 std::shared_ptr<matrixdata_t>
C;
84 std::shared_ptr<localmatrixdata_t>
A_diag;
85 std::shared_ptr<localmatrixdata_t>
A_offd;
86 std::shared_ptr<localmatrixdata_t>
B_diag;
87 std::shared_ptr<localmatrixdata_t>
B_offd;
90 std::shared_ptr<localmatrixdata_t>
BR_diag;
91 std::shared_ptr<localmatrixdata_t>
BR_offd;
94 std::shared_ptr<localmatrixdata_t>
C_diag;
95 std::shared_ptr<localmatrixdata_t>
C_offd;
Provides C++ wrapper around MPI routines.
CSRLocalMatrixData< Config > localmatrixdata_t
std::shared_ptr< localmatrixdata_t > A_offd
CSRMatrixSpGEMMDevice(std::shared_ptr< matrixdata_t > A_, std::shared_ptr< matrixdata_t > B_, std::shared_ptr< matrixdata_t > C_)
std::map< int, std::shared_ptr< localmatrixdata_t > > d_send_matrices
std::shared_ptr< matrixdata_t > A
std::map< int, SpGEMMCommInfo > d_src_info
typename Config::scalar_t scalar_t
~CSRMatrixSpGEMMDevice()=default
CSRMatrixCommunicator< Config > d_csr_comm
std::shared_ptr< localmatrixdata_t > C_offd_diag
void multiply(std::shared_ptr< localmatrixdata_t > A_data, std::shared_ptr< localmatrixdata_t > B_data, std::shared_ptr< localmatrixdata_t > C_data)
typename Config::allocator_type allocator_type
std::shared_ptr< matrixdata_t > B
std::shared_ptr< localmatrixdata_t > C_diag_offd
std::map< int, std::shared_ptr< localmatrixdata_t > > d_recv_matrices
std::shared_ptr< localmatrixdata_t > A_diag
typename Config::lidx_t lidx_t
std::map< int, SpGEMMCommInfo > d_dest_info
std::shared_ptr< localmatrixdata_t > C_diag
typename Config::gidx_t gidx_t
std::shared_ptr< matrixdata_t > C
std::shared_ptr< localmatrixdata_t > C_diag_diag
std::shared_ptr< localmatrixdata_t > C_offd_offd
CSRMatrixSpGEMMDevice()=default
std::shared_ptr< localmatrixdata_t > BR_offd
std::shared_ptr< localmatrixdata_t > BR_diag
std::shared_ptr< localmatrixdata_t > C_offd
std::shared_ptr< localmatrixdata_t > B_diag
std::shared_ptr< localmatrixdata_t > B_offd
void merge(std::shared_ptr< localmatrixdata_t > inL, std::shared_ptr< localmatrixdata_t > inR, std::shared_ptr< localmatrixdata_t > out)
typename matrixdata_t::localmatrixdata_t localmatrixdata_t
#define AMP_DEBUG_INSIST(EXP, MSG)
Insist error.
AMP_MPI getComm(const TYPE &obj)
Return the underlying MPI class for the object.
SpGEMMCommInfo(int numrow_)
std::vector< lidx_t > rownnz
std::vector< gidx_t > rowids