Advanced Multi-Physics (AMP)
On-Line Documentation
SpGEMM_Cuda.h
Go to the documentation of this file.
1#ifndef included_AMP_SpGEMM_Cuda
2#define included_AMP_SpGEMM_Cuda
3
4#include "AMP/utils/cuda/Helper_Cuda.h"
5
6#include <cusparse.h>
7
8#include <cstdint>
9#include <type_traits>
10
11namespace AMP::LinearAlgebra {
12
13// This class wraps the operations in the cusparse-spgemm example found at
14// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuSPARSE/spgemm/spgemm_example.c
15template<typename rowidx_t, typename colidx_t, typename scalar_t>
17{
18 // Only signed int 32's and 64's are supported for index types
19 // Only floats and doubles supported for value types
20 static_assert( std::is_same_v<rowidx_t, int> || std::is_same_v<rowidx_t, long long> );
21 static_assert( std::is_same_v<colidx_t, int> || std::is_same_v<colidx_t, long long> );
22 static_assert( std::is_same_v<scalar_t, float> || std::is_same_v<scalar_t, double> );
23
24 // 64 bit row pointers with 32 bit column indices is invalid
25 // TODO: cusparse does *not* allow mixed types for these currently,
26 // add this assert and remove the prior one if that ever changes
27 // static_assert( std::is_same_v<rowidx_t, int> ||
28 // (std::is_same_v<rowidx_t, long long> && std::is_same_v<colidx_t, long long>)
29 // );
30 static_assert( std::is_same_v<rowidx_t, colidx_t> );
31
32public:
33 VendorSpGEMM( const int64_t M_,
34 const int64_t N_,
35 const int64_t K_,
36 const int64_t A_nnz,
37 rowidx_t *A_rs,
38 colidx_t *A_cols,
39 scalar_t *A_vals,
40 const int64_t B_nnz,
41 rowidx_t *B_rs,
42 colidx_t *B_cols,
43 scalar_t *B_vals,
44 rowidx_t *C_rs );
45
47
48 int64_t getCnnz();
49
50 void compute( rowidx_t *C_rs, colidx_t *C_cols, scalar_t *C_vals );
51
52private:
53 const int64_t M;
54 const int64_t N;
55 const int64_t K;
56
57 const scalar_t alpha;
58 const scalar_t beta;
59
60 const cusparseIndexType_t itype;
61 const cusparseIndexType_t jtype;
62 const cudaDataType computeType;
63 const cusparseOperation_t opA;
64 const cusparseOperation_t opB;
65 const cusparseSpGEMMAlg_t alg;
66
67 cusparseHandle_t handle;
68 cusparseSpGEMMDescr_t spgemmDesc;
69
70 cusparseSpMatDescr_t matA;
71 cusparseSpMatDescr_t matB;
72 cusparseSpMatDescr_t matC;
73
76 void *dBuffer1;
77 void *dBuffer2;
78};
79
80} // namespace AMP::LinearAlgebra
81
82
83#endif
cusparseSpMatDescr_t matC
Definition SpGEMM_Cuda.h:72
cusparseSpMatDescr_t matA
Definition SpGEMM_Cuda.h:70
void compute(rowidx_t *C_rs, colidx_t *C_cols, scalar_t *C_vals)
VendorSpGEMM(const int64_t M_, const int64_t N_, const int64_t K_, const int64_t A_nnz, rowidx_t *A_rs, colidx_t *A_cols, scalar_t *A_vals, const int64_t B_nnz, rowidx_t *B_rs, colidx_t *B_cols, scalar_t *B_vals, rowidx_t *C_rs)
const cusparseSpGEMMAlg_t alg
Definition SpGEMM_Cuda.h:65
const cusparseIndexType_t jtype
Definition SpGEMM_Cuda.h:61
const cusparseOperation_t opB
Definition SpGEMM_Cuda.h:64
const cusparseIndexType_t itype
Definition SpGEMM_Cuda.h:60
const cudaDataType computeType
Definition SpGEMM_Cuda.h:62
cusparseSpMatDescr_t matB
Definition SpGEMM_Cuda.h:71
const cusparseOperation_t opA
Definition SpGEMM_Cuda.h:63
cusparseSpGEMMDescr_t spgemmDesc
Definition SpGEMM_Cuda.h:68



Advanced Multi-Physics (AMP)
Oak Ridge National Laboratory
Idaho National Laboratory
Los Alamos National Laboratory
This page automatically produced from the
source code by doxygen
Last updated: Tue Mar 10 2026 13:06:40.
Comments on this page