Expression Templates Library (ETL)
transpose.hpp
Go to the documentation of this file.
1 //=======================================================================
2 // Copyright (c) 2014-2023 Baptiste Wicht
3 // Distributed under the terms of the MIT License.
4 // (See accompanying file LICENSE or copy at
5 // http://opensource.org/licenses/MIT)
6 //=======================================================================
7 
13 #pragma once
14 
15 namespace etl::impl::vec {
16 
17 template <typename V, typename T>
18 inline void transpose_block_4x4_kernel(size_t N, size_t M, const T* A2, T* C2, size_t i2, size_t j2) {
19  C2[(j2 + 0) * N + (i2 + 0)] = A2[(i2 + 0) * M + (j2 + 0)];
20  C2[(j2 + 1) * N + (i2 + 0)] = A2[(i2 + 0) * M + (j2 + 1)];
21  C2[(j2 + 2) * N + (i2 + 0)] = A2[(i2 + 0) * M + (j2 + 2)];
22  C2[(j2 + 3) * N + (i2 + 0)] = A2[(i2 + 0) * M + (j2 + 3)];
23 
24  C2[(j2 + 0) * N + (i2 + 1)] = A2[(i2 + 1) * M + (j2 + 0)];
25  C2[(j2 + 1) * N + (i2 + 1)] = A2[(i2 + 1) * M + (j2 + 1)];
26  C2[(j2 + 2) * N + (i2 + 1)] = A2[(i2 + 1) * M + (j2 + 2)];
27  C2[(j2 + 3) * N + (i2 + 1)] = A2[(i2 + 1) * M + (j2 + 3)];
28 
29  C2[(j2 + 0) * N + (i2 + 2)] = A2[(i2 + 2) * M + (j2 + 0)];
30  C2[(j2 + 1) * N + (i2 + 2)] = A2[(i2 + 2) * M + (j2 + 1)];
31  C2[(j2 + 2) * N + (i2 + 2)] = A2[(i2 + 2) * M + (j2 + 2)];
32  C2[(j2 + 3) * N + (i2 + 2)] = A2[(i2 + 2) * M + (j2 + 3)];
33 
34  C2[(j2 + 0) * N + (i2 + 3)] = A2[(i2 + 3) * M + (j2 + 0)];
35  C2[(j2 + 1) * N + (i2 + 3)] = A2[(i2 + 3) * M + (j2 + 1)];
36  C2[(j2 + 2) * N + (i2 + 3)] = A2[(i2 + 3) * M + (j2 + 2)];
37  C2[(j2 + 3) * N + (i2 + 3)] = A2[(i2 + 3) * M + (j2 + 3)];
38 }
39 
40 #ifdef __SSE3__
41 // sse_vec will only be defined if __SSE3__is enabled
42 
43 // SSE Version optimized for float
44 template <>
45 inline void transpose_block_4x4_kernel<sse_vec>(size_t N, size_t M, const float* A2, float* C2, size_t i2, size_t j2) {
46  using vec_type = sse_vec;
47 
48  auto r1 = vec_type::loadu(A2 + (i2 + 0) * M + j2);
49  auto r2 = vec_type::loadu(A2 + (i2 + 1) * M + j2);
50  auto r3 = vec_type::loadu(A2 + (i2 + 2) * M + j2);
51  auto r4 = vec_type::loadu(A2 + (i2 + 3) * M + j2);
52 
53  _MM_TRANSPOSE4_PS(r1.value, r2.value, r3.value, r4.value);
54 
55  vec_type::storeu(C2 + (j2 + 0) * N + i2, r1);
56  vec_type::storeu(C2 + (j2 + 1) * N + i2, r2);
57  vec_type::storeu(C2 + (j2 + 2) * N + i2, r3);
58  vec_type::storeu(C2 + (j2 + 3) * N + i2, r4);
59 }
60 
61 #endif
62 
63 template <typename V, typename A, typename C>
64 void transpose_impl(const A& a, C&& c) {
65  const size_t N = etl::dim<0>(a);
66  const size_t M = etl::dim<1>(a);
67 
68  const auto* A2 = a.memory_start();
69  auto* C2 = c.memory_start();
70 
71  if constexpr (decay_traits<A>::storage_order == order::RowMajor) {
72  constexpr size_t block_size = 16;
73  constexpr size_t kernel_block_size = 4;
74 
75 
76  auto batch_fun_i = [&](const size_t ifirst, const size_t ilast) {
77  cpp_assert(ilast <= N, "Invalid dispatch");
78 
79  size_t i = ifirst;
80 
81  for (; i + block_size - 1 < ilast; i += block_size) {
82  size_t j = 0;
83 
84  // Compute blocks of 16x16
85  for (; j + block_size - 1 < M; j += block_size) {
86  for (size_t i2 = i; i2 < i + block_size; i2 += kernel_block_size) {
87  for (size_t j2 = j; j2 < j + block_size; j2 += kernel_block_size) {
88  transpose_block_4x4_kernel<V>(N, M, A2, C2, i2, j2);
89  }
90  }
91  }
92 
93  // Compute blocks of 16x4
94  for (; j + kernel_block_size - 1 < M; j += kernel_block_size) {
95  for (size_t i2 = i; i2 < i + block_size; i2 += kernel_block_size) {
96  transpose_block_4x4_kernel<V>(N, M, A2, C2, i2, j);
97  }
98  }
99 
100  // Compute the left overs
101  for (; j < M; ++j) {
102  for (size_t i2 = i; i2 < i + block_size; ++i2) {
103  C2[j * N + i2] = A2[i2 * M + j];
104  }
105  }
106  }
107 
108  for (; i + kernel_block_size - 1 < ilast; i += kernel_block_size) {
109  size_t j = 0;
110 
111  // Compute blocks of 4x4
112  for (; j + kernel_block_size - 1 < M; j += kernel_block_size) {
113  transpose_block_4x4_kernel<V>(N, M, A2, C2, i, j);
114  }
115 
116  // Compute the leftovers
117  for (; j < M; ++j) {
118  for (size_t i2 = i; i2 < i + kernel_block_size; ++i2) {
119  C2[j * N + i2] = A2[i2 * M + j];
120  }
121  }
122  }
123 
124  for (; i < ilast; ++i) {
125  for (size_t j = 0; j < M; ++j) {
126  C2[j * N + i] = A2[i * M + j];
127  }
128  }
129  };
130 
131  engine_dispatch_1d(batch_fun_i, 0, N, engine_select_parallel(N, threads * 2 * block_size));
132  } else {
133  //TODO Optimize properly for column major
134  for (size_t j = 0; j < M; ++j) {
135  for (size_t i = 0; i < N; ++i) {
136  C2[i * M + j] = A2[j * N + i];
137  }
138  }
139  }
140 }
141 
142 template <typename A, typename C>
143 void transpose([[maybe_unused]] A&& a, [[maybe_unused]] C&& c) {
144  if constexpr (all_vectorizable<vector_mode, A, C> && sse3_enabled) {
145 #ifdef __SSE3__
146 // sse_vec will only be defined if __SSE3__is enabled
147  transpose_impl<sse_vec>(a, c);
148 #endif
149  } else {
150  cpp_unreachable("Invalid call to vec::batch_outer");
151  }
152 }
153 
154 } //end of namespace etl::impl::vec
void engine_dispatch_1d(Functor &&functor, size_t first, size_t last, [[maybe_unused]] size_t threshold, [[maybe_unused]] size_t n_threads=etl::threads)
Dispatch the elements of a range to a functor in a parallel manner, using the global thread engine...
Definition: parallel_support.hpp:708
Definition: bias_add.hpp:15
typename V::template vec_type< value_type > vec_type
The vectorization type for V.
Definition: dyn_matrix_view.hpp:43
auto transpose(const E &value)
Returns the transpose of the given expression.
Definition: expression_builder.hpp:528
void storeu(vec_type< V > in, size_t i) noexcept
Store several elements in the matrix at once.
Definition: dyn_matrix_view.hpp:187
auto loadu(size_t x) const noexcept
Load several elements of the expression at once.
Definition: dyn_matrix_view.hpp:154
const size_t threads
The number of threads ETL can use in parallel mode.
Definition: config.hpp:45
bool engine_select_parallel([[maybe_unused]] size_t n, [[maybe_unused]] size_t threshold=parallel_threshold)
Indicates if an 1D evaluation should run in paralle.
Definition: parallel_support.hpp:679
constexpr bool sse3_enabled
Indicates if SSE3 is available.
Definition: config.hpp:215
Row-Major storage.
transpose_impl
Enumeration describing the different implementations of transpose.
Definition: transpose_impl.hpp:20