Expression Templates Library (ETL)
cce.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 namespace etl::impl::standard {
16 
22 template <typename O, typename L>
23 value_t<O> cce_loss(const O& output, const L& labels, value_t<O> scale) {
24  return scale * etl::sum(log(output) >> labels);
25 }
26 
32 template <typename O, typename L>
33 value_t<O> cce_error(const O& output, const L& labels, value_t<O> scale) {
34  return scale * sum(min(abs(argmax(labels) - argmax(output)), 1.0));
35 }
36 
43 template <typename O, typename L>
44 std::pair<value_t<O>, value_t<O>> cce(const O& output, const L& labels, value_t<O> alpha, value_t<O> beta) {
45  return std::make_pair(cce_loss(output, labels, alpha), cce_error(output, labels, beta));
46 }
47 
48 } //end of namespace etl::impl::standard
Definition: prob_pooling.hpp:10
auto abs(E &&value)
Apply absolute on each value of the given expression.
Definition: expression_builder.hpp:54
auto scale(LE &&lhs, RE &&rhs)
Builds an expression representing the scalar multiplication of lhs and rhs.
Definition: binary_expression_builder.hpp:64
auto min(L &&lhs, R &&rhs)
Create an expression with the min value of lhs or rhs.
Definition: expression_builder.hpp:77
value_t< E > sum(E &&values)
Returns the sum of all the values contained in the given expression.
Definition: expression_builder.hpp:624
auto argmax(E &&value)
Returns the indices of the maximum values in the first axis of the given matrix. If passed a vector...
Definition: expression_builder.hpp:408
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
auto log(E &&value) -> detail::unary_helper< E, log_unary_op >
Apply logarithm (base e) on each value of the given expression.
Definition: function_expression_builder.hpp:64