30 template <
typename O,
typename L>
32 if (impl::egblas::has_scce && impl::egblas::has_dcce) {
45 template <
typename O,
typename L>
47 if (impl::egblas::has_cce_sloss && impl::egblas::has_cce_dloss) {
60 template <
typename O,
typename L>
62 if (impl::egblas::has_cce_serror && impl::egblas::has_cce_derror) {
76 template <
typename O,
typename L>
78 constexpr
auto impl = select_cce_impl<O, L>();
84 return impl::standard::cce(output, labels, alpha, beta);
89 output_gpu.ensure_gpu_up_to_date();
90 labels_gpu.ensure_gpu_up_to_date();
92 return impl::egblas::cce(etl::dim<0>(output), etl::dim<1>(output), alpha, beta, output_gpu.gpu_memory(), labels_gpu.gpu_memory());
94 cpp_unreachable(
"Invalid selection for CCE");
106 template <
typename O,
typename L>
108 constexpr
auto impl = select_cce_loss_impl<O, L>();
114 return impl::standard::cce_loss(output, labels, scale);
119 output_gpu.ensure_gpu_up_to_date();
120 labels_gpu.ensure_gpu_up_to_date();
122 return impl::egblas::cce_loss(
etl::size(output), scale, output_gpu.gpu_memory(), 1, labels_gpu.gpu_memory(), 1);
124 cpp_unreachable(
"Invalid selection for CCE");
136 template <
typename O,
typename L>
138 constexpr
auto impl = select_cce_error_impl<O, L>();
144 return impl::standard::cce_error(output, labels, scale);
149 output_gpu.ensure_gpu_up_to_date();
150 labels_gpu.ensure_gpu_up_to_date();
152 return impl::egblas::cce_error(etl::dim<0>(output), etl::dim<1>(output), scale, output_gpu.gpu_memory(), labels_gpu.gpu_memory());
154 cpp_unreachable(
"Invalid selection for CCE");
static std::pair< value_t< O >, value_t< O > > apply(const O &output, const L &labels, value_t< O > alpha, value_t< O > beta)
Apply the functor to e.
Definition: cce.hpp:77
EGBLAS wrappers for the cce operations.
constexpr etl::cce_impl select_cce_error_impl()
Select the CCE implementation for an expression of type E.
Definition: cce.hpp:61
void force(Expr &&expr)
Force the internal evaluation of an expression.
Definition: evaluator.hpp:1292
Definition: expression_builder.hpp:699
Sum operation implementation.
Definition: cce.hpp:132
constexpr etl::cce_impl select_cce_impl()
Select the CCE implementation for an expression of type E.
Definition: cce.hpp:31
cce_impl
Enumeration describing the different implementations of CCE.
Definition: cce_impl.hpp:20
auto scale(LE &&lhs, RE &&rhs)
Builds an expression representing the scalar multiplication of lhs and rhs.
Definition: binary_expression_builder.hpp:64
Sum operation implementation.
Definition: cce.hpp:102
Standard implementation of the Categorical Cross Entropy reduction.
constexpr etl::cce_impl select_cce_loss_impl()
Select the CCE implementation for an expression of type E.
Definition: cce.hpp:46
decltype(auto) smart_forward_gpu(E &expr)
Smart forwarding for a temporary expression that will be computed in GPU.
Definition: helpers.hpp:343
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
static value_t< O > apply(const O &output, const L &labels, value_t< O > scale)
Apply the functor to e.
Definition: cce.hpp:137
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
Sum operation implementation.
Definition: cce.hpp:72
static value_t< O > apply(const O &output, const L &labels, value_t< O > scale)
Apply the functor to e.
Definition: cce.hpp:107