mgcpp
A C++ Math Library Based on CUDA
gemm.hpp
Go to the documentation of this file.
1 
2 // Copyright RedPortal, mujjingun 2017 - 2018.
3 // Distributed under the Boost Software License, Version 1.0.
4 // (See accompanying file LICENSE or copy at
5 // http://www.boost.org/LICENSE_1_0.txt)
6 
7 #ifndef _MGCPP_OPERATIONS_GEMM_HPP_
8 #define _MGCPP_OPERATIONS_GEMM_HPP_
9 
13 
14 #include <cstdlib>
15 
16 namespace mgcpp {
17 namespace strict {
25 template <typename ADense, typename BDense, typename CDense, typename Type>
26 inline decltype(auto) gemm(dense_matrix<ADense, Type> const& A,
27  dense_matrix<BDense, Type> const& B,
28  dense_matrix<CDense, Type> const& C);
29 
39 template <
40  typename ADense,
41  typename BDense,
42  typename CDense,
43  typename Type,
44  typename ScalarAlpha,
45  typename ScalarBeta,
46  typename = typename std::enable_if<is_scalar<ScalarAlpha>::value &&
47  is_scalar<ScalarBeta>::value>::type>
48 inline decltype(auto) gemm(ScalarAlpha alpha,
49  dense_matrix<ADense, Type> const& A,
50  dense_matrix<BDense, Type> const& B,
51  ScalarBeta beta,
52  dense_matrix<CDense, Type> const& C);
53 
64 template <
65  typename ADense,
66  typename BDense,
67  typename CDense,
68  typename Type,
69  typename ScalarAlpha,
70  typename ScalarBeta,
71  typename = typename std::enable_if<is_scalar<ScalarAlpha>::value &&
72  is_scalar<ScalarBeta>::value>::type>
73 inline decltype(auto) gemm(ScalarAlpha alpha,
74  dense_matrix<ADense, Type> const& A,
75  dense_matrix<BDense, Type> const& B,
76  ScalarBeta beta,
77  dense_matrix<CDense, Type>&& C);
78 
79 enum class trans_mode {
80  same = CUBLAS_OP_N,
81  transposed = CUBLAS_OP_T,
82  conj_trans = CUBLAS_OP_C
83 };
84 
99 template <
100  typename ADense,
101  typename BDense,
102  typename CDense,
103  typename Type,
104  typename ScalarAlpha,
105  typename ScalarBeta,
106  typename = typename std::enable_if<is_scalar<ScalarAlpha>::value &&
108 inline decltype(auto) gemm(ScalarAlpha alpha,
109  trans_mode mode_A,
110  trans_mode mode_B,
113  ScalarBeta beta,
114  dense_matrix<CDense, Type> const& C);
115 
131 template <
132  typename ADense,
133  typename BDense,
134  typename CDense,
135  typename Type,
136  typename ScalarAlpha,
137  typename ScalarBeta,
138  typename = typename std::enable_if<is_scalar<ScalarAlpha>::value &&
140 inline decltype(auto) gemm(ScalarAlpha alpha,
141  trans_mode mode_A,
142  trans_mode mode_B,
145  ScalarBeta beta,
147 } // namespace strict
148 } // namespace mgcpp
149 
150 #include <mgcpp/operations/gemm.tpp>
151 #endif
Definition: adapter_base.hpp:12
Definition: shape.hpp:33
Definition: is_scalar.hpp:16
decltype(auto) gemm(dense_matrix< ADense, Type > const &A, dense_matrix< BDense, Type > const &B, dense_matrix< CDense, Type > const &C)
Definition: dense_matrix.hpp:15
trans_mode
Definition: gemm.hpp:79