30 template <
typename O,
typename L>
32 if (impl::egblas::has_smse && impl::egblas::has_dmse) {
45 template <
typename O,
typename L>
47 if (impl::egblas::has_mse_sloss && impl::egblas::has_mse_dloss) {
60 template <
typename O,
typename L>
62 if (impl::egblas::has_mse_serror && impl::egblas::has_mse_derror) {
76 template <
typename O,
typename L>
78 constexpr
auto impl = select_mse_impl<O, L>();
84 return impl::standard::mse(output, labels, alpha, beta);
89 output_gpu.ensure_gpu_up_to_date();
90 labels_gpu.ensure_gpu_up_to_date();
92 return impl::egblas::mse(
etl::size(output), alpha, beta, output_gpu.gpu_memory(), 1, labels_gpu.gpu_memory(), 1);
94 cpp_unreachable(
"Invalid selection for MSE");
106 template <
typename O,
typename L>
108 constexpr
auto impl = select_mse_loss_impl<O, L>();
114 return impl::standard::mse_loss(
etl::size(output), output, labels, scale);
119 output_gpu.ensure_gpu_up_to_date();
120 labels_gpu.ensure_gpu_up_to_date();
122 return impl::egblas::mse_loss(
etl::size(output), scale, output_gpu.gpu_memory(), 1, labels_gpu.gpu_memory(), 1);
124 cpp_unreachable(
"Invalid selection for MSE");
136 template <
typename O,
typename L>
138 constexpr
auto impl = select_mse_error_impl<O, L>();
144 return impl::standard::mse_error(
etl::size(output), output, labels, scale);
149 output_gpu.ensure_gpu_up_to_date();
150 labels_gpu.ensure_gpu_up_to_date();
152 return impl::egblas::mse_error(
etl::size(output), scale, output_gpu.gpu_memory(), 1, labels_gpu.gpu_memory(), 1);
154 cpp_unreachable(
"Invalid selection for MSE");
constexpr etl::mse_impl select_mse_impl()
Select the MSE implementation for an expression of type E.
Definition: mse.hpp:31
mse_impl
Enumeration describing the different implementations of MSE.
Definition: mse_impl.hpp:20
static std::pair< value_t< O >, value_t< O > > apply(const O &output, const L &labels, value_t< O > alpha, value_t< O > beta)
Apply the functor to e.
Definition: mse.hpp:77
Sum operation implementation.
Definition: mse.hpp:102
void force(Expr &&expr)
Force the internal evaluation of an expression.
Definition: evaluator.hpp:1292
static value_t< O > apply(const O &output, const L &labels, value_t< O > scale)
Apply the functor to e.
Definition: mse.hpp:137
Definition: expression_builder.hpp:699
Sum operation implementation.
Definition: mse.hpp:132
static value_t< O > apply(const O &output, const L &labels, value_t< O > scale)
Apply the functor to e.
Definition: mse.hpp:107
constexpr etl::mse_impl select_mse_error_impl()
Select the MSE implementation for an expression of type E.
Definition: mse.hpp:61
auto scale(LE &&lhs, RE &&rhs)
Builds an expression representing the scalar multiplication of lhs and rhs.
Definition: binary_expression_builder.hpp:64
decltype(auto) smart_forward_gpu(E &expr)
Smart forwarding for a temporary expression that will be computed in GPU.
Definition: helpers.hpp:343
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
EGBLAS wrappers for the mse operations.
constexpr etl::mse_impl select_mse_loss_impl()
Select the MSE implementation for an expression of type E.
Definition: mse.hpp:46
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
Standard implementation of the Mean Squared Error reduction.
Sum operation implementation.
Definition: mse.hpp:72