Expression Templates Library (ETL)
batch_hint_builder.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 #include "etl/expr/batch_k_scale_expr.hpp"
16 #include "etl/expr/batch_k_scale_plus_expr.hpp"
17 #include "etl/expr/batch_k_minus_scale_expr.hpp"
18 
19 namespace etl {
20 
21 constexpr bool is_2d4d(size_t dimensions) {
22  return dimensions == 2 || dimensions == 4;
23 }
24 
30 template <typename Expr>
31 auto batch_hint(Expr&& expr) {
32  using expr_t = std::decay_t<Expr>;
33 
34  // If this becomes more complicated, make one detection function for each
35  // possible type instead of craming everything here
36 
37  if constexpr (is_binary_expr<Expr>) {
38  using value_type = typename expr_t::value_type;
39  using operator_type = typename expr_t::operator_type;
40 
41  using left_type = typename expr_t::left_type;
42  using right_type = typename expr_t::right_type;
43 
44  constexpr size_t left_dimensions = decay_traits<left_type>::dimensions();
45  constexpr size_t right_dimensions = decay_traits<right_type>::dimensions();
46 
47  if constexpr (std::same_as<operator_type, mul_binary_op<value_type>>) {
48  if constexpr (is_binary_expr<right_type>) {
49  auto& right_expr = expr.get_rhs();
50 
51  using right_value_type = typename right_type::value_type;
52  using right_operator_type = typename right_type::operator_type;
53 
54  using right_left_type = typename right_type::left_type;
55  using right_right_type = typename right_type::right_type;
56 
57  if constexpr (std::same_as<right_operator_type, minus_binary_op<right_value_type>>) {
58  constexpr size_t right_left_dimensions = decay_traits<right_left_type>::dimensions();
59  constexpr size_t right_right_dimensions = decay_traits<right_right_type>::dimensions();
60 
61  if constexpr (is_2d4d(right_left_dimensions) && right_right_dimensions == 1 && left_dimensions == 1
62  && all_homogeneous<right_left_type, right_right_type, left_type>) {
63  // Detect gamma[K] * (input[B, K, W, H]) - beta[k])
64  return batch_k_minus_scale(expr.get_lhs(), right_expr.get_lhs(), right_expr.get_rhs());
65  } else {
66  if constexpr (left_dimensions == 1 && is_2d4d(right_dimensions) && all_homogeneous<left_type, right_type>) {
67  // Detect gamma[K] * beta[B, K, W, H]
68  return batch_k_scale(expr.get_lhs(), expr.get_rhs());
69  } else {
70  return std::forward<Expr>(expr);
71  }
72  }
73  } else {
74  if constexpr (left_dimensions == 1 && is_2d4d(right_dimensions) && all_homogeneous<left_type, right_type>) {
75  // Detect gamma[K] * beta[B, K, W, H]
76  return batch_k_scale(expr.get_lhs(), expr.get_rhs());
77  } else {
78  return std::forward<Expr>(expr);
79  }
80  }
81  } else {
82  if constexpr (left_dimensions == 1 && is_2d4d(right_dimensions) && all_homogeneous<left_type, right_type>) {
83  // Detect gamma[K] * beta[B, K, W, H]
84  return batch_k_scale(expr.get_lhs(), expr.get_rhs());
85  } else {
86  return std::forward<Expr>(expr);
87  }
88  }
89  } else if constexpr (std::same_as<operator_type, plus_binary_op<value_type>>) {
90  if constexpr (is_binary_expr<left_type>) {
91  auto& left_expr = expr.get_lhs();
92 
93  using left_value_type = typename left_type::value_type;
94  using left_operator_type = typename left_type::operator_type;
95 
96  using left_left_type = typename left_type::left_type;
97  using left_right_type = typename left_type::right_type;
98 
99  if constexpr (std::same_as<left_operator_type, mul_binary_op<left_value_type>>) {
100  constexpr size_t left_left_dimensions = decay_traits<left_left_type>::dimensions();
101  constexpr size_t left_right_dimensions = decay_traits<left_right_type>::dimensions();
102 
103  if constexpr (left_left_dimensions == 1 && is_2d4d(left_right_dimensions) && right_dimensions == 1
104  && all_homogeneous<left_left_type, left_right_type, right_type>) {
105  // Detect (gamma[K] * input[B, K, W, H]) + beta[k]
106  return batch_k_scale_plus(left_expr.get_lhs(), left_expr.get_rhs(), expr.get_rhs());
107  } else {
108  return std::forward<Expr>(expr);
109  }
110  } else {
111  return std::forward<Expr>(expr);
112  }
113 
114  } else {
115  return std::forward<Expr>(expr);
116  }
117 
118  } else {
119  return std::forward<Expr>(expr);
120  }
121  } else {
122  return std::forward<Expr>(expr);
123  }
124 }
125 
126 } //end of namespace etl
batch_k_scale_expr< detail::build_type< A >, detail::build_type< B > > batch_k_scale(const A &a, const B &b)
Returns the transpose of the given expression.
Definition: batch_k_scale_expr.hpp:1495
batch_k_scale_plus_expr< detail::build_type< A >, detail::build_type< B >, detail::build_type< C > > batch_k_scale_plus(const A &a, const B &b, const C &c)
Returns the transpose of the given expression.
Definition: batch_k_scale_plus_expr.hpp:1575
value_t< sub_type > value_type
The value contained in the expression.
Definition: dyn_matrix_view.hpp:31
batch_k_minus_scale_expr< detail::build_type< A >, detail::build_type< B >, detail::build_type< C > > batch_k_minus_scale(const A &a, const B &b, const C &c)
Returns the transpose of the given expression.
Definition: batch_k_minus_scale_expr.hpp:1575
Binary operator for scalar subtraction.
Definition: minus.hpp:16
auto batch_hint(Expr &&expr)
Build a special expression for batched expressions.
Definition: batch_hint_builder.hpp:31
Root namespace for the ETL library.
Definition: adapter.hpp:15
static constexpr size_t dimensions()
Return the number of dimensions of the expression.
Definition: traits_base.hpp:31
Binary operator for scalar multiplication.
Definition: div.hpp:13
Binary operator for scalar addition.
Definition: plus.hpp:154