Expression Templates Library (ETL)
dot.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 namespace etl::impl::vec {
16 
23 template <typename V, typename L, typename R>
24 value_t<L> dot_impl(const L& lhs, const R& rhs) {
25  using vec_type = V;
26  using T = value_t<L>;
27 
28  static constexpr size_t vec_size = vec_type::template traits<T>::size;
29 
30  auto n = etl::size(lhs);
31 
32  static constexpr bool remainder = !padding || !all_padded<L, R>;
33  const size_t last = remainder ? prev_multiple(n, vec_size) : n;
34 
35  size_t i = 0;
36 
37  auto r1 = vec_type::template zero<T>();
38  auto r2 = vec_type::template zero<T>();
39  auto r3 = vec_type::template zero<T>();
40  auto r4 = vec_type::template zero<T>();
41  auto r5 = vec_type::template zero<T>();
42  auto r6 = vec_type::template zero<T>();
43  auto r7 = vec_type::template zero<T>();
44  auto r8 = vec_type::template zero<T>();
45 
46  if (n <= 4 * cache_size / sizeof(T)) {
47  for (; i + (vec_size * 7) < last; i += 8 * vec_size) {
48  auto a1 = lhs.template load<vec_type>(i + 0 * vec_size);
49  auto a2 = lhs.template load<vec_type>(i + 1 * vec_size);
50  auto a3 = lhs.template load<vec_type>(i + 2 * vec_size);
51  auto a4 = lhs.template load<vec_type>(i + 3 * vec_size);
52  auto a5 = lhs.template load<vec_type>(i + 4 * vec_size);
53  auto a6 = lhs.template load<vec_type>(i + 5 * vec_size);
54  auto a7 = lhs.template load<vec_type>(i + 6 * vec_size);
55  auto a8 = lhs.template load<vec_type>(i + 7 * vec_size);
56 
57  auto b1 = rhs.template load<vec_type>(i + 0 * vec_size);
58  auto b2 = rhs.template load<vec_type>(i + 1 * vec_size);
59  auto b3 = rhs.template load<vec_type>(i + 2 * vec_size);
60  auto b4 = rhs.template load<vec_type>(i + 3 * vec_size);
61  auto b5 = rhs.template load<vec_type>(i + 4 * vec_size);
62  auto b6 = rhs.template load<vec_type>(i + 5 * vec_size);
63  auto b7 = rhs.template load<vec_type>(i + 6 * vec_size);
64  auto b8 = rhs.template load<vec_type>(i + 7 * vec_size);
65 
66  r1 = vec_type::fmadd(a1, b1, r1);
67  r2 = vec_type::fmadd(a2, b2, r2);
68  r3 = vec_type::fmadd(a3, b3, r3);
69  r4 = vec_type::fmadd(a4, b4, r4);
70  r5 = vec_type::fmadd(a5, b5, r5);
71  r6 = vec_type::fmadd(a6, b6, r6);
72  r7 = vec_type::fmadd(a7, b7, r7);
73  r8 = vec_type::fmadd(a8, b8, r8);
74  }
75 
76  for (; i + (vec_size * 3) < last; i += 4 * vec_size) {
77  auto a1 = lhs.template load<vec_type>(i + 0 * vec_size);
78  auto a2 = lhs.template load<vec_type>(i + 1 * vec_size);
79  auto a3 = lhs.template load<vec_type>(i + 2 * vec_size);
80  auto a4 = lhs.template load<vec_type>(i + 3 * vec_size);
81 
82  auto b1 = rhs.template load<vec_type>(i + 0 * vec_size);
83  auto b2 = rhs.template load<vec_type>(i + 1 * vec_size);
84  auto b3 = rhs.template load<vec_type>(i + 2 * vec_size);
85  auto b4 = rhs.template load<vec_type>(i + 3 * vec_size);
86 
87  r1 = vec_type::fmadd(a1, b1, r1);
88  r2 = vec_type::fmadd(a2, b2, r2);
89  r3 = vec_type::fmadd(a3, b3, r3);
90  r4 = vec_type::fmadd(a4, b4, r4);
91  }
92  }
93 
94  for (; i + (vec_size * 1) < last; i += 2 * vec_size) {
95  auto a1 = lhs.template load<vec_type>(i + 0 * vec_size);
96  auto a2 = lhs.template load<vec_type>(i + 1 * vec_size);
97 
98  auto b1 = rhs.template load<vec_type>(i + 0 * vec_size);
99  auto b2 = rhs.template load<vec_type>(i + 1 * vec_size);
100 
101  r1 = vec_type::fmadd(a1, b1, r1);
102  r2 = vec_type::fmadd(a2, b2, r2);
103  }
104 
105  for (; i < last; i += vec_size) {
106  auto a1 = lhs.template load<vec_type>(i);
107  auto b1 = rhs.template load<vec_type>(i);
108 
109  r1 = vec_type::fmadd(a1, b1, r1);
110  }
111 
112  auto rsum = vec_type::add(vec_type::add(vec_type::add(r1, r2), vec_type::add(r3, r4)), vec_type::add(vec_type::add(r5, r6), vec_type::add(r7, r8)));
113 
114  auto p1 = vec_type::hadd(rsum);
115  auto p2 = T();
116 
117  for (; remainder && i + 1 < n; i += 2) {
118  p1 += lhs[i] * rhs[i];
119  p2 += lhs[i + 1] * rhs[i + 1];
120  }
121 
122  if (remainder && i < n) {
123  p1 += lhs[i] * rhs[i];
124  }
125 
126  return p1 + p2;
127 }
128 
135 template <typename L, typename R>
136 value_t<L> dot(const L& lhs, const R& rhs) {
137  lhs.ensure_cpu_up_to_date();
138  rhs.ensure_cpu_up_to_date();
139 
140  // The default vectorization scheme should be sufficient
141  return dot_impl<default_vec>(lhs, rhs);
142 }
143 
144 } //end of namespace etl::impl::vec
constexpr bool padding
Indicates if ETL is allowed to pad matrices and vectors.
Definition: config.hpp:135
value_t< A > dot(const A &a, const B &b)
Returns the dot product of the two given expressions.
Definition: expression_builder.hpp:594
Definition: bias_add.hpp:15
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
dot_impl
Enumeration describing the different implementations of dot.
Definition: dot_impl.hpp:20
constexpr size_t cache_size
Cache size of the machine.
Definition: config.hpp:168
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81