30 template <
typename O,
typename L>
32 if (impl::egblas::has_sbce && impl::egblas::has_dbce) {
45 template <
typename O,
typename L>
47 if (impl::egblas::has_bce_sloss && impl::egblas::has_bce_dloss) {
60 template <
typename O,
typename L>
62 if (impl::egblas::has_bce_serror && impl::egblas::has_bce_derror) {
76 template <
typename O,
typename L>
78 constexpr
auto impl = select_bce_impl<O, L>();
84 return impl::standard::bce(output, labels, alpha, beta);
89 output_gpu.ensure_gpu_up_to_date();
90 labels_gpu.ensure_gpu_up_to_date();
92 return impl::egblas::bce(
etl::size(output), alpha, beta, output_gpu.gpu_memory(), 1, labels_gpu.gpu_memory(), 1);
94 cpp_unreachable(
"Invalid selection for BCE");
106 template <
typename O,
typename L>
108 constexpr
auto impl = select_bce_loss_impl<O, L>();
114 return impl::standard::bce_loss(output, labels, scale);
119 output_gpu.ensure_gpu_up_to_date();
120 labels_gpu.ensure_gpu_up_to_date();
122 return impl::egblas::bce_loss(
etl::size(output), scale, output_gpu.gpu_memory(), 1, labels_gpu.gpu_memory(), 1);
124 cpp_unreachable(
"Invalid selection for BCE");
136 template <
typename O,
typename L>
138 constexpr
auto impl = select_bce_error_impl<O, L>();
144 return impl::standard::bce_error(output, labels, scale);
149 output_gpu.ensure_gpu_up_to_date();
150 labels_gpu.ensure_gpu_up_to_date();
152 return impl::egblas::bce_error(
etl::size(output), scale, output_gpu.gpu_memory(), 1, labels_gpu.gpu_memory(), 1);
154 cpp_unreachable(
"Invalid selection for BCE");
constexpr etl::bce_impl select_bce_error_impl()
Select the BCE implementation for an expression of type E.
Definition: bce.hpp:61
Standard implementation of the Binary Cross Entropy reduction.
Sum operation implementation.
Definition: bce.hpp:132
constexpr etl::bce_impl select_bce_impl()
Select the BCE implementation for an expression of type E.
Definition: bce.hpp:31
void force(Expr &&expr)
Force the internal evaluation of an expression.
Definition: evaluator.hpp:1292
Definition: expression_builder.hpp:699
static value_t< O > apply(const O &output, const L &labels, value_t< O > scale)
Apply the functor to e.
Definition: bce.hpp:137
EGBLAS wrappers for the bce operations.
auto scale(LE &&lhs, RE &&rhs)
Builds an expression representing the scalar multiplication of lhs and rhs.
Definition: binary_expression_builder.hpp:64
Sum operation implementation.
Definition: bce.hpp:102
static std::pair< value_t< O >, value_t< O > > apply(const O &output, const L &labels, value_t< O > alpha, value_t< O > beta)
Apply the functor to e.
Definition: bce.hpp:77
static value_t< O > apply(const O &output, const L &labels, value_t< O > scale)
Apply the functor to e.
Definition: bce.hpp:107
decltype(auto) smart_forward_gpu(E &expr)
Smart forwarding for a temporary expression that will be computed in GPU.
Definition: helpers.hpp:343
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
bce_impl
Enumeration describing the different implementations of BCE.
Definition: bce_impl.hpp:20
constexpr etl::bce_impl select_bce_loss_impl()
Select the BCE implementation for an expression of type E.
Definition: bce.hpp:46
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
Sum operation implementation.
Definition: bce.hpp:72