Expression Templates Library (ETL)
optimizer.hpp
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
8 #pragma once
9 
10 namespace etl {
11 
19 template <typename Expr>
20 struct optimizable {
26  static bool is([[maybe_unused]] const Expr& expr) {
27  return false;
28  }
29 
35  static bool is_deep([[maybe_unused]] const Expr& expr) {
36  return false;
37  }
38 };
39 
45 template <typename T, typename Expr, typename UnaryOp>
46 struct optimizable<etl::unary_expr<T, Expr, UnaryOp>> {
48  static constexpr bool is(const etl::unary_expr<T, Expr, UnaryOp>& /*unused*/) {
49  return std::is_same_v<UnaryOp, plus_unary_op<T>>;
50  }
51 
53  static bool is_deep(const etl::unary_expr<T, Expr, UnaryOp>& expr) {
54  return is(expr) || is_optimizable_deep(expr.value);
55  }
56 };
57 
63 template <typename T, typename BinaryOp>
64 struct optimizable<etl::binary_expr<T, etl::scalar<T>, BinaryOp, etl::scalar<T>>> {
66  static constexpr bool is(const etl::binary_expr<T, etl::scalar<T>, BinaryOp, etl::scalar<T>>& /*unused*/) {
67  if (std::is_same_v<BinaryOp, mul_binary_op<T>>) {
68  return true;
69  }
70 
71  if (std::is_same_v<BinaryOp, plus_binary_op<T>>) {
72  return true;
73  }
74 
75  if (std::is_same_v<BinaryOp, div_binary_op<T>>) {
76  return true;
77  }
78 
79  if (std::is_same_v<BinaryOp, minus_binary_op<T>>) {
80  return true;
81  }
82 
83  return false;
84  }
85 
87  static bool is_deep(const etl::binary_expr<T, etl::scalar<T>, BinaryOp, etl::scalar<T>>& expr) {
88  return is(expr) || is_optimizable_deep(expr.lhs) || is_optimizable_deep(expr.rhs);
89  }
90 };
91 
97 template <typename T, typename BinaryOp, typename RightExpr>
98 struct optimizable<etl::binary_expr<T, etl::scalar<T>, BinaryOp, RightExpr>> {
100  static bool is(const etl::binary_expr<T, etl::scalar<T>, BinaryOp, RightExpr>& expr) {
101  if (expr.lhs.value == 1.0 && std::is_same_v<BinaryOp, mul_binary_op<T>>) {
102  return true;
103  }
104 
105  if (expr.lhs.value == 0.0 && std::is_same_v<BinaryOp, mul_binary_op<T>>) {
106  return true;
107  }
108 
109  if (expr.lhs.value == 0.0 && std::is_same_v<BinaryOp, plus_binary_op<T>>) {
110  return true;
111  }
112 
113  if (expr.lhs.value == 0.0 && std::is_same_v<BinaryOp, div_binary_op<T>>) {
114  return true;
115  }
116 
117  return false;
118  }
119 
121  static bool is_deep(const etl::binary_expr<T, etl::scalar<T>, BinaryOp, RightExpr>& expr) {
122  return is(expr) || is_optimizable_deep(expr.lhs) || is_optimizable_deep(expr.rhs);
123  }
124 };
125 
131 template <typename T, typename LeftExpr, typename BinaryOp>
132 struct optimizable<etl::binary_expr<T, LeftExpr, BinaryOp, etl::scalar<T>>> {
134  static bool is(const etl::binary_expr<T, LeftExpr, BinaryOp, etl::scalar<T>>& expr) {
135  if (expr.rhs.value == 1.0 && std::is_same_v<BinaryOp, mul_binary_op<T>>) {
136  return true;
137  }
138 
139  if (expr.rhs.value == 0.0 && std::is_same_v<BinaryOp, mul_binary_op<T>>) {
140  return true;
141  }
142 
143  if (expr.rhs.value == 0.0 && std::is_same_v<BinaryOp, plus_binary_op<T>>) {
144  return true;
145  }
146 
147  if (expr.rhs.value == 0.0 && std::is_same_v<BinaryOp, minus_binary_op<T>>) {
148  return true;
149  }
150 
151  if (expr.rhs.value == 1.0 && std::is_same_v<BinaryOp, div_binary_op<T>>) {
152  return true;
153  }
154 
155  return false;
156  }
157 
159  static bool is_deep(const etl::binary_expr<T, LeftExpr, BinaryOp, etl::scalar<T>>& expr) {
160  return is(expr) || is_optimizable_deep(expr.lhs) || is_optimizable_deep(expr.rhs);
161  }
162 };
163 
169 template <typename T, typename LeftExpr, typename BinaryOp, typename RightExpr>
170 struct optimizable<etl::binary_expr<T, LeftExpr, BinaryOp, RightExpr>> {
173  return false;
174  }
175 
178  return is_optimizable_deep(expr.lhs) || is_optimizable_deep(expr.rhs);
179  }
180 };
181 
187 template <typename Expr>
188 bool is_optimizable(const Expr& expr) {
190 }
191 
197 template <typename Expr>
198 bool is_optimizable_deep(const Expr& expr) {
200 }
201 
205 template <typename Expr>
206 struct transformer {
212  template <typename Builder>
213  static void transform([[maybe_unused]] Builder builder, [[maybe_unused]] const Expr& expr) {
214  std::cout << "Arrived in parent, should not happen" << std::endl;
215  }
216 };
217 
223 template <typename T, typename Expr, typename UnaryOp>
224 struct transformer<etl::unary_expr<T, Expr, UnaryOp>> {
230  template <typename Builder>
231  static void transform([[maybe_unused]] Builder parent_builder, [[maybe_unused]] const etl::unary_expr<T, Expr, UnaryOp>& expr) {
232  if constexpr (std::is_same_v<UnaryOp, plus_unary_op<T>>) {
233  parent_builder(expr.value);
234  }
235  }
236 };
237 
243 template <typename T, typename BinaryOp>
244 struct transformer<etl::binary_expr<T, etl::scalar<T>, BinaryOp, etl::scalar<T>>> {
250  template <typename Builder>
251  static void transform([[maybe_unused]] Builder parent_builder, [[maybe_unused]] const etl::binary_expr<T, etl::scalar<T>, BinaryOp, etl::scalar<T>>& expr) {
252  if constexpr (std::is_same_v<BinaryOp, mul_binary_op<T>>) {
253  parent_builder(etl::scalar<T>(expr.lhs.value * expr.rhs.value));
254  } else if constexpr (std::is_same_v<BinaryOp, plus_binary_op<T>>) {
255  parent_builder(etl::scalar<T>(expr.lhs.value + expr.rhs.value));
256  } else if constexpr (std::is_same_v<BinaryOp, minus_binary_op<T>>) {
257  parent_builder(etl::scalar<T>(expr.lhs.value - expr.rhs.value));
258  } else if constexpr (std::is_same_v<BinaryOp, div_binary_op<T>>) {
259  parent_builder(etl::scalar<T>(expr.lhs.value / expr.rhs.value));
260  }
261  }
262 };
263 
269 template <typename T, typename BinaryOp, typename RightExpr>
270 struct transformer<etl::binary_expr<T, etl::scalar<T>, BinaryOp, RightExpr>> {
276  template <typename Builder>
277  static void transform([[maybe_unused]] Builder parent_builder, [[maybe_unused]] const etl::binary_expr<T, etl::scalar<T>, BinaryOp, RightExpr>& expr) {
278  if constexpr (std::is_same_v<BinaryOp, mul_binary_op<T>>) {
279  if (expr.lhs.value == 1.0) {
280  parent_builder(expr.rhs);
281  } else if (expr.lhs.value == 0.0) {
282  parent_builder(expr.lhs);
283  }
284  } else if constexpr (std::is_same_v<BinaryOp, plus_binary_op<T>>) {
285  if (expr.lhs.value == 0.0) {
286  parent_builder(expr.rhs);
287  }
288  } else if constexpr (std::is_same_v<BinaryOp, div_binary_op<T>>) {
289  if (expr.lhs.value == 0.0) {
290  parent_builder(expr.lhs);
291  }
292  }
293  }
294 };
295 
301 template <typename T, typename LeftExpr, typename BinaryOp>
302 struct transformer<etl::binary_expr<T, LeftExpr, BinaryOp, etl::scalar<T>>> {
308  template <typename Builder>
309  static void transform([[maybe_unused]] Builder parent_builder, [[maybe_unused]] const etl::binary_expr<T, LeftExpr, BinaryOp, etl::scalar<T>>& expr) {
310  if constexpr (std::is_same_v<BinaryOp, mul_binary_op<T>>) {
311  if (expr.rhs.value == 1.0) {
312  parent_builder(expr.lhs);
313  } else if (expr.rhs.value == 0.0) {
314  parent_builder(expr.rhs);
315  }
316  } else if constexpr (std::is_same_v<BinaryOp, plus_binary_op<T>>) {
317  if (expr.rhs.value == 0.0) {
318  parent_builder(expr.lhs);
319  }
320  } else if constexpr (std::is_same_v<BinaryOp, minus_binary_op<T>>) {
321  if (expr.rhs.value == 0.0) {
322  parent_builder(expr.lhs);
323  }
324  } else if constexpr (std::is_same_v<BinaryOp, div_binary_op<T>>) {
325  if (expr.rhs.value == 1.0) {
326  parent_builder(expr.lhs);
327  }
328  }
329  }
330 };
331 
337 template <typename Builder, typename Expr>
338 void transform(Builder parent_builder, const Expr& expr) {
339  transformer<std::decay_t<Expr>>::transform(parent_builder, expr);
340 }
341 
345 template <typename Expr>
346 struct optimizer {
352  template <typename Builder>
353  static void apply([[maybe_unused]] Builder parent_builder, [[maybe_unused]] const Expr& expr) {
354  std::cout << "Leaf node" << std::endl;
355  }
356 };
357 
361 template <typename T, typename Expr, typename UnaryOp>
362 struct optimizer<etl::unary_expr<T, Expr, UnaryOp>> {
368  template <typename Builder>
369  static void apply(Builder parent_builder, const etl::unary_expr<T, Expr, UnaryOp>& expr) {
370  if (is_optimizable(expr)) {
371  transform(parent_builder, expr);
372  } else if (is_optimizable_deep(expr.value)) {
373  auto value_builder = [&](auto&& new_value) {
374  parent_builder(etl::unary_expr<T, etl::detail::build_type<decltype(new_value)>, UnaryOp>(new_value));
375  };
376 
377  optimize(value_builder, expr.value);
378  } else {
379  parent_builder(expr);
380  }
381  }
382 };
383 
387 template <typename T, typename LeftExpr, typename BinaryOp, typename RightExpr>
388 struct optimizer<etl::binary_expr<T, LeftExpr, BinaryOp, RightExpr>> {
394  template <typename Builder>
395  static void apply(Builder parent_builder, const etl::binary_expr<T, LeftExpr, BinaryOp, RightExpr>& expr) {
396  if (is_optimizable(expr)) {
397  transform(parent_builder, expr);
398  } else if (is_optimizable_deep(expr.lhs)) {
399  auto lhs_builder = [&](auto&& new_lhs) {
400  parent_builder(etl::binary_expr<T, etl::detail::build_type<decltype(new_lhs)>, BinaryOp, RightExpr>(new_lhs, expr.rhs));
401  };
402 
403  optimize(lhs_builder, expr.lhs);
404  } else if (is_optimizable_deep(expr.rhs)) {
405  auto rhs_builder = [&](auto&& new_rhs) {
406  parent_builder(etl::binary_expr<T, LeftExpr, BinaryOp, etl::detail::build_type<decltype(new_rhs)>>(expr.lhs, new_rhs));
407  };
408 
409  optimize(rhs_builder, expr.rhs);
410  } else {
411  parent_builder(expr);
412  }
413  }
414 };
415 
422 template <typename Builder, typename Expr>
423 void optimize(Builder parent_builder, Expr& expr) {
424  optimizer<std::decay_t<Expr>>::apply(parent_builder, expr);
425 }
426 
433 template <typename Expr, typename Result>
434 void optimized_forward(Expr& expr, Result result) {
435  if (is_optimizable_deep(expr)) {
436  optimize([result](auto&& new_expr) mutable { optimized_forward(new_expr, result); }, expr);
437  return;
438  }
439 
440  result(expr);
441 }
442 
443 } //end of namespace etl
static void transform([[maybe_unused]] Builder parent_builder, [[maybe_unused]] const etl::binary_expr< T, etl::scalar< T >, BinaryOp, etl::scalar< T >> &expr)
Transform the expression using the given builder.
Definition: optimizer.hpp:251
Unary operation computing the plus operation.
Definition: plus.hpp:17
Transformer functor for optimizable expression.
Definition: expr_fwd.hpp:19
bool is_optimizable_deep(const Expr &expr)
Function to test if expr or sub parts of expr are optimizable.
Definition: optimizer.hpp:198
static constexpr bool is(const etl::binary_expr< T, etl::scalar< T >, BinaryOp, etl::scalar< T >> &)
Indicates if the given expression is optimizable or not.
Definition: optimizer.hpp:66
void optimized_forward(Expr &expr, Result result)
Optimize an expression and pass the optimized expression to the given functor.
Definition: optimizer.hpp:434
static void transform([[maybe_unused]] Builder builder, [[maybe_unused]] const Expr &expr)
Transform the expression using the given builder.
Definition: optimizer.hpp:213
static bool is(const etl::binary_expr< T, etl::scalar< T >, BinaryOp, RightExpr > &expr)
Indicates if the given expression is optimizable or not.
Definition: optimizer.hpp:100
Binary operator for scalar subtraction.
Definition: minus.hpp:16
static void apply(Builder parent_builder, const etl::binary_expr< T, LeftExpr, BinaryOp, RightExpr > &expr)
Optimize the expression using the given builder.
Definition: optimizer.hpp:395
static bool is_deep(const etl::binary_expr< T, LeftExpr, BinaryOp, RightExpr > &expr)
Indicates if the given expression or one of its sub expressions is optimizable or not...
Definition: optimizer.hpp:177
static bool is_deep(const etl::unary_expr< T, Expr, UnaryOp > &expr)
Indicates if the given expression or one of its sub expressions is optimizable or not...
Definition: optimizer.hpp:53
static bool is(const etl::binary_expr< T, LeftExpr, BinaryOp, RightExpr > &)
Indicates if the given expression is optimizable or not.
Definition: optimizer.hpp:172
Binary operator for scalar division.
Definition: div.hpp:349
static bool is_deep([[maybe_unused]] const Expr &expr)
Indicates if the given expression or one of its sub expressions is optimizable or not...
Definition: optimizer.hpp:35
An unary expression.
Definition: unary_expr.hpp:126
An optimizer for the given expression type.
Definition: expr_fwd.hpp:16
A binary expression.
Definition: binary_expr.hpp:18
Root namespace for the ETL library.
Definition: adapter.hpp:15
static constexpr bool is(const etl::unary_expr< T, Expr, UnaryOp > &)
Indicates if the given expression is optimizable or not.
Definition: optimizer.hpp:48
static void transform([[maybe_unused]] Builder parent_builder, [[maybe_unused]] const etl::binary_expr< T, LeftExpr, BinaryOp, etl::scalar< T >> &expr)
Transform the expression using the given builder.
Definition: optimizer.hpp:309
static void apply([[maybe_unused]] Builder parent_builder, [[maybe_unused]] const Expr &expr)
Optimize the expression using the given builder.
Definition: optimizer.hpp:353
std::conditional_t< is_etl_value< T >, const std::decay_t< T > &, std::decay_t< T > > build_type
Helper to build the type for a sub expression.
Definition: expression_helpers.hpp:24
static bool is([[maybe_unused]] const Expr &expr)
Indicates if the given expression is optimizable or not.
Definition: optimizer.hpp:26
Represents a scalar value.
Definition: concepts_base.hpp:19
bool is_optimizable(const Expr &expr)
Function to test if expr is optimizable.
Definition: optimizer.hpp:188
Binary operator for scalar multiplication.
Definition: div.hpp:13
Simple traits to test if an expression is optimizable.
Definition: expr_fwd.hpp:13
void transform(Builder parent_builder, const Expr &expr)
Function to transform the expression into its optimized form.
Definition: optimizer.hpp:338
static bool is_deep(const etl::binary_expr< T, etl::scalar< T >, BinaryOp, etl::scalar< T >> &expr)
Indicates if the given expression or one of its sub expressions is optimizable or not...
Definition: optimizer.hpp:87
static void apply(Builder parent_builder, const etl::unary_expr< T, Expr, UnaryOp > &expr)
Optimize the expression using the given builder.
Definition: optimizer.hpp:369
static void transform([[maybe_unused]] Builder parent_builder, [[maybe_unused]] const etl::unary_expr< T, Expr, UnaryOp > &expr)
Transform the expression using the given builder.
Definition: optimizer.hpp:231
static bool is(const etl::binary_expr< T, LeftExpr, BinaryOp, etl::scalar< T >> &expr)
Indicates if the given expression is optimizable or not.
Definition: optimizer.hpp:134
static void transform([[maybe_unused]] Builder parent_builder, [[maybe_unused]] const etl::binary_expr< T, etl::scalar< T >, BinaryOp, RightExpr > &expr)
Transform the expression using the given builder.
Definition: optimizer.hpp:277
static bool is_deep(const etl::binary_expr< T, LeftExpr, BinaryOp, etl::scalar< T >> &expr)
Indicates if the given expression or one of its sub expressions is optimizable or not...
Definition: optimizer.hpp:159
static bool is_deep(const etl::binary_expr< T, etl::scalar< T >, BinaryOp, RightExpr > &expr)
Indicates if the given expression or one of its sub expressions is optimizable or not...
Definition: optimizer.hpp:121
Binary operator for scalar addition.
Definition: plus.hpp:154
void optimize(Builder parent_builder, Expr &expr)
Optimize an expression and reconstruct the parent from the optimized expression.
Definition: optimizer.hpp:423