Advanced Multi-Physics (AMP)
On-Line Documentation
CSRMatrixSpGEMMDefault.h
Go to the documentation of this file.
1#ifndef included_AMP_CSRMatrixSpGEMMDefault
2#define included_AMP_CSRMatrixSpGEMMDefault
3
4#include "AMP/matrices/data/CSRMatrixCommunicator.h"
5#include "AMP/matrices/data/CSRMatrixData.h"
6#include "AMP/utils/AMP_MPI.h"
7
8#include <map>
9#include <memory>
10#include <vector>
11
12namespace AMP::LinearAlgebra {
13
14template<typename Config>
16{
17public:
18 using allocator_type = typename Config::allocator_type;
19 using config_type = Config;
22 using lidx_t = typename Config::lidx_t;
23 using gidx_t = typename Config::gidx_t;
24 using scalar_t = typename Config::scalar_t;
25
26 static_assert( std::is_same_v<typename allocator_type::value_type, void> );
27
29 : A( nullptr ),
30 B( nullptr ),
31 C( nullptr ),
32 A_diag( nullptr ),
33 A_offd( nullptr ),
34 B_diag( nullptr ),
35 B_offd( nullptr ),
36 C_diag( nullptr ),
37 C_offd( nullptr ),
38 d_overlap_comms( false ),
39 d_num_rows( 0 ),
40 d_need_comms( false )
41 {
42 }
43 CSRMatrixSpGEMMDefault( std::shared_ptr<matrixdata_t> A_,
44 std::shared_ptr<matrixdata_t> B_,
45 std::shared_ptr<matrixdata_t> C_,
46 bool overlap_comms_ )
47 : A( A_ ),
48 B( B_ ),
49 C( C_ ),
50 A_diag( A->getDiagMatrix() ),
51 A_offd( A->getOffdMatrix() ),
52 B_diag( B->getDiagMatrix() ),
53 B_offd( B->getOffdMatrix() ),
54 C_diag( C->getDiagMatrix() ),
55 C_offd( C->getOffdMatrix() ),
56 d_overlap_comms( ( A->getComm().getSize() > 1 ) && overlap_comms_ ),
57 d_num_rows( static_cast<lidx_t>( A->numLocalRows() ) ),
58 comm( A->getComm() ),
59 d_csr_comm( A->getRightCommList() ),
60 d_need_comms( true )
61 {
63 comm == B->getComm() && comm == C->getComm(),
64 "CSRMatrixSpGEMMDefault: All three matrices must have the same communicator" );
65 }
66
68
72
73protected:
78
79 enum class Mode { SYMBOLIC, NUMERIC };
80 enum class BlockType { DIAG, OFFD };
81
82 template<Mode mode_t, BlockType block_t>
83 void multiply( std::shared_ptr<localmatrixdata_t> A_data,
84 std::shared_ptr<localmatrixdata_t> B_data,
85 std::shared_ptr<localmatrixdata_t> C_data );
86
87 template<Mode mode_t, BlockType block_t>
88 void multiplyFused( std::shared_ptr<localmatrixdata_t> B_data,
89 std::shared_ptr<localmatrixdata_t> BR_data,
90 std::shared_ptr<localmatrixdata_t> C_data );
91
92 template<BlockType block_t>
93 void multiplyReuse( std::shared_ptr<localmatrixdata_t> A_data,
94 std::shared_ptr<localmatrixdata_t> B_data,
95 std::shared_ptr<localmatrixdata_t> C_data );
96
100
101 void mergeDiag();
102 void mergeOffd();
103
104 // default starting size for sparse accumulators
105 static constexpr lidx_t SPACC_SIZE = 256;
106
107 // Matrix data of operands and output
108 std::shared_ptr<matrixdata_t> A;
109 std::shared_ptr<matrixdata_t> B;
110 std::shared_ptr<matrixdata_t> C;
111
112 // diag and offd blocks of input matrices
113 std::shared_ptr<localmatrixdata_t> A_diag;
114 std::shared_ptr<localmatrixdata_t> A_offd;
115 std::shared_ptr<localmatrixdata_t> B_diag;
116 std::shared_ptr<localmatrixdata_t> B_offd;
117
118 // Matrix data formed from remote rows of B that get pulled to each process
119 std::shared_ptr<localmatrixdata_t> BR_diag;
120 std::shared_ptr<localmatrixdata_t> BR_offd;
121
122 // Blocks of C matrix
123 std::shared_ptr<localmatrixdata_t> C_diag;
124 std::shared_ptr<localmatrixdata_t> C_offd;
125
126 // flag for whether overlapped communication/computation should be done
128
129 // number of local rows in A and C are the same, and many loops
130 // run over this range
132
133 // Communicator
137
138 // To overlap comms and calcs it is easiest to form the output in four
139 // blocks and merge them together at the end
140 std::shared_ptr<localmatrixdata_t> C_diag_diag; // from A_diag * B_diag
141 std::shared_ptr<localmatrixdata_t> C_diag_offd; // from A_diag * B_offd
142 std::shared_ptr<localmatrixdata_t> C_offd_diag; // from A_offd * BR_diag
143 std::shared_ptr<localmatrixdata_t> C_offd_offd; // from A_offd * BR_offd
144
145 // The following all support the communication needed to build BRemote
146 // these are worth preserving to allow repeated SpGEMMs to re-use the
147 // structure with potentially new coefficients
148
149 // struct to hold fields that are needed in both
150 // the "source" and "destination" perspectives
153 SpGEMMCommInfo( int numrow_ ) : numrow( numrow_ ) {}
154 // number of rows to send or receive
156 // ids of rows to send/receive
157 std::vector<gidx_t> rowids;
158 // number of non-zeros in those rows
159 std::vector<lidx_t> rownnz;
160 };
161
162 // Source information, things expected from other ranks
163 std::map<int, SpGEMMCommInfo> d_src_info;
164
165 // Destination information, things sent to other ranks
166 std::map<int, SpGEMMCommInfo> d_dest_info;
167
168 std::map<int, std::shared_ptr<localmatrixdata_t>> d_send_matrices;
169 std::map<int, std::shared_ptr<localmatrixdata_t>> d_recv_matrices;
170
171 // Internal row accumlator classes
172 template<typename col_t>
174 DenseAccumulator( int capacity_, gidx_t offset_ )
175 : capacity( capacity_ ),
176 offset( offset_ ),
177 num_inserted( 0 ),
178 total_inserted( 0 ),
179 total_collisions( 0 ),
181 total_clears( 0 ),
182 total_grows( 0 ),
183 flags( capacity, -1 )
184 {
185 static_assert( std::is_same_v<col_t, gidx_t> || std::is_same_v<col_t, lidx_t> );
186 }
187
188 void insert_or_append( col_t col_idx );
189 void insert_or_append( col_t col_idx, scalar_t val, col_t *col_space, scalar_t *val_space );
190 void clear();
191 lidx_t contains( col_t col_idx ) const;
192 void set_flag( col_t col_idx, lidx_t k );
193
194 static constexpr bool IsGlobal = std::is_same_v<gidx_t, col_t>;
195
197 const col_t offset;
204 std::vector<lidx_t> flags;
205 std::vector<lidx_t> flag_inv;
206 std::vector<col_t> cols;
207 };
208
209 template<typename col_t>
211 SparseAccumulator( int capacity_, gidx_t offset_ )
212 : capacity( capacity_ ),
213 offset( offset_ ),
214 num_inserted( 0 ),
215 total_inserted( 0 ),
216 total_collisions( 0 ),
218 total_clears( 0 ),
219 total_grows( 0 ),
220 flags( capacity, 0xFFFF )
221 {
223 static_assert( std::is_same_v<col_t, gidx_t> || std::is_same_v<col_t, lidx_t> );
224 }
225
226 uint16_t hash( col_t col_idx ) const;
227 void insert_or_append( col_t col_idx );
228 void insert_or_append( col_t col_idx, scalar_t val, col_t *col_space, scalar_t *val_space );
229 void clear();
230 lidx_t contains( col_t col_idx ) const;
231 void set_flag( col_t col_idx, lidx_t k );
232
233 static constexpr bool IsGlobal = std::is_same_v<gidx_t, col_t>;
234
235 uint16_t capacity;
237 uint16_t num_inserted;
243 std::vector<uint16_t> flags;
244 std::vector<col_t> cols;
245
246 private:
247 void grow( col_t *col_space );
248 };
249};
250
251} // namespace AMP::LinearAlgebra
252
253#endif
Provides C++ wrapper around MPI routines.
Definition AMP_MPI.h:63
CSRLocalMatrixData< Config > localmatrixdata_t
std::shared_ptr< localmatrixdata_t > A_diag
std::shared_ptr< localmatrixdata_t > BR_offd
std::shared_ptr< localmatrixdata_t > C_diag_diag
std::shared_ptr< localmatrixdata_t > BR_diag
std::shared_ptr< localmatrixdata_t > C_offd
std::map< int, std::shared_ptr< localmatrixdata_t > > d_send_matrices
void multiplyFused(std::shared_ptr< localmatrixdata_t > B_data, std::shared_ptr< localmatrixdata_t > BR_data, std::shared_ptr< localmatrixdata_t > C_data)
std::shared_ptr< localmatrixdata_t > B_offd
void multiply(std::shared_ptr< localmatrixdata_t > A_data, std::shared_ptr< localmatrixdata_t > B_data, std::shared_ptr< localmatrixdata_t > C_data)
std::map< int, std::shared_ptr< localmatrixdata_t > > d_recv_matrices
std::shared_ptr< localmatrixdata_t > C_diag_offd
typename Config::allocator_type allocator_type
std::shared_ptr< localmatrixdata_t > C_offd_offd
std::shared_ptr< localmatrixdata_t > A_offd
typename matrixdata_t::localmatrixdata_t localmatrixdata_t
std::shared_ptr< localmatrixdata_t > B_diag
std::shared_ptr< localmatrixdata_t > C_diag
std::shared_ptr< localmatrixdata_t > C_offd_diag
CSRMatrixSpGEMMDefault(std::shared_ptr< matrixdata_t > A_, std::shared_ptr< matrixdata_t > B_, std::shared_ptr< matrixdata_t > C_, bool overlap_comms_)
void multiplyReuse(std::shared_ptr< localmatrixdata_t > A_data, std::shared_ptr< localmatrixdata_t > B_data, std::shared_ptr< localmatrixdata_t > C_data)
#define AMP_DEBUG_ASSERT(EXP)
Assert error (debug only)
#define AMP_DEBUG_INSIST(EXP, MSG)
Insist error.
AMP_MPI getComm(const TYPE &obj)
Return the underlying MPI class for the object.
void insert_or_append(col_t col_idx, scalar_t val, col_t *col_space, scalar_t *val_space)
void insert_or_append(col_t col_idx, scalar_t val, col_t *col_space, scalar_t *val_space)



Advanced Multi-Physics (AMP)
Oak Ridge National Laboratory
Idaho National Laboratory
Los Alamos National Laboratory
This page automatically produced from the
source code by doxygen
Last updated: Tue Mar 10 2026 13:06:40.
Comments on this page