31 #ifndef MATRIX_PRODUCT_IMPL_H_
32 #define MATRIX_PRODUCT_IMPL_H_
43 #include <viennacl/linalg/prod.hpp>
44 #include <viennacl/matrix.hpp>
45 #endif // HAVE_VIENNACL
53 namespace implementation
59 template <enum Backend,
class Matrix>
63 typedef typename Matrix::Scalar
T;
75 static void compute(Matrix A, Matrix B, Matrix C,
76 bool transpose_A,
bool transpose_B,
bool overwrite);
82 template <>
template <
class Matrix>
85 typedef typename Matrix::Scalar
T;
99 bool transpose_A,
bool transpose_B,
bool overwrite)
107 if (transpose_A && transpose_B)
108 C_eig = A_eig.transpose() * B_eig.transpose();
110 else if (transpose_A)
111 C_eig = A_eig.transpose() * B_eig;
113 else if (transpose_B)
114 C_eig = A_eig * B_eig.transpose();
117 C_eig = A_eig * B_eig;
121 if (transpose_A && transpose_B)
122 C_eig += A_eig.transpose() * B_eig.transpose();
124 else if (transpose_A)
125 C_eig += A_eig.transpose() * B_eig;
127 else if (transpose_B)
128 C_eig += A_eig * B_eig.transpose();
131 C_eig += A_eig * B_eig;
135 #endif // HAVE_EIGEN3
140 template <>
template <
class Matrix>
143 typedef typename Matrix::Scalar
T;
155 static void compute(CGPUMatrix<T> A, CGPUMatrix<T> B, CGPUMatrix<T> C,
156 bool transpose_A,
bool transpose_B,
bool overwrite)
160 if (transpose_A && transpose_B)
161 C.vcl_matrix() = viennacl::linalg::prod(
162 viennacl::trans(A.vcl_matrix()), viennacl::trans(B.vcl_matrix()));
164 else if (transpose_A)
165 C.vcl_matrix() = viennacl::linalg::prod(
166 viennacl::trans(A.vcl_matrix()), B.vcl_matrix());
168 else if (transpose_B)
169 C.vcl_matrix() = viennacl::linalg::prod(
170 A.vcl_matrix(), viennacl::trans(B.vcl_matrix()));
173 C.vcl_matrix() = viennacl::linalg::prod(A.vcl_matrix(), B.vcl_matrix());
177 if (transpose_A && transpose_B)
178 C.vcl_matrix() += viennacl::linalg::prod(
179 viennacl::trans(A.vcl_matrix()), viennacl::trans(B.vcl_matrix()));
181 else if (transpose_A)
182 C.vcl_matrix() += viennacl::linalg::prod(
183 viennacl::trans(A.vcl_matrix()), B.vcl_matrix());
185 else if (transpose_B)
186 C.vcl_matrix() += viennacl::linalg::prod(
187 A.vcl_matrix(), viennacl::trans(B.vcl_matrix()));
190 C.vcl_matrix() += viennacl::linalg::prod(A.vcl_matrix(), B.vcl_matrix());
195 #endif // HAVE_VIENNACL
202 #endif // MATRIX_PRODUCT_IMPL_H_
static void compute(Matrix A, Matrix B, Matrix C, bool transpose_A, bool transpose_B, bool overwrite)
void matrix_product(Matrix A, Matrix B, Matrix C, bool transpose_A=false, bool transpose_B=false, bool overwrite=true)
all of classes and functions are contained in the shogun namespace
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > MatrixXt
static void compute(SGMatrix< T > A, SGMatrix< T > B, SGMatrix< T > C, bool transpose_A, bool transpose_B, bool overwrite)