30#include "Teuchos_Assert.hpp"
32template <
typename OrdinalType,
typename FadType>
39 workspace_pointer(
NULL)
47template <
typename OrdinalType,
typename FadType>
50 use_dynamic(
a.use_dynamic),
51 workspace_size(
a.workspace_size),
55 if (workspace_size > 0) {
57 workspace_pointer = workspace;
62template <
typename OrdinalType,
typename FadType>
76 if (workspace_size > 0)
80template <
typename OrdinalType,
typename FadType>
94template <
typename OrdinalType,
typename FadType>
141template <
typename OrdinalType,
typename FadType>
166 cdot = &A[0].fastAccessDx(0);
177 for (OrdinalType j=0; j<n; j++) {
178 for (OrdinalType
i=0;
i<m;
i++) {
179 val[j*m+
i] = A[j*lda+
i].val();
180 for (OrdinalType k=0; k<n_dot; k++)
181 dot[(k*n+j)*m+
i] = A[j*lda+
i].fastAccessDx(k);
190template <
typename OrdinalType,
typename FadType>
193unpack(
const ValueType&
a, OrdinalType& n_dot, ValueType&
val,
194 const ValueType*& dot)
const
201template <
typename OrdinalType,
typename FadType>
204unpack(
const ValueType*
a, OrdinalType n, OrdinalType inc,
205 OrdinalType& n_dot, OrdinalType& inc_val, OrdinalType& inc_dot,
206 const ValueType*& cval,
const ValueType*& cdot)
const
215template <
typename OrdinalType,
typename FadType>
218unpack(
const ValueType* A, OrdinalType m, OrdinalType n, OrdinalType lda,
219 OrdinalType& n_dot, OrdinalType& lda_val, OrdinalType& lda_dot,
220 const ValueType*& cval,
const ValueType*& cdot)
const
229template <
typename OrdinalType,
typename FadType>
240template <
typename OrdinalType,
typename FadType>
254template <
typename OrdinalType,
typename FadType>
268template <
typename OrdinalType,
typename FadType>
279 "ArrayTraits::unpack(): FadType has wrong number of " <<
280 "derivative components. Got " <<
n_dot <<
293 dot = &
a.fastAccessDx(0);
298template <
typename OrdinalType,
typename FadType>
318 "ArrayTraits::unpack(): FadType has wrong number of " <<
319 "derivative components. Got " <<
n_dot <<
331 val = allocate_array(n);
339 dot = &
a[0].fastAccessDx(0);
359template <
typename OrdinalType,
typename FadType>
380 "ArrayTraits::unpack(): FadType has wrong number of " <<
381 "derivative components. Got " <<
n_dot <<
393 val = allocate_array(m*n);
402 dot = &A[0].fastAccessDx(0);
424template <
typename OrdinalType,
typename FadType>
439 a.fastAccessDx(
i) =
dot[
i];
442template <
typename OrdinalType,
typename FadType>
472template <
typename OrdinalType,
typename FadType>
492 if (A[0].size() !=
n_dot)
498 if (A[0].
dx() !=
dot)
505template <
typename OrdinalType,
typename FadType>
515template <
typename OrdinalType,
typename FadType>
532template <
typename OrdinalType,
typename FadType>
549template <
typename OrdinalType,
typename FadType>
560 "ArrayTraits::allocate_array(): " <<
561 "Requested workspace memory beyond size allocated. " <<
562 "Workspace size is " << workspace_size <<
563 ", currently used is " << workspace_pointer-workspace <<
564 ", requested size is " << size <<
".");
569 workspace_pointer += size;
573template <
typename OrdinalType,
typename FadType>
578 if (use_dynamic &&
ptr !=
NULL)
581 workspace_pointer -= size;
584template <
typename OrdinalType,
typename FadType>
590 (&(
a[n-1].val())-&(
a[0].
val()) == n-1) &&
591 (
a[n-1].dx()-
a[0].dx() == n-1);
594template <
typename OrdinalType,
typename FadType>
605template <
typename OrdinalType,
typename FadType>
609 arrayTraits(
x.arrayTraits),
611 use_default_impl(
x.use_default_impl)
615template <
typename OrdinalType,
typename FadType>
621template <
typename OrdinalType,
typename FadType>
627 if (use_default_impl) {
648 "BLAS::SCAL(): All arguments must have " <<
649 "the same number of derivative components, or none");
667template <
typename OrdinalType,
typename FadType>
673 if (use_default_impl) {
684 !arrayTraits.is_array_contiguous(
x, n,
n_x_dot) ||
685 !arrayTraits.is_array_contiguous(
y, n,
n_y_dot))
694template <
typename OrdinalType,
typename FadType>
695template <
typename alpha_type,
typename x_type>
701 if (use_default_impl) {
733 "BLAS::AXPY(): All arguments must have " <<
734 "the same number of derivative components, or none");
753template <
typename OrdinalType,
typename FadType>
754template <
typename x_type,
typename y_type>
760 if (use_default_impl)
795template <
typename OrdinalType,
typename FadType>
800 if (use_default_impl)
801 return BLASType::NRM2(n,
x,
incx);
827template <
typename OrdinalType,
typename FadType>
838 if (use_default_impl) {
839 BLASType::GEMV(
trans,m,n,
alpha,A,
lda,
x,
incx,
beta,
y,
incy);
845 if (
trans != Teuchos::NO_TRANS) {
898template <
typename OrdinalType,
typename FadType>
899template <
typename A_type>
906 if (use_default_impl) {
926 "BLAS::TRMV(): All arguments must have " <<
927 "the same number of derivative components, or none");
942 if (gemv_Ax.size() != std::size_t(n))
963template <
typename OrdinalType,
typename FadType>
964template <
typename alpha_type,
typename x_type,
typename y_type>
972 if (use_default_impl) {
1017template <
typename OrdinalType,
typename FadType>
1028 if (use_default_impl) {
1029 BLASType::GEMM(
transa,
transb,m,n,
k,
alpha,A,
lda,
B,
ldb,
beta,
C,
ldc);
1035 if (
transa != Teuchos::NO_TRANS) {
1042 if (
transb != Teuchos::NO_TRANS) {
1100template <
typename OrdinalType,
typename FadType>
1111 if (use_default_impl) {
1112 BLASType::SYMM(
side,
uplo,m,n,
alpha,A,
lda,
B,
ldb,
beta,
C,
ldc);
1118 if (
side == Teuchos::RIGHT_SIDE) {
1174template <
typename OrdinalType,
typename FadType>
1175template <
typename alpha_type,
typename A_type>
1179 Teuchos::ETransp
transa, Teuchos::EDiag
diag,
1184 if (use_default_impl) {
1185 BLASType::TRMM(
side,
uplo,
transa,
diag,m,n,
alpha,A,
lda,
B,
ldb);
1191 if (
side == Teuchos::RIGHT_SIDE) {
1234template <
typename OrdinalType,
typename FadType>
1235template <
typename alpha_type,
typename A_type>
1239 Teuchos::ETransp
transa, Teuchos::EDiag
diag,
1244 if (use_default_impl) {
1245 BLASType::TRSM(
side,
uplo,
transa,
diag,m,n,
alpha,A,
lda,
B,
ldb);
1251 if (
side == Teuchos::RIGHT_SIDE) {
1294template <
typename OrdinalType,
typename FadType>
1295template <
typename x_type,
typename y_type>
1318 "BLAS::Fad_DOT(): All arguments must have " <<
1319 "the same number of derivative components, or none");
1325 blas.GEMV(Teuchos::TRANS, n,
n_x_dot, 1.0,
x_dot, n,
y,
incy, 0.0,
z_dot,
1335 !Teuchos::ScalarTraits<ValueType>::isComplex)
1336 blas.GEMV(Teuchos::TRANS, n,
n_y_dot, 1.0,
y_dot, n,
x,
incx, 1.0,
z_dot,
1347template <
typename OrdinalType,
typename FadType>
1386 "BLAS::Fad_GEMV(): All arguments must have " <<
1387 "the same number of derivative components, or none");
1393 if (
trans == Teuchos::TRANS) {
1417 if (gemv_Ax.size() != std::size_t(n))
1419 blas.
GEMV(
trans, m, n, 1.0, A,
lda,
x,
incx, 0.0, &gemv_Ax[0],
1442 blas.
GEMV(
trans, m, n,
alpha, A,
lda,
x,
incx,
beta,
y,
incy);
1445template <
typename OrdinalType,
typename FadType>
1446template <
typename alpha_type,
typename x_type,
typename y_type>
1478 "BLAS::Fad_GER(): All arguments must have " <<
1479 "the same number of derivative components, or none");
1499template <
typename OrdinalType,
typename FadType>
1540 "BLAS::Fad_GEMM(): All arguments must have " <<
1541 "the same number of derivative components, or none");
1544 if (
transa != Teuchos::NO_TRANS) {
1549 if (
transb != Teuchos::NO_TRANS) {
1570 if (gemm_AB.size() != std::size_t(m*n))
1571 gemm_AB.resize(m*n);
1572 blas.
GEMM(
transa,
transb, m, n,
k, 1.0, A,
lda,
B,
ldb, 0.0, &gemm_AB[0],
1615 blas.
GEMM(
transa,
transb, m, n,
k,
alpha, A,
lda,
B,
ldb,
beta,
C,
ldc);
1618template <
typename OrdinalType,
typename FadType>
1657 "BLAS::Fad_SYMM(): All arguments must have " <<
1658 "the same number of derivative components, or none");
1661 if (
side == Teuchos::RIGHT_SIDE) {
1682 if (gemm_AB.size() != std::size_t(m*n))
1683 gemm_AB.resize(m*n);
1684 blas.
SYMM(
side,
uplo, m, n, 1.0, A,
lda,
B,
ldb, 0.0, &gemm_AB[0],
1727 blas.
SYMM(
side,
uplo, m, n,
alpha, A,
lda,
B,
ldb,
beta,
C,
ldc);
1730template <
typename OrdinalType,
typename FadType>
1731template <
typename alpha_type,
typename A_type>
1735 Teuchos::EUplo
uplo,
1737 Teuchos::EDiag
diag,
1761 "BLAS::Fad_TRMM(): All arguments must have " <<
1762 "the same number of derivative components, or none");
1765 if (
side == Teuchos::RIGHT_SIDE) {
1771 blas.
TRMM(
side,
uplo,
transa,
diag, m, n,
alpha, A,
lda,
B_dot+
i*
ldb_dot*n,
1776 if (gemm_AB.size() != std::size_t(m*n))
1777 gemm_AB.resize(m*n);
1783 blas.
TRMM(
side,
uplo,
transa,
diag, m, n, 1.0, A,
lda, &gemm_AB[0],
1798 if (gemm_AB.size() != std::size_t(m*n))
1799 gemm_AB.resize(m*n);
1833 blas.
TRMM(
side,
uplo,
transa,
diag, m, n,
alpha, A,
lda,
B,
ldb);
1836template <
typename OrdinalType,
typename FadType>
1837template <
typename alpha_type,
typename A_type>
1841 Teuchos::EUplo
uplo,
1843 Teuchos::EDiag
diag,
1867 "BLAS::Fad_TRSM(): All arguments must have " <<
1868 "the same number of derivative components, or none");
1871 if (
side == Teuchos::RIGHT_SIDE) {
1899 blas.
TRSM(
side,
uplo,
transa,
diag, m, n,
alpha, A,
lda,
B,
ldb);
1903 if (gemm_AB.size() != std::size_t(m*n))
1904 gemm_AB.resize(m*n);
1925 if (
side == Teuchos::LEFT_SIDE)
1926 blas.
TRSM(
side,
uplo,
transa,
diag, m, n*
n_dot, 1.0, A,
lda,
B_dot,
1930 blas.
TRSM(
side,
uplo,
transa,
diag, m, n, 1.0, A,
lda,
B_dot+
i*
ldb_dot*n,
expr expr expr fastAccessDx(i)) FAD_UNARYOP_MACRO(exp
void free_array(const ValueType *ptr, OrdinalType size) const
OrdinalType workspace_size
Size of static workspace.
Sacado::dummy< ValueType, scalar_type >::type ScalarType
ValueType * allocate_array(OrdinalType size) const
ValueType * workspace
Workspace for holding contiguous values/derivatives.
ValueType * workspace_pointer
Pointer to current free entry in workspace.
bool is_array_contiguous(const FadType *a, OrdinalType n, OrdinalType n_dot) const
ArrayTraits(bool use_dynamic=true, OrdinalType workspace_size=0)
Fad specializations for Teuchos::BLAS wrappers.
void TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, FadType *B, const OrdinalType ldb) const
Solves the matrix equations: op(A)*X=alpha*B or X*op(A)=alpha*B where X and B are m by n matrices,...
void AXPY(const OrdinalType n, const alpha_type &alpha, const x_type *x, const OrdinalType incx, FadType *y, const OrdinalType incy) const
Perform the operation: y <- y+alpha*x.
void Fad_SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const B_type *B, const OrdinalType ldb, const OrdinalType n_B_dot, const B_type *B_dot, const OrdinalType ldb_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *C, const OrdinalType ldc, const OrdinalType n_C_dot, ValueType *C_dot, const OrdinalType ldc_dot, const OrdinalType n_dot) const
Implementation of SYMM.
Teuchos::ScalarTraits< FadType >::magnitudeType MagnitudeType
BLAS(bool use_default_impl=true, bool use_dynamic=true, OrdinalType static_workspace_size=0)
Default constructor.
virtual ~BLAS()
Destructor.
void Fad_GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const B_type *B, const OrdinalType ldb, const OrdinalType n_B_dot, const B_type *B_dot, const OrdinalType ldb_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *C, const OrdinalType ldc, const OrdinalType n_C_dot, ValueType *C_dot, const OrdinalType ldc_dot, const OrdinalType n_dot) const
Implementation of GEMM.
void Fad_TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, ValueType *B, const OrdinalType ldb, const OrdinalType n_B_dot, ValueType *B_dot, const OrdinalType ldb_dot, const OrdinalType n_dot) const
Implementation of TRMM.
void Fad_TRSM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, ValueType *B, const OrdinalType ldb, const OrdinalType n_B_dot, ValueType *B_dot, const OrdinalType ldb_dot, const OrdinalType n_dot) const
Implementation of TRMM.
void COPY(const OrdinalType n, const FadType *x, const OrdinalType incx, FadType *y, const OrdinalType incy) const
Copy the std::vector x to the std::vector y.
MagnitudeType NRM2(const OrdinalType n, const FadType *x, const OrdinalType incx) const
Compute the 2-norm of the std::vector x.
void Fad_GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const A_type *A, const OrdinalType lda, const OrdinalType n_A_dot, const A_type *A_dot, const OrdinalType lda_dot, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const beta_type &beta, const OrdinalType n_beta_dot, const beta_type *beta_dot, ValueType *y, const OrdinalType incy, const OrdinalType n_y_dot, ValueType *y_dot, const OrdinalType incy_dot, const OrdinalType n_dot) const
Implementation of GEMV.
void Fad_GER(const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const OrdinalType n_alpha_dot, const alpha_type *alpha_dot, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const y_type *y, const OrdinalType incy, const OrdinalType n_y_dot, const y_type *y_dot, const OrdinalType incy_dot, ValueType *A, const OrdinalType lda, const OrdinalType n_A_dot, ValueType *A_dot, const OrdinalType lda_dot, const OrdinalType n_dot) const
Implementation of GER.
void GEMV(Teuchos::ETransp trans, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const x_type *x, const OrdinalType incx, const beta_type &beta, FadType *y, const OrdinalType incy) const
Performs the matrix-std::vector operation: y <- alpha*A*x+beta*y or y <- alpha*A'*x+beta*y where A ...
void SYMM(Teuchos::ESide side, Teuchos::EUplo uplo, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const B_type *B, const OrdinalType ldb, const beta_type &beta, FadType *C, const OrdinalType ldc) const
Performs the matrix-matrix operation: C <- alpha*A*B+beta*C or C <- alpha*B*A+beta*C where A is an m ...
void TRMV(Teuchos::EUplo uplo, Teuchos::ETransp trans, Teuchos::EDiag diag, const OrdinalType n, const A_type *A, const OrdinalType lda, FadType *x, const OrdinalType incx) const
Performs the matrix-std::vector operation: x <- A*x or x <- A'*x where A is a unit/non-unit n by n ...
void TRMM(Teuchos::ESide side, Teuchos::EUplo uplo, Teuchos::ETransp transa, Teuchos::EDiag diag, const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const A_type *A, const OrdinalType lda, FadType *B, const OrdinalType ldb) const
Performs the matrix-matrix operation: C <- alpha*op(A)*B+beta*C or C <- alpha*B*op(A)+beta*C where op...
FadType DOT(const OrdinalType n, const x_type *x, const OrdinalType incx, const y_type *y, const OrdinalType incy) const
Form the dot product of the vectors x and y.
void GER(const OrdinalType m, const OrdinalType n, const alpha_type &alpha, const x_type *x, const OrdinalType incx, const y_type *y, const OrdinalType incy, FadType *A, const OrdinalType lda) const
Performs the rank 1 operation: A <- alpha*x*y'+A.
void Fad_DOT(const OrdinalType n, const x_type *x, const OrdinalType incx, const OrdinalType n_x_dot, const x_type *x_dot, const OrdinalType incx_dot, const y_type *y, const OrdinalType incy, const OrdinalType n_y_dot, const y_type *y_dot, const OrdinalType incy_dot, ValueType &z, const OrdinalType n_z_dot, ValueType *zdot) const
Implementation of DOT.
void GEMM(Teuchos::ETransp transa, Teuchos::ETransp transb, const OrdinalType m, const OrdinalType n, const OrdinalType k, const alpha_type &alpha, const A_type *A, const OrdinalType lda, const B_type *B, const OrdinalType ldb, const beta_type &beta, FadType *C, const OrdinalType ldc) const
Performs the matrix-matrix operation: C <- alpha*op(A)*op(B)+beta*C where op(A) is either A or A',...
Teuchos::DefaultBLASImpl< OrdinalType, FadType > BLASType
void SCAL(const OrdinalType n, const FadType &alpha, FadType *x, const OrdinalType incx) const
Scale the std::vector x by the constant alpha.
Base template specification for ValueType.