15 #include "etl/expr/batch_k_scale_expr.hpp" 16 #include "etl/expr/batch_k_scale_plus_expr.hpp" 17 #include "etl/expr/batch_k_minus_scale_expr.hpp" 21 constexpr
bool is_2d4d(
size_t dimensions) {
22 return dimensions == 2 || dimensions == 4;
30 template <
typename Expr>
32 using expr_t = std::decay_t<Expr>;
37 if constexpr (is_binary_expr<Expr>) {
39 using operator_type =
typename expr_t::operator_type;
41 using left_type =
typename expr_t::left_type;
42 using right_type =
typename expr_t::right_type;
48 if constexpr (is_binary_expr<right_type>) {
49 auto& right_expr = expr.get_rhs();
52 using right_operator_type =
typename right_type::operator_type;
54 using right_left_type =
typename right_type::left_type;
55 using right_right_type =
typename right_type::right_type;
61 if constexpr (is_2d4d(right_left_dimensions) && right_right_dimensions == 1 && left_dimensions == 1
62 && all_homogeneous<right_left_type, right_right_type, left_type>) {
66 if constexpr (left_dimensions == 1 && is_2d4d(right_dimensions) && all_homogeneous<left_type, right_type>) {
70 return std::forward<Expr>(expr);
74 if constexpr (left_dimensions == 1 && is_2d4d(right_dimensions) && all_homogeneous<left_type, right_type>) {
78 return std::forward<Expr>(expr);
82 if constexpr (left_dimensions == 1 && is_2d4d(right_dimensions) && all_homogeneous<left_type, right_type>) {
86 return std::forward<Expr>(expr);
90 if constexpr (is_binary_expr<left_type>) {
91 auto& left_expr = expr.get_lhs();
94 using left_operator_type =
typename left_type::operator_type;
96 using left_left_type =
typename left_type::left_type;
97 using left_right_type =
typename left_type::right_type;
103 if constexpr (left_left_dimensions == 1 && is_2d4d(left_right_dimensions) && right_dimensions == 1
104 && all_homogeneous<left_left_type, left_right_type, right_type>) {
108 return std::forward<Expr>(expr);
111 return std::forward<Expr>(expr);
115 return std::forward<Expr>(expr);
119 return std::forward<Expr>(expr);
122 return std::forward<Expr>(expr);
batch_k_scale_expr< detail::build_type< A >, detail::build_type< B > > batch_k_scale(const A &a, const B &b)
Returns the transpose of the given expression.
Definition: batch_k_scale_expr.hpp:1495
batch_k_scale_plus_expr< detail::build_type< A >, detail::build_type< B >, detail::build_type< C > > batch_k_scale_plus(const A &a, const B &b, const C &c)
Returns the transpose of the given expression.
Definition: batch_k_scale_plus_expr.hpp:1575
value_t< sub_type > value_type
The value contained in the expression.
Definition: dyn_matrix_view.hpp:31
batch_k_minus_scale_expr< detail::build_type< A >, detail::build_type< B >, detail::build_type< C > > batch_k_minus_scale(const A &a, const B &b, const C &c)
Returns the transpose of the given expression.
Definition: batch_k_minus_scale_expr.hpp:1575
Binary operator for scalar subtraction.
Definition: minus.hpp:16
auto batch_hint(Expr &&expr)
Build a special expression for batched expressions.
Definition: batch_hint_builder.hpp:31
Root namespace for the ETL library.
Definition: adapter.hpp:15
static constexpr size_t dimensions()
Return the number of dimensions of the expression.
Definition: traits_base.hpp:31
Binary operator for scalar multiplication.
Definition: div.hpp:13
Binary operator for scalar addition.
Definition: plus.hpp:154