Expression Templates Library (ETL)
relu_derivative.hpp
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
8 #pragma once
9 
10 namespace etl {
11 
15 template <typename T>
17  static constexpr bool linear = true;
18  static constexpr bool thread_safe = true;
19  static constexpr bool desc_func = false;
20 
26  template <vector_mode_t V>
27  static constexpr bool vectorizable = true;
28 
32  template <typename L, typename R>
33  static constexpr bool gpu_computable = cudnn_enabled;
34 
39  static constexpr int complexity() {
40  return 1;
41  }
42 
46  template <typename V = default_vec>
47  using vec_type = typename V::template vec_type<T>;
48 
55  static constexpr T apply(const T& lhs, const T& rhs) noexcept {
56  return lhs > 0.0 ? rhs : 0.0;
57  }
58 
66  template <typename V = default_vec>
67  static ETL_STRONG_INLINE(vec_type<V>) load(const vec_type<V>& lhs, const vec_type<V>& rhs) noexcept {
68  auto t1 = V::round_up(V::min(V::set(T(1.0)), lhs));
69 
70  return V::mul(t1, rhs);
71  }
72 
81  template <typename L, typename R, typename Y>
82  static auto gpu_compute_hint(const L& lhs, const R& rhs, Y& y) noexcept {
83  decltype(auto) t1 = smart_gpu_compute_hint(lhs, y);
84  decltype(auto) t2 = smart_gpu_compute_hint(rhs, y);
85  decltype(auto) t3 = force_temporary_gpu_dim_only(t2);
86 
87  impl::cudnn::relu_backward(t1, t2, t3);
88 
89  return t3;
90  }
91 
100  template <typename L, typename R, typename Y>
101  static Y& gpu_compute(const L& lhs, const R& rhs, Y& y) noexcept {
102  decltype(auto) t1 = smart_gpu_compute_hint(lhs, y);
103  decltype(auto) t2 = smart_gpu_compute_hint(rhs, y);
104 
105  impl::cudnn::relu_backward(t1, t2, y);
106 
107  return y;
108  }
109 
114  static std::string desc() noexcept {
115  return "relu_back";
116  }
117 };
118 
119 } //end of namespace etl
static constexpr int complexity()
Estimate the complexity of operator.
Definition: relu_derivative.hpp:39
static constexpr T apply(const T &lhs, const T &rhs) noexcept
Apply the unary operator on lhs and rhs.
Definition: relu_derivative.hpp:55
static constexpr bool vectorizable
Indicates if the expression is vectorizable using the given vector mode.
Definition: relu_derivative.hpp:27
static constexpr bool linear
Indicates if the operator is linear or not.
Definition: relu_derivative.hpp:17
static constexpr bool desc_func
Indicates if the description must be printed as function.
Definition: relu_derivative.hpp:19
Binary operator for relu derivative.
Definition: relu_derivative.hpp:16
auto load(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:143
Root namespace for the ETL library.
Definition: adapter.hpp:15
static constexpr bool gpu_computable
Indicates if the operator can be computed on GPU.
Definition: relu_derivative.hpp:33
constexpr bool cudnn_enabled
Indicates if the NVIDIA CUDNN library is available for ETL.
Definition: config.hpp:114
static constexpr bool thread_safe
Indicates if the operator is thread safe or not.
Definition: relu_derivative.hpp:18
static std::string desc() noexcept
Returns a textual representation of the operator.
Definition: relu_derivative.hpp:114
auto min(L &&lhs, R &&rhs)
Create an expression with the min value of lhs or rhs.
Definition: expression_builder.hpp:77
decltype(auto) force_temporary_gpu_dim_only(E &&expr)
Force a temporary out of the expression, without copying its content.
Definition: temporary.hpp:223
static auto gpu_compute_hint(const L &lhs, const R &rhs, Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: relu_derivative.hpp:82
typename V::template vec_type< T > vec_type
Definition: relu_derivative.hpp:47
decltype(auto) smart_gpu_compute_hint(E &expr, Y &y)
Compute the expression into a representation that is GPU up to date.
Definition: helpers.hpp:368
static ETL_STRONG_INLINE(vec_type< V >) load(const vec_type< V > &lhs
Compute several applications of the operator at a time.
static Y & gpu_compute(const L &lhs, const R &rhs, Y &y) noexcept
Compute the result of the operation using the GPU.
Definition: relu_derivative.hpp:101