1#ifndef included_CSRMatrixOperationsDefault_H_
2#define included_CSRMatrixOperationsDefault_H_
4#include "AMP/matrices/data/CSRMatrixData.h"
5#include "AMP/matrices/data/MatrixData.h"
6#include "AMP/matrices/operations/MatrixOperations.h"
7#include "AMP/matrices/operations/default/CSRLocalMatrixOperationsDefault.h"
8#include "AMP/matrices/operations/default/spgemm/CSRMatrixSpGEMMDefault.h"
9#include "AMP/vectors/Vector.h"
15template<
typename Config>
20 static_assert( std::is_same_v<typename allocator_type::value_type, void> );
27 using gidx_t =
typename Config::gidx_t;
28 using lidx_t =
typename Config::lidx_t;
43 void mult( std::shared_ptr<const Vector> x,
45 std::shared_ptr<Vector> y )
override;
55 std::shared_ptr<Vector> out )
override;
86 std::shared_ptr<MatrixData> B,
87 std::shared_ptr<MatrixData> C )
override;
133 std::shared_ptr<Vector> buf,
134 const bool remove_zeros =
false )
override;
161 template<
typename ConfigIn>
166 std::string
type()
const override {
return "CSRMatrixOperationsDefault"; }
184 std::map<std::pair<std::shared_ptr<matrixdata_t>, std::shared_ptr<matrixdata_t>>,
Class to manage reading/writing restart data.
CSRLocalMatrixData< Config > localmatrixdata_t
void writeRestart(int64_t fid) const override
Write restart data to file.
AMP::Scalar LinfNorm(const MatrixData &X) const override
Compute the maximum row sum.
void mult(std::shared_ptr< const Vector > x, MatrixData const &A, std::shared_ptr< Vector > y) override
Matrix-vector multiplication.
std::map< std::pair< std::shared_ptr< matrixdata_t >, std::shared_ptr< matrixdata_t > >, CSRMatrixSpGEMMDefault< Config > > d_SpGEMMHelpers
void setIdentity(MatrixData &A) override
Set the matrix to the identity matrix.
void getRowSumsAbsolute(MatrixData const &A, std::shared_ptr< Vector > buf, const bool remove_zeros=false) override
Extract the absolute row sums into a vector.
typename Config::lidx_t lidx_t
typename Config::gidx_t gidx_t
typename Config::allocator_type allocator_type
typename matrixdata_t::localmatrixdata_t localmatrixdata_t
void copyCast(const MatrixData &X, MatrixData &Y) override
Set Y matrix with the same non-zero and distributed structure as x and copy the coefficients after up...
void zero(MatrixData &A) override
Set the non-zeros of the matrix to zero.
void setScalar(AMP::Scalar alpha, MatrixData &A) override
Set the non-zeros of the matrix to a scalar.
void scale(AMP::Scalar alpha, MatrixData &A) override
Scale the matrix by a scalar.
std::shared_ptr< localops_t > d_localops_offd
CSRMatrixOperationsDefault()
void scale(AMP::Scalar alpha, std::shared_ptr< const Vector > D, MatrixData &A) override
Scale the matrix by a scalar and diagonal matrix.
void matMatMult(std::shared_ptr< MatrixData > A, std::shared_ptr< MatrixData > B, std::shared_ptr< MatrixData > C) override
Compute the product of two matrices.
void getRowSums(MatrixData const &A, std::shared_ptr< Vector > buf) override
Extract the row sums into a vector.
void extractDiagonal(MatrixData const &A, std::shared_ptr< Vector > buf) override
Extract the diagonal values into a vector.
void multTranspose(std::shared_ptr< const Vector > in, MatrixData const &A, std::shared_ptr< Vector > out) override
Matrix transpose-vector multiplication.
void setDiagonal(std::shared_ptr< const Vector > in, MatrixData &A) override
Set the diagonal to the values in a vector.
void axpy(AMP::Scalar alpha, const MatrixData &X, MatrixData &Y) override
Compute the linear combination of two matrices.
void scaleInv(AMP::Scalar alpha, std::shared_ptr< const Vector > D, MatrixData &A) override
Scale the matrix by a scalar and inverse of diagonal matrix.
CSRMatrixOperationsDefault(int64_t, AMP::IO::RestartManager *)
static void copyCast(CSRMatrixData< typename ConfigIn::template set_alloc_t< Config::allocator > > *X, matrixdata_t *Y)
std::string type() const override
Return the type of the matrix operations class.
void copy(const MatrixData &X, MatrixData &Y) override
Set Y matrix with the same non-zero and distributed structure as x and copy the coefficients.
typename Config::scalar_t scalar_t
std::shared_ptr< localops_t > d_localops_diag
Scalar is a class used to store a scalar variable that may be different types/precision.