Expression Templates Library (ETL)
reduc_transformers.hpp
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
8 #pragma once
9 
10 namespace etl {
11 
17 template <typename T>
19  using sub_type = T;
21 
22  friend struct etl_traits<argmax_transformer>;
23 
24  static constexpr bool gpu_computable = false;
25 
26 private:
27  sub_type sub;
28 
29 public:
34  explicit argmax_transformer(sub_type expr) : sub(expr) {}
35 
41  value_type operator[](size_t i) const {
42  return value_type(max_index(sub(i)));
43  }
44 
51  value_type read_flat(size_t i) const {
52  return value_type(max_index(sub(i)));
53  }
54 
58  template <typename... Sizes>
59  value_type operator()(size_t i, Sizes... /*sizes*/) const {
60  return value_type(max_index(sub(i)));
61  }
62 
68  template <typename E>
69  bool alias(const E& rhs) const noexcept {
70  return sub.alias(rhs);
71  }
72 
73  // Internals
74 
79  template <typename V>
80  void visit(V&& visitor) const {
81  sub.visit(std::forward<V>(visitor));
82  }
83 
88  void ensure_cpu_up_to_date() const {
89  // Need to ensure sub value
90  sub.ensure_cpu_up_to_date();
91  }
92 
97  void ensure_gpu_up_to_date() const {
98  // Need to ensure both LHS and RHS
99  sub.ensure_gpu_up_to_date();
100  }
101 
108  friend std::ostream& operator<<(std::ostream& os, const argmax_transformer& transformer) {
109  return os << "argmax(" << transformer.sub << ")";
110  }
111 };
112 
118 template <typename T>
120  using sub_type = T;
122 
123  friend struct etl_traits<argmin_transformer>;
124 
125  static constexpr bool gpu_computable = false;
126 
127 private:
128  sub_type sub;
129 
130 public:
135  explicit argmin_transformer(sub_type expr) : sub(expr) {}
136 
142  value_type operator[](size_t i) const {
143  return value_type(min_index(sub(i)));
144  }
145 
152  value_type read_flat(size_t i) const {
153  return value_type(min_index(sub(i)));
154  }
155 
159  template <typename... Sizes>
160  value_type operator()(size_t i, Sizes... /*sizes*/) const {
161  return value_type(min_index(sub(i)));
162  }
163 
169  template <typename E>
170  bool alias(const E& rhs) const noexcept {
171  return sub.alias(rhs);
172  }
173 
174  // Internals
175 
180  template <typename V>
181  void visit(V&& visitor) const {
182  sub.visit(std::forward<V>(visitor));
183  }
184 
189  void ensure_cpu_up_to_date() const {
190  // Need to ensure sub value
191  sub.ensure_cpu_up_to_date();
192  }
193 
198  void ensure_gpu_up_to_date() const {
199  // Need to ensure both LHS and RHS
200  sub.ensure_gpu_up_to_date();
201  }
202 
209  friend std::ostream& operator<<(std::ostream& os, const argmin_transformer& transformer) {
210  return os << "argmin(" << transformer.sub << ")";
211  }
212 };
213 
218 template <typename T>
220  using sub_type = T;
222 
223  friend struct etl_traits<sum_r_transformer>;
224 
225  static constexpr bool gpu_computable = false;
226 
227 private:
228  sub_type sub;
229 
230 public:
235  explicit sum_r_transformer(sub_type expr) : sub(expr) {}
236 
242  value_type operator[](size_t i) const {
243  return sum(sub(i));
244  }
245 
252  value_type read_flat(size_t i) const {
253  return sum(sub(i));
254  }
255 
259  template <typename... Sizes>
260  value_type operator()(size_t i, Sizes... /*sizes*/) const {
261  return sum(sub(i));
262  }
263 
269  template <typename E>
270  bool alias(const E& rhs) const noexcept {
271  return sub.alias(rhs);
272  }
273 
274  // Internals
275 
280  template <typename V>
281  void visit(V&& visitor) const {
282  sub.visit(std::forward<V>(visitor));
283  }
284 
289  void ensure_cpu_up_to_date() const {
290  // Need to ensure sub value
291  sub.ensure_cpu_up_to_date();
292  }
293 
298  void ensure_gpu_up_to_date() const {
299  // Need to ensure both LHS and RHS
300  sub.ensure_gpu_up_to_date();
301  }
302 
309  friend std::ostream& operator<<(std::ostream& os, const sum_r_transformer& transformer) {
310  return os << "sum_r(" << transformer.sub << ")";
311  }
312 };
313 
318 template <typename T>
320  using sub_type = T;
322 
323  friend struct etl_traits<mean_r_transformer>;
324 
325  static constexpr bool gpu_computable = false;
326 
327 private:
328  sub_type sub;
329 
330 public:
335  explicit mean_r_transformer(sub_type expr) : sub(expr) {}
336 
342  value_type operator[](size_t i) const {
343  return mean(sub(i));
344  }
345 
352  value_type read_flat(size_t i) const {
353  return mean(sub(i));
354  }
355 
359  template <typename... Sizes>
360  value_type operator()(size_t i, Sizes... /*sizes*/) const {
361  return mean(sub(i));
362  }
363 
369  template <typename E>
370  bool alias(const E& rhs) const noexcept {
371  return sub.alias(rhs);
372  }
373 
374  // Internals
375 
380  template <typename V>
381  void visit(V&& visitor) const {
382  sub.visit(std::forward<V>(visitor));
383  }
384 
389  void ensure_cpu_up_to_date() const {
390  // Need to ensure sub value
391  sub.ensure_cpu_up_to_date();
392  }
393 
398  void ensure_gpu_up_to_date() const {
399  // Need to ensure both LHS and RHS
400  sub.ensure_gpu_up_to_date();
401  }
402 
409  friend std::ostream& operator<<(std::ostream& os, const mean_r_transformer& transformer) {
410  return os << "mean_r(" << transformer.sub << ")";
411  }
412 };
413 
418 template <typename T>
420  using sub_type = T;
422 
423  friend struct etl_traits<sum_l_transformer>;
424 
425  static constexpr bool gpu_computable = false;
426 
427 private:
428  sub_type sub;
429 
430 public:
435  explicit sum_l_transformer(sub_type expr) : sub(expr) {}
436 
442  value_type operator[](size_t j) const {
443  value_type m = 0.0;
444 
445  for (size_t i = 0; i < dim<0>(sub); ++i) {
446  m += sub[j + i * (etl::size(sub) / dim<0>(sub))];
447  }
448 
449  return m;
450  }
451 
458  value_type read_flat(size_t j) const noexcept {
459  value_type m = 0.0;
460 
461  for (size_t i = 0; i < dim<0>(sub); ++i) {
462  m += sub.read_flat(j + i * (etl::size(sub) / dim<0>(sub)));
463  }
464 
465  return m;
466  }
467 
474  template <typename... Sizes>
475  value_type operator()(size_t j, Sizes... sizes) const {
476  value_type m = 0.0;
477 
478  for (size_t i = 0; i < dim<0>(sub); ++i) {
479  m += sub(i, j, sizes...);
480  }
481 
482  return m;
483  }
484 
490  template <typename E>
491  bool alias(const E& rhs) const noexcept {
492  return sub.alias(rhs);
493  }
494 
495  // Internals
496 
501  template <typename V>
502  void visit(V&& visitor) const {
503  sub.visit(std::forward<V>(visitor));
504  }
505 
510  void ensure_cpu_up_to_date() const {
511  // Need to ensure sub value
512  sub.ensure_cpu_up_to_date();
513  }
514 
519  void ensure_gpu_up_to_date() const {
520  // Need to ensure both LHS and RHS
521  sub.ensure_gpu_up_to_date();
522  }
523 
530  friend std::ostream& operator<<(std::ostream& os, const sum_l_transformer& transformer) {
531  return os << "sum_l(" << transformer.sub << ")";
532  }
533 };
534 
539 template <typename T>
541  using sub_type = T;
543 
544  friend struct etl_traits<mean_l_transformer>;
545 
546  static constexpr bool gpu_computable = false;
547 
548 private:
549  sub_type sub;
550 
551 public:
556  explicit mean_l_transformer(sub_type expr) : sub(expr) {}
557 
563  value_type operator[](size_t j) const {
564  value_type m = 0.0;
565 
566  for (size_t i = 0; i < dim<0>(sub); ++i) {
567  m += sub[j + i * (etl::size(sub) / dim<0>(sub))];
568  }
569 
570  return m / value_type(dim<0>(sub));
571  }
572 
579  value_type read_flat(size_t j) const noexcept {
580  value_type m = 0.0;
581 
582  for (size_t i = 0; i < dim<0>(sub); ++i) {
583  m += sub.read_flat(j + i * (etl::size(sub) / dim<0>(sub)));
584  }
585 
586  return m / value_type(dim<0>(sub));
587  }
588 
595  template <typename... Sizes>
596  value_type operator()(size_t j, Sizes... sizes) const {
597  value_type m = 0.0;
598 
599  for (size_t i = 0; i < dim<0>(sub); ++i) {
600  m += sub(i, j, sizes...);
601  }
602 
603  return m / value_type(dim<0>(sub));
604  }
605 
611  template <typename E>
612  bool alias(const E& rhs) const noexcept {
613  return sub.alias(rhs);
614  }
615 
616  // Internals
617 
622  template <typename V>
623  void visit(V&& visitor) const {
624  sub.visit(std::forward<V>(visitor));
625  }
626 
631  void ensure_cpu_up_to_date() const {
632  // Need to ensure sub value
633  sub.ensure_cpu_up_to_date();
634  }
635 
640  void ensure_gpu_up_to_date() const {
641  // Need to ensure both LHS and RHS
642  sub.ensure_gpu_up_to_date();
643  }
644 
651  friend std::ostream& operator<<(std::ostream& os, const mean_l_transformer& transformer) {
652  return os << "mean_l(" << transformer.sub << ")";
653  }
654 };
655 
659 template <typename T>
661  cpp::specialization_of<etl::argmax_transformer, T>
662  || cpp::specialization_of<etl::argmin_transformer, T>
663  || cpp::specialization_of<etl::sum_r_transformer, T>
664  || cpp::specialization_of<etl::mean_r_transformer, T>)
665 struct etl_traits<T> {
666  using expr_t = T;
669 
670  static constexpr bool is_etl = true;
671  static constexpr bool is_transformer = true;
672  static constexpr bool is_view = false;
673  static constexpr bool is_magic_view = false;
674  static constexpr bool is_fast = etl_traits<sub_expr_t>::is_fast;
675  static constexpr bool is_linear = false;
677  static constexpr bool is_value = false;
678  static constexpr bool is_direct = false;
679  static constexpr bool is_generator = false;
680  static constexpr bool is_padded = false;
681  static constexpr bool is_aligned = false;
682  static constexpr bool is_temporary = etl_traits<sub_expr_t>::is_temporary;
683  static constexpr bool gpu_computable = false;
684  static constexpr order storage_order = etl_traits<sub_expr_t>::storage_order;
685 
691  template <vector_mode_t V>
692  static constexpr bool vectorizable = false;
693 
699  static size_t size(const expr_t& v) {
700  return etl::dim<0>(v.sub);
701  }
702 
709  static size_t dim(const expr_t& v, [[maybe_unused]] size_t d) {
710  return etl::dim<0>(v.sub);
711  }
712 
717  static constexpr size_t size() {
718  return etl_traits<sub_expr_t>::template dim<0>();
719  }
720 
726  template <size_t D>
727  static constexpr size_t dim() {
728  return etl_traits<sub_expr_t>::template dim<0>();
729  }
730 
735  static constexpr size_t dimensions() {
736  return 1;
737  }
738 
743  static constexpr int complexity() noexcept {
744  return 1;
745  }
746 };
747 
751 template <typename T>
752 requires(cpp::specialization_of<etl::sum_l_transformer, T> || cpp::specialization_of<etl::mean_l_transformer, T>)
753 struct etl_traits<T> {
754  using expr_t = T;
757 
758  static constexpr bool is_etl = true;
759  static constexpr bool is_transformer = true;
760  static constexpr bool is_view = false;
761  static constexpr bool is_magic_view = false;
762  static constexpr bool is_fast = etl_traits<sub_expr_t>::is_fast;
763  static constexpr bool is_linear = false;
765  static constexpr bool is_value = false;
766  static constexpr bool is_direct = false;
767  static constexpr bool is_generator = false;
768  static constexpr bool is_padded = false;
769  static constexpr bool is_aligned = false;
770  static constexpr bool is_temporary = etl_traits<sub_expr_t>::is_temporary;
771  static constexpr bool gpu_computable = false;
772  static constexpr order storage_order = etl_traits<sub_expr_t>::storage_order;
773 
779  template <vector_mode_t V>
780  static constexpr bool vectorizable = false;
781 
787  static size_t size(const expr_t& v) {
788  return etl::size(v.sub) / etl::dim<0>(v.sub);
789  }
790 
797  static size_t dim(const expr_t& v, size_t d) {
798  return etl::dim(v.sub, d + 1);
799  }
800 
805  static constexpr size_t size() {
807  }
808 
814  template <size_t D>
815  static constexpr size_t dim() {
816  return etl_traits<sub_expr_t>::template dim<D + 1>();
817  }
818 
823  static constexpr size_t dimensions() {
825  }
826 
831  static constexpr int complexity() noexcept {
832  return 1;
833  }
834 };
835 
836 } //end of namespace etl
value_t< E > mean(E &&values)
Returns the mean of all the values contained in the given expression.
Definition: expression_builder.hpp:650
Transform (dynamic) that returns only the maximum elements from the right dimensions.
Definition: reduc_transformers.hpp:119
constexpr int complexity([[maybe_unused]] const E &expr) noexcept
Return the complexity of the expression.
Definition: helpers.hpp:38
T sub_type
The type on which the expression works.
Definition: reduc_transformers.hpp:220
void ensure_gpu_up_to_date() const
Copy back from the GPU to the expression memory if necessary.
Definition: reduc_transformers.hpp:640
value_type operator[](size_t i) const
Returns the value at the given index.
Definition: reduc_transformers.hpp:142
Transformer functor for optimizable expression.
Definition: expr_fwd.hpp:19
Transform (dynamic) that sums the expression from the left, effectively removing the left dimension...
Definition: reduc_transformers.hpp:419
value_t< T > value_type
The type of valuie.
Definition: reduc_transformers.hpp:542
void ensure_cpu_up_to_date() const
Ensures that the GPU memory is allocated and that the GPU memory is up to date (to undefined value)...
Definition: reduc_transformers.hpp:88
void ensure_cpu_up_to_date() const
Ensures that the GPU memory is allocated and that the GPU memory is up to date (to undefined value)...
Definition: reduc_transformers.hpp:189
argmax_transformer(sub_type expr)
Construct a new transformer around the given expression.
Definition: reduc_transformers.hpp:34
value_t< sub_type > value_type
The value contained in the expression.
Definition: dyn_matrix_view.hpp:31
constexpr bool is_magic_view
Traits indicating if the given ETL type is a magic view expression.
Definition: traits.hpp:311
friend std::ostream & operator<<(std::ostream &os, const mean_r_transformer &transformer)
Display the transformer on the given stream.
Definition: reduc_transformers.hpp:409
value_type operator[](size_t j) const
Returns the value at the given index.
Definition: reduc_transformers.hpp:442
T sub_type
The type on which the expression works.
Definition: reduc_transformers.hpp:120
friend std::ostream & operator<<(std::ostream &os, const mean_l_transformer &transformer)
Display the transformer on the given stream.
Definition: reduc_transformers.hpp:651
value_type read_flat(size_t i) const
Returns the value at the given index This function never has side effects.
Definition: reduc_transformers.hpp:252
order
Storage order of a matrix.
Definition: order.hpp:15
bool alias(const E &rhs) const noexcept
Test if this expression aliases with the given expression.
Definition: reduc_transformers.hpp:170
T sub_type
The type on which the expression works.
Definition: reduc_transformers.hpp:420
void visit(V &&visitor) const
Apply the given visitor to this expression and its descendants.
Definition: reduc_transformers.hpp:181
bool alias(const E &rhs) const noexcept
Test if this expression aliases with the given expression.
Definition: reduc_transformers.hpp:370
bool alias(const E &rhs) const noexcept
Test if this expression aliases with the given expression.
Definition: reduc_transformers.hpp:612
Transform (dynamic) that averages the expression from the left, effectively removing the left dimensi...
Definition: reduc_transformers.hpp:540
value_t< T > value_type
The type of valuie.
Definition: reduc_transformers.hpp:421
value_type operator()(size_t i, Sizes...) const
Returns the value at the given position (i, sizes...)
Definition: reduc_transformers.hpp:360
T sub_type
The sub type.
Definition: dyn_matrix_view.hpp:30
value_type read_flat(size_t i) const
Returns the value at the given index This function never has side effects.
Definition: reduc_transformers.hpp:152
value_type read_flat(size_t j) const noexcept
Returns the value at the given index This function never has side effects.
Definition: reduc_transformers.hpp:579
void ensure_cpu_up_to_date() const
Ensures that the GPU memory is allocated and that the GPU memory is up to date (to undefined value)...
Definition: reduc_transformers.hpp:389
constexpr bool is_fast
Traits to test if the given ETL expresion type is fast (sizes known at compile-time) ...
Definition: traits.hpp:588
argmin_transformer(sub_type expr)
Construct a new transformer around the given expression.
Definition: reduc_transformers.hpp:135
value_type operator[](size_t i) const
Returns the value at the given index.
Definition: reduc_transformers.hpp:242
void ensure_gpu_up_to_date() const
Copy back from the GPU to the expression memory if necessary.
Definition: reduc_transformers.hpp:398
void ensure_cpu_up_to_date() const
Ensures that the GPU memory is allocated and that the GPU memory is up to date (to undefined value)...
Definition: reduc_transformers.hpp:510
bool alias(const E &rhs) const noexcept
Test if this expression aliases with the given expression.
Definition: reduc_transformers.hpp:491
value_t< T > value_type
The type of valuie.
Definition: reduc_transformers.hpp:321
Traits to get information about ETL types.
Definition: tmp.hpp:68
Root namespace for the ETL library.
Definition: adapter.hpp:15
void ensure_gpu_up_to_date() const
Copy back from the GPU to the expression memory if necessary.
Definition: reduc_transformers.hpp:97
void ensure_cpu_up_to_date() const
Ensures that the GPU memory is allocated and that the GPU memory is up to date (to undefined value)...
Definition: reduc_transformers.hpp:631
static constexpr size_t dimensions()
Return the number of dimensions of the expression.
Definition: traits_base.hpp:31
value_type operator()(size_t i, Sizes...) const
Returns the value at the given position (i, sizes...)
Definition: reduc_transformers.hpp:59
size_t max_index(E &&values)
Returns the index of the maximum element contained in the expression.
Definition: expression_builder.hpp:720
value_type operator[](size_t j) const
Returns the value at the given index.
Definition: reduc_transformers.hpp:563
value_type operator()(size_t i, Sizes...) const
Returns the value at the given position (i, sizes...)
Definition: reduc_transformers.hpp:160
auto dim(E &&value, size_t i) -> detail::identity_helper< E, dim_view< detail::build_identity_type< E >, D >>
Return a view representing the ith Dth dimension.
Definition: view_expression_builder.hpp:25
value_t< T > value_type
The type of valuie.
Definition: reduc_transformers.hpp:121
friend std::ostream & operator<<(std::ostream &os, const sum_l_transformer &transformer)
Display the transformer on the given stream.
Definition: reduc_transformers.hpp:530
friend std::ostream & operator<<(std::ostream &os, const argmax_transformer &transformer)
Display the transformer on the given stream.
Definition: reduc_transformers.hpp:108
void ensure_gpu_up_to_date() const
Copy back from the GPU to the expression memory if necessary.
Definition: reduc_transformers.hpp:298
value_t< T > value_type
The type of valuie.
Definition: reduc_transformers.hpp:221
friend std::ostream & operator<<(std::ostream &os, const argmin_transformer &transformer)
Display the transformer on the given stream.
Definition: reduc_transformers.hpp:209
value_type operator()(size_t j, Sizes... sizes) const
Access to the value at the given (j, sizes...) position.
Definition: reduc_transformers.hpp:596
value_t< T > value_type
The type of valuie.
Definition: reduc_transformers.hpp:20
T sub_type
The type on which the expression works.
Definition: reduc_transformers.hpp:19
value_type operator()(size_t i, Sizes...) const
Returns the value at the given position (i, sizes...)
Definition: reduc_transformers.hpp:260
value_type read_flat(size_t j) const noexcept
Returns the value at the given index This function never has side effects.
Definition: reduc_transformers.hpp:458
constexpr bool is_transformer
Traits indicating if the given ETL type is a transformer expression.
Definition: traits.hpp:297
value_t< E > sum(E &&values)
Returns the sum of all the values contained in the given expression.
Definition: expression_builder.hpp:624
T sub_type
The type on which the expression works.
Definition: reduc_transformers.hpp:320
Transform (dynamic) that averages the expression from the right, effectively removing the right dimen...
Definition: reduc_transformers.hpp:319
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
value_type read_flat(size_t i) const
Returns the value at the given index This function never has side effects.
Definition: reduc_transformers.hpp:352
void visit(V &&visitor) const
Apply the given visitor to this expression and its descendants.
Definition: reduc_transformers.hpp:80
sum_r_transformer(sub_type expr)
Construct a new transformer around the given expression.
Definition: reduc_transformers.hpp:235
requires(D > 0) struct dyn_base
Matrix with run-time fixed dimensions.
Definition: dyn_base.hpp:113
constexpr bool is_view
Traits indicating if the given ETL type is a view expression.
Definition: traits.hpp:304
Transform (dynamic) that returns only the maximum elements from the right dimensions.
Definition: reduc_transformers.hpp:18
value_type read_flat(size_t i) const
Returns the value at the given index This function never has side effects.
Definition: reduc_transformers.hpp:51
void visit(V &&visitor) const
Apply the given visitor to this expression and its descendants.
Definition: reduc_transformers.hpp:502
bool alias(const E &rhs) const noexcept
Test if this expression aliases with the given expression.
Definition: reduc_transformers.hpp:69
void ensure_cpu_up_to_date() const
Ensures that the GPU memory is allocated and that the GPU memory is up to date (to undefined value)...
Definition: reduc_transformers.hpp:289
value_type operator()(size_t j, Sizes... sizes) const
Access to the value at the given (j, sizes...) position.
Definition: reduc_transformers.hpp:475
Transform (dynamic) that sums the expression from the right, effectively removing the right dimension...
Definition: reduc_transformers.hpp:219
void visit(V &&visitor) const
Apply the given visitor to this expression and its descendants.
Definition: reduc_transformers.hpp:281
constexpr bool is_thread_safe
Traits to test if the given ETL expresion type is thread safe.
Definition: traits.hpp:687
void ensure_gpu_up_to_date() const
Copy back from the GPU to the expression memory if necessary.
Definition: reduc_transformers.hpp:198
value_type operator[](size_t i) const
Returns the value at the given index.
Definition: reduc_transformers.hpp:41
size_t min_index(E &&values)
Returns the index of the minimum element contained in the expression.
Definition: expression_builder.hpp:753
sum_l_transformer(sub_type expr)
Construct a new transformer around the given expression.
Definition: reduc_transformers.hpp:435
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
mean_r_transformer(sub_type expr)
Construct a new transformer around the given expression.
Definition: reduc_transformers.hpp:335
void ensure_gpu_up_to_date() const
Copy back from the GPU to the expression memory if necessary.
Definition: reduc_transformers.hpp:519
bool alias(const E &rhs) const noexcept
Test if this expression aliases with the given expression.
Definition: reduc_transformers.hpp:270
T sub_type
The type on which the expression works.
Definition: reduc_transformers.hpp:541
value_type operator[](size_t i) const
Returns the value at the given index.
Definition: reduc_transformers.hpp:342
void visit(V &&visitor) const
Apply the given visitor to this expression and its descendants.
Definition: reduc_transformers.hpp:381
mean_l_transformer(sub_type expr)
Construct a new transformer around the given expression.
Definition: reduc_transformers.hpp:556
void visit(V &&visitor) const
Apply the given visitor to this expression and its descendants.
Definition: reduc_transformers.hpp:623
friend std::ostream & operator<<(std::ostream &os, const sum_r_transformer &transformer)
Display the transformer on the given stream.
Definition: reduc_transformers.hpp:309