Expression Templates Library (ETL)
mse.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
16 #pragma once
17 
18 //Include the implementations
19 #include "etl/impl/std/mse.hpp"
20 #include "etl/impl/egblas/mse.hpp"
21 
22 namespace etl::detail {
23 
30 template <typename O, typename L>
32  if (impl::egblas::has_smse && impl::egblas::has_dmse) {
33  return etl::mse_impl::EGBLAS;
34  }
35 
36  return etl::mse_impl::STD;
37 }
38 
45 template <typename O, typename L>
47  if (impl::egblas::has_mse_sloss && impl::egblas::has_mse_dloss) {
48  return etl::mse_impl::EGBLAS;
49  }
50 
51  return etl::mse_impl::STD;
52 }
53 
60 template <typename O, typename L>
62  if (impl::egblas::has_mse_serror && impl::egblas::has_mse_derror) {
63  return etl::mse_impl::EGBLAS;
64  }
65 
66  return etl::mse_impl::STD;
67 }
68 
72 struct mse_impl {
76  template <typename O, typename L>
77  static std::pair<value_t<O>, value_t<O>> apply(const O& output, const L& labels, value_t<O> alpha, value_t<O> beta) {
78  constexpr auto impl = select_mse_impl<O, L>();
79 
80  if constexpr (impl == etl::mse_impl::STD) {
81  etl::force(output);
82  etl::force(labels);
83 
84  return impl::standard::mse(output, labels, alpha, beta);
85  } else if constexpr (impl == etl::mse_impl::EGBLAS) {
86  decltype(auto) output_gpu = smart_forward_gpu(output);
87  decltype(auto) labels_gpu = smart_forward_gpu(labels);
88 
89  output_gpu.ensure_gpu_up_to_date();
90  labels_gpu.ensure_gpu_up_to_date();
91 
92  return impl::egblas::mse(etl::size(output), alpha, beta, output_gpu.gpu_memory(), 1, labels_gpu.gpu_memory(), 1);
93  } else {
94  cpp_unreachable("Invalid selection for MSE");
95  }
96  }
97 };
98 
106  template <typename O, typename L>
107  static value_t<O> apply(const O& output, const L& labels, value_t<O> scale) {
108  constexpr auto impl = select_mse_loss_impl<O, L>();
109 
110  if constexpr (impl == etl::mse_impl::STD) {
111  etl::force(output);
112  etl::force(labels);
113 
114  return impl::standard::mse_loss(etl::size(output), output, labels, scale);
115  } else if constexpr (impl == etl::mse_impl::EGBLAS) {
116  decltype(auto) output_gpu = smart_forward_gpu(output);
117  decltype(auto) labels_gpu = smart_forward_gpu(labels);
118 
119  output_gpu.ensure_gpu_up_to_date();
120  labels_gpu.ensure_gpu_up_to_date();
121 
122  return impl::egblas::mse_loss(etl::size(output), scale, output_gpu.gpu_memory(), 1, labels_gpu.gpu_memory(), 1);
123  } else {
124  cpp_unreachable("Invalid selection for MSE");
125  }
126  }
127 };
128 
136  template <typename O, typename L>
137  static value_t<O> apply(const O& output, const L& labels, value_t<O> scale) {
138  constexpr auto impl = select_mse_error_impl<O, L>();
139 
140  if constexpr (impl == etl::mse_impl::STD) {
141  etl::force(output);
142  etl::force(labels);
143 
144  return impl::standard::mse_error(etl::size(output), output, labels, scale);
145  } else if constexpr (impl == etl::mse_impl::EGBLAS) {
146  decltype(auto) output_gpu = smart_forward_gpu(output);
147  decltype(auto) labels_gpu = smart_forward_gpu(labels);
148 
149  output_gpu.ensure_gpu_up_to_date();
150  labels_gpu.ensure_gpu_up_to_date();
151 
152  return impl::egblas::mse_error(etl::size(output), scale, output_gpu.gpu_memory(), 1, labels_gpu.gpu_memory(), 1);
153  } else {
154  cpp_unreachable("Invalid selection for MSE");
155  }
156  }
157 };
158 
159 } //end of namespace etl::detail
constexpr etl::mse_impl select_mse_impl()
Select the MSE implementation for an expression of type E.
Definition: mse.hpp:31
mse_impl
Enumeration describing the different implementations of MSE.
Definition: mse_impl.hpp:20
Standard implementation.
static std::pair< value_t< O >, value_t< O > > apply(const O &output, const L &labels, value_t< O > alpha, value_t< O > beta)
Apply the functor to e.
Definition: mse.hpp:77
Sum operation implementation.
Definition: mse.hpp:102
void force(Expr &&expr)
Force the internal evaluation of an expression.
Definition: evaluator.hpp:1292
static value_t< O > apply(const O &output, const L &labels, value_t< O > scale)
Apply the functor to e.
Definition: mse.hpp:137
Definition: expression_builder.hpp:699
Sum operation implementation.
Definition: mse.hpp:132
GPU implementation.
static value_t< O > apply(const O &output, const L &labels, value_t< O > scale)
Apply the functor to e.
Definition: mse.hpp:107
constexpr etl::mse_impl select_mse_error_impl()
Select the MSE implementation for an expression of type E.
Definition: mse.hpp:61
auto scale(LE &&lhs, RE &&rhs)
Builds an expression representing the scalar multiplication of lhs and rhs.
Definition: binary_expression_builder.hpp:64
decltype(auto) smart_forward_gpu(E &expr)
Smart forwarding for a temporary expression that will be computed in GPU.
Definition: helpers.hpp:343
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
EGBLAS wrappers for the mse operations.
constexpr etl::mse_impl select_mse_loss_impl()
Select the MSE implementation for an expression of type E.
Definition: mse.hpp:46
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
Standard implementation of the Mean Squared Error reduction.
Sum operation implementation.
Definition: mse.hpp:72