1#ifndef included_CSRMatrixOperationsKokkos_H_
2#define included_CSRMatrixOperationsKokkos_H_
4#include "AMP/AMP_TPLs.h"
5#include "AMP/matrices/CSRConfig.h"
6#include "AMP/matrices/data/CSRMatrixData.h"
7#include "AMP/matrices/data/MatrixData.h"
8#include "AMP/matrices/operations/MatrixOperations.h"
9#include "AMP/matrices/operations/default/CSRMatrixOperationsDefault.h"
10#include "AMP/matrices/operations/kokkos/CSRLocalMatrixOperationsKokkos.h"
11#include "AMP/utils/Memory.h"
12#include "AMP/vectors/Vector.h"
15 #include "AMP/matrices/operations/device/CSRMatrixOperationsDevice.h"
22 #include "Kokkos_Core.hpp"
26template<
typename Config,
28 class ExecSpace =
typename std::conditional<alloc_info<Config::allocator>::mem_loc ==
30 Kokkos::DefaultHostExecutionSpace,
31 Kokkos::DefaultExecutionSpace>::type,
32 class ViewSpace =
typename std::conditional<
35 typename std::conditional<
38 typename Kokkos::DefaultExecutionSpace::memory_space>::type>::type
40 class ExecSpace = Kokkos::DefaultHostExecutionSpace,
41 class ViewSpace = Kokkos::HostSpace
44class CSRMatrixOperationsKokkos :
public MatrixOperations
47 static_assert( std::is_same_v<typename Config::allocator_type::value_type, void> );
49 using config_type = Config;
50 using allocator_type =
typename Config::allocator_type;
51 using matrixdata_t = CSRMatrixData<Config>;
52 using localmatrixdata_t =
typename matrixdata_t::localmatrixdata_t;
54 using localops_t = CSRLocalMatrixOperationsKokkos<Config, ExecSpace, ViewSpace>;
56 using gidx_t =
typename Config::gidx_t;
57 using lidx_t =
typename Config::lidx_t;
58 using scalar_t =
typename Config::scalar_t;
60 CSRMatrixOperationsKokkos()
62 d_localops_diag(
std::make_shared<localops_t>( d_exec_space ) ),
63 d_localops_offd(
std::make_shared<localops_t>( d_exec_space ) )
72 void mult( std::shared_ptr<const Vector> x,
74 std::shared_ptr<Vector> y )
override;
81 void multTranspose( std::shared_ptr<const Vector> in,
83 std::shared_ptr<Vector> out )
override;
89 void scale(
AMP::Scalar alpha, MatrixData &A )
override;
97 void scale(
AMP::Scalar alpha, std::shared_ptr<const Vector> D, MatrixData &A )
override;
105 void scaleInv(
AMP::Scalar alpha, std::shared_ptr<const Vector> D, MatrixData &A )
override;
112 void matMatMult( std::shared_ptr<MatrixData> A,
113 std::shared_ptr<MatrixData> B,
114 std::shared_ptr<MatrixData> C )
override;
121 void axpy(
AMP::Scalar alpha,
const MatrixData &
X, MatrixData &Y )
override;
126 void setScalar(
AMP::Scalar alpha, MatrixData &A )
override;
131 void zero( MatrixData &A )
override;
136 void setDiagonal( std::shared_ptr<const Vector> in, MatrixData &A )
override;
140 void setIdentity( MatrixData &A )
override;
146 void extractDiagonal( MatrixData
const &A, std::shared_ptr<Vector> buf )
override;
152 void getRowSums( MatrixData
const &A, std::shared_ptr<Vector> buf )
override;
158 void getRowSumsAbsolute( MatrixData
const &A,
159 std::shared_ptr<Vector> buf,
160 const bool remove_zeros =
false )
override;
165 AMP::Scalar LinfNorm(
const MatrixData &
X )
const override;
172 void copy(
const MatrixData &
X, MatrixData &Y )
override;
179 void copyCast(
const MatrixData &
X, MatrixData &Y )
override;
181 template<
typename ConfigIn>
182 static void copyCast( CSRMatrixData<ConfigIn> *
X, CSRMatrixData<Config> *Y );
184 std::string type()
const override {
return "CSRMatrixOperationsKokkos"; }
191 void writeRestart( int64_t fid )
const override;
195 d_localops_diag(
std::make_shared<localops_t>( d_exec_space ) ),
196 d_localops_offd(
std::make_shared<localops_t>( d_exec_space ) )
201 ExecSpace d_exec_space;
202 std::shared_ptr<localops_t> d_localops_diag;
203 std::shared_ptr<localops_t> d_localops_offd;
207 CSRMatrixOperationsDefault<Config> d_matrixOpsDefault;
208 #ifdef AMP_USE_DEVICE
209 CSRMatrixOperationsDevice<Config> d_matrixOpsDevice;
Class to manage reading/writing restart data.
Scalar is a class used to store a scalar variable that may be different types/precision.
void copy(size_t N, const T1 *src, T2 *dst)
Perform copy with conversion if necessary.
void copyCast(const size_t len, const T1 *vec_in, T2 *vec_out)
void zero(void *dest, std::size_t count)
Perform memory zero (pointer may be in any memory space)