Advanced Multi-Physics (AMP)
On-Line Documentation
CSRMatrixSpGEMMDevice.h
Go to the documentation of this file.
1#ifndef included_AMP_CSRMatrixSpGEMMDevice
2#define included_AMP_CSRMatrixSpGEMMDevice
3
4#include "AMP/AMP_TPLs.h"
5#include "AMP/matrices/data/CSRMatrixCommunicator.h"
6#include "AMP/matrices/data/CSRMatrixData.h"
7#include "AMP/utils/AMP_MPI.h"
8#include "AMP/utils/Memory.h"
9
10#ifdef AMP_USE_CUDA
11 #include "AMP/matrices/operations/device/spgemm/cuda/SpGEMM_Cuda.h"
12#endif
13
14#ifdef AMP_USE_HIP
15 #include "AMP/matrices/operations/device/spgemm/hip/SpGEMM_Hip.h"
16#endif
17
18#include <map>
19#include <memory>
20#include <type_traits>
21#include <vector>
22
23namespace AMP::LinearAlgebra {
24
25template<typename Config>
27{
28public:
29 using allocator_type = typename Config::allocator_type;
30 using config_type = Config;
33 using lidx_t = typename Config::lidx_t;
34 using gidx_t = typename Config::gidx_t;
35 using scalar_t = typename Config::scalar_t;
36
37 static_assert( std::is_same_v<typename allocator_type::value_type, void> );
38
40 CSRMatrixSpGEMMDevice( std::shared_ptr<matrixdata_t> A_,
41 std::shared_ptr<matrixdata_t> B_,
42 std::shared_ptr<matrixdata_t> C_ )
43 : A( A_ ),
44 B( B_ ),
45 C( C_ ),
46 A_diag( A->getDiagMatrix() ),
47 A_offd( A->getOffdMatrix() ),
48 B_diag( B->getDiagMatrix() ),
49 B_offd( B->getOffdMatrix() ),
50 C_diag( C->getDiagMatrix() ),
51 C_offd( C->getOffdMatrix() ),
52 d_num_rows( static_cast<lidx_t>( A->numLocalRows() ) ),
53 comm( A->getComm() ),
54 d_csr_comm( A->getRightCommList() )
55 {
57 comm == B->getComm() && comm == C->getComm(),
58 "CSRMatrixSpGEMMDevice: All three matrices must have the same communicator" );
59 }
60
62
63 void multiply();
64
65 void multiply( std::shared_ptr<localmatrixdata_t> A_data,
66 std::shared_ptr<localmatrixdata_t> B_data,
67 std::shared_ptr<localmatrixdata_t> C_data );
68
69protected:
73
74 void merge( std::shared_ptr<localmatrixdata_t> inL,
75 std::shared_ptr<localmatrixdata_t> inR,
76 std::shared_ptr<localmatrixdata_t> out );
77
78 // Matrix data of operands and output
79 std::shared_ptr<matrixdata_t> A;
80 std::shared_ptr<matrixdata_t> B;
81 std::shared_ptr<matrixdata_t> C;
82
83 // diag and offd blocks of input matrices
84 std::shared_ptr<localmatrixdata_t> A_diag;
85 std::shared_ptr<localmatrixdata_t> A_offd;
86 std::shared_ptr<localmatrixdata_t> B_diag;
87 std::shared_ptr<localmatrixdata_t> B_offd;
88
89 // Matrix data formed from remote rows of B that get pulled to each process
90 std::shared_ptr<localmatrixdata_t> BR_diag;
91 std::shared_ptr<localmatrixdata_t> BR_offd;
92
93 // Blocks of C matrix
94 std::shared_ptr<localmatrixdata_t> C_diag;
95 std::shared_ptr<localmatrixdata_t> C_offd;
96
97 // number of local rows in A and C are the same, and many loops
98 // run over this range
100
101 // Communicator
104
105 // To overlap comms and calcs it is easiest to form the output in four
106 // blocks and merge them together at the end
107 std::shared_ptr<localmatrixdata_t> C_diag_diag; // from A_diag * B_diag
108 std::shared_ptr<localmatrixdata_t> C_diag_offd; // from A_diag * B_offd
109 std::shared_ptr<localmatrixdata_t> C_offd_diag; // from A_offd * BR_diag
110 std::shared_ptr<localmatrixdata_t> C_offd_offd; // from A_offd * BR_offd
111
112 // The following all support the communication needed to build BRemote
113 // these are worth preserving to allow repeated SpGEMMs to re-use the
114 // structure with potentially new coefficients
115
116 // struct to hold fields that are needed in both
117 // the "source" and "destination" perspectives
119 SpGEMMCommInfo() = default;
120 SpGEMMCommInfo( int numrow_ ) : numrow( numrow_ ) {}
121 // number of rows to send or receive
123 // ids of rows to send/receive
124 std::vector<gidx_t> rowids;
125 // number of non-zeros in those rows
126 std::vector<lidx_t> rownnz;
127 };
128
129 // Source information, things expected from other ranks
130 std::map<int, SpGEMMCommInfo> d_src_info;
131
132 // Destination information, things sent to other ranks
133 std::map<int, SpGEMMCommInfo> d_dest_info;
134
135 std::map<int, std::shared_ptr<localmatrixdata_t>> d_send_matrices;
136 std::map<int, std::shared_ptr<localmatrixdata_t>> d_recv_matrices;
137};
138
139} // namespace AMP::LinearAlgebra
140
141#endif
Provides C++ wrapper around MPI routines.
Definition AMP_MPI.h:63
CSRLocalMatrixData< Config > localmatrixdata_t
std::shared_ptr< localmatrixdata_t > A_offd
CSRMatrixSpGEMMDevice(std::shared_ptr< matrixdata_t > A_, std::shared_ptr< matrixdata_t > B_, std::shared_ptr< matrixdata_t > C_)
std::map< int, std::shared_ptr< localmatrixdata_t > > d_send_matrices
std::shared_ptr< localmatrixdata_t > C_offd_diag
void multiply(std::shared_ptr< localmatrixdata_t > A_data, std::shared_ptr< localmatrixdata_t > B_data, std::shared_ptr< localmatrixdata_t > C_data)
typename Config::allocator_type allocator_type
std::shared_ptr< localmatrixdata_t > C_diag_offd
std::map< int, std::shared_ptr< localmatrixdata_t > > d_recv_matrices
std::shared_ptr< localmatrixdata_t > A_diag
std::map< int, SpGEMMCommInfo > d_dest_info
std::shared_ptr< localmatrixdata_t > C_diag
std::shared_ptr< localmatrixdata_t > C_diag_diag
std::shared_ptr< localmatrixdata_t > C_offd_offd
std::shared_ptr< localmatrixdata_t > BR_offd
std::shared_ptr< localmatrixdata_t > BR_diag
std::shared_ptr< localmatrixdata_t > C_offd
std::shared_ptr< localmatrixdata_t > B_diag
std::shared_ptr< localmatrixdata_t > B_offd
void merge(std::shared_ptr< localmatrixdata_t > inL, std::shared_ptr< localmatrixdata_t > inR, std::shared_ptr< localmatrixdata_t > out)
typename matrixdata_t::localmatrixdata_t localmatrixdata_t
#define AMP_DEBUG_INSIST(EXP, MSG)
Insist error.
AMP_MPI getComm(const TYPE &obj)
Return the underlying MPI class for the object.



Advanced Multi-Physics (AMP)
Oak Ridge National Laboratory
Idaho National Laboratory
Los Alamos National Laboratory
This page automatically produced from the
source code by doxygen
Last updated: Tue Mar 10 2026 13:06:40.
Comments on this page