10 #include "etl/tmp.hpp" 13 #include "etl/op/flip_transformers.hpp" 14 #include "etl/op/rep_transformers.hpp" 15 #include "etl/op/reduc_transformers.hpp" 24 template <
typename L,
typename R>
33 static constexpr
bool gpu_computable =
false;
41 check_mmul_sizes(left, right);
51 const auto n = etl::dim<1>(
right);
54 const auto m = etl::dim<0>(
left);
67 const auto n = etl::dim<1>(
right);
70 const auto m = etl::dim<0>(
left);
84 for (
size_t k = 0; k <
columns(left); k++) {
97 bool alias(
const E& rhs)
const noexcept {
98 return left.alias(rhs) || right.alias(rhs);
109 right.visit(visitor);
118 left.ensure_cpu_up_to_date();
119 right.ensure_cpu_up_to_date();
128 left.ensure_gpu_up_to_date();
129 right.ensure_gpu_up_to_date();
139 return os <<
"mm_mul(" << transformer.
left <<
"," << transformer.
right <<
")";
143 template <
typename A,
typename B>
144 void check_mmul_sizes([[maybe_unused]]
const A& a, [[maybe_unused]]
const B& b) {
145 if constexpr (all_fast<A, B>) {
148 "Invalid sizes for multiplication");
150 cpp_assert(dim<1>(a) == dim<0>(b)
152 "Invalid sizes for multiplication");
166 static constexpr
bool gpu_computable =
false;
185 size_t i_i = i / (
etl::size(sub) + h - 1);
186 size_t i_j = i % (
etl::size(sub) + h - 1);
226 template <
typename E>
227 bool alias(
const E& rhs)
const noexcept {
228 return sub.alias(rhs);
246 sub.ensure_cpu_up_to_date();
255 sub.ensure_gpu_up_to_date();
268 static constexpr
bool gpu_computable =
false;
279 i1 = etl::dim<0>(sub);
280 i2 = etl::dim<1>(sub);
282 size_t c_height = (i1 + k1 - 1) * (i2 + k2 - 1);
283 size_t c_width = k1 * k2;
285 auto max_fill = c_height - ((i1 + k1 - 1) * ((c_width - 1) / k1) + (c_width - 1) % k1);
286 inner_paddings = max_fill - (i1 * i2);
287 inner_padding = inner_paddings / (i2 - 1);
296 size_t i_i = i / (k1 * k2);
297 size_t i_j = i % (k1 * k2);
298 return (*
this)(i_i, i_j);
308 size_t i_i = i / (k1 * k2);
309 size_t i_j = i % (k1 * k2);
310 return (*
this)(i_i, i_j);
320 auto top_padding = (i1 + k1 - 1) * (j / k1) + j % k1;
322 if (i < top_padding || i >= top_padding + (i1 * i2) + inner_paddings) {
325 auto inner = i - top_padding;
326 auto col = inner % (i1 + inner_padding);
327 auto block = inner / (i1 + inner_padding);
332 return sub(
col, block);
342 template <
typename E>
343 bool alias(
const E& rhs)
const noexcept {
344 return sub.alias(rhs);
363 sub.ensure_cpu_up_to_date();
372 sub.ensure_gpu_up_to_date();
385 template <
typename A,
typename M>
387 const size_t i1 = etl::dim<0>(sub);
388 const size_t i2 = etl::dim<1>(sub);
390 const size_t c_height = (i1 + k1 - 1) * (i2 + k2 - 1);
391 const size_t c_width = k1 * k2;
393 const auto max_fill = c_height - ((i1 + k1 - 1) * ((c_width - 1) / k1) + (c_width - 1) % k1);
394 const auto inner_paddings = max_fill - (i1 * i2);
395 const auto inner_padding = inner_paddings / (i2 - 1);
397 auto* __restrict mm = m.memory_start();
398 auto* __restrict ss = sub.memory_start();
402 for (
size_t j = 0; j < c_width; ++j) {
403 size_t big_i = (i1 + k1 - 1) * (j / k1) + j % k1;
405 for (
size_t ii = 0; ii < etl::dim<1>(sub); ++ii) {
406 for (
size_t jj = 0; jj < etl::dim<0>(sub); ++jj) {
407 mm[j * c_width + big_i] = ss[jj * i2 + ii];
410 big_i += inner_padding;
424 template <
typename A,
typename M>
426 if constexpr (all_dma<A, M>) {
429 const size_t i1 = etl::dim<0>(sub);
430 const size_t i2 = etl::dim<1>(sub);
432 const auto m_width = (i1 - k1 + 1) * (i2 - k2 + 1);
434 const auto mm = m.memory_start();
435 const auto ss = sub.memory_start();
437 for (
size_t b = 0; b < m_width; ++b) {
438 auto s_i = b % (i1 - k1 + 1);
439 auto s_j = b / (i1 - k1 + 1);
441 for (
size_t b_i = 0; b_i < k1; ++b_i) {
442 for (
size_t b_j = 0; b_j < k2; ++b_j) {
443 mm[(b_j * k1 + b_i) * m_width + b] = ss[(s_i + b_i) * i2 + s_j + b_j];
448 const size_t i1 = etl::dim<0>(sub);
449 const size_t i2 = etl::dim<1>(sub);
451 const size_t m_width = (i1 - k1 + 1) * (i2 - k2 + 1);
453 for (
size_t b = 0; b < m_width; ++b) {
454 auto s_i = b % (i1 - k1 + 1);
455 auto s_j = b / (i1 - k1 + 1);
457 for (
size_t b_i = 0; b_i < k1; ++b_i) {
458 for (
size_t b_j = 0; b_j < k2; ++b_j) {
459 m(b_j * k1 + b_i, b) = sub(s_i + b_i, s_j + b_j);
478 template <etl_dma A, etl_dma M>
480 const size_t i1 = etl::dim<0>(sub);
481 const size_t i2 = etl::dim<1>(sub);
483 const auto height = i1 - k1 + 1;
484 const auto width = i2 - k2 + 1;
486 const auto mm = m.memory_start();
487 const auto ss = sub.memory_start();
489 for (
size_t c = 0; c < k1 * k2; ++c) {
490 const size_t w_source = c % k2;
491 const size_t h_source = (c / k2) % k1;
492 const size_t c_source = c / (k1 * k2);
494 for (
size_t h = 0; h < height; ++h) {
495 const size_t block_source = (c_source * i1 + h + h_source) * i2 + w_source;
496 const size_t block_target = (c * height + h) * width;
513 template <etl_dma A, etl_dma M>
515 const auto N = etl::dim<0>(sub);
516 const auto i1 = etl::dim<1>(sub);
517 const auto i2 = etl::dim<2>(sub);
519 const auto height = i1 - k1 + 1;
520 const auto width = i2 - k2 + 1;
522 const auto mm = m.memory_start();
523 const auto ss = sub.memory_start();
525 for (
size_t w = 0; w < k1 * k2; ++w) {
526 const auto w_source = w % k2;
527 const auto h_source = (w / k2) % k1;
528 const auto c_source = w / (k1 * k2);
530 for (
size_t i = 0; i < N; ++i) {
531 for (
size_t h = 0; h < height; ++h) {
532 const auto block_source = ((c_source * i1 + h + h_source) * i2 + w_source) + (i) * (i1 * i2);
533 const auto block_target = (w * N + i) * (height * width) + h * width;
544 template <
typename LE,
typename RE>
553 static constexpr
bool is_etl =
true;
557 static constexpr
bool is_fast = l_traits::is_fast && r_traits::is_fast;
558 static constexpr
bool is_linear =
false;
560 static constexpr
bool is_value =
false;
561 static constexpr
bool is_direct =
false;
562 static constexpr
bool is_generator =
false;
563 static constexpr
bool is_padded =
false;
564 static constexpr
bool is_aligned =
false;
565 static constexpr
bool is_temporary = l_traits::is_temporary || r_traits::is_temporary;
566 static constexpr
bool gpu_computable =
false;
567 static constexpr
order storage_order = l_traits::is_generator ? r_traits::storage_order : l_traits::storage_order;
574 template <vector_mode_t V>
575 static constexpr
bool vectorizable =
false;
583 return dim(v, 0) *
dim(v, 1);
596 cpp_assert(d == 1,
"Only 2D mmul are supported");
606 static constexpr
size_t size() {
640 template <
typename E>
646 static constexpr
bool is_etl =
true;
651 static constexpr
bool is_linear =
false;
653 static constexpr
bool is_value =
false;
654 static constexpr
bool is_direct =
false;
655 static constexpr
bool is_generator =
false;
656 static constexpr
bool is_padded =
false;
657 static constexpr
bool is_aligned =
false;
659 static constexpr
bool gpu_computable =
false;
667 template <vector_mode_t V>
668 static constexpr
bool vectorizable =
false;
713 template <
typename E>
719 static constexpr
bool is_etl =
true;
724 static constexpr
bool is_linear =
false;
726 static constexpr
bool is_value =
false;
727 static constexpr
bool is_direct =
false;
728 static constexpr
bool is_generator =
false;
729 static constexpr
bool is_padded =
false;
730 static constexpr
bool is_aligned =
false;
732 static constexpr
bool gpu_computable =
false;
740 template <vector_mode_t V>
741 static constexpr
bool vectorizable =
false;
749 auto c_height = (etl::dim<0>(v.
sub) + v.
k1 - 1) * (etl::dim<1>(v.
sub) + v.
k2 - 1);
750 auto c_width = v.
k1 * v.
k2;
751 return c_height * c_width;
762 return (etl::dim<0>(v.
sub) + v.
k1 - 1) * (etl::dim<1>(v.
sub) + v.
k2 - 1);
constexpr bool is_magic_view
Traits indicating if the given ETL type is a magic view expression.
Definition: traits.hpp:311
D D
The number of dimensions.
Definition: dyn_matrix_view.hpp:24
order
Storage order of a matrix.
Definition: order.hpp:15
void im2col_direct_tr_multi(M &m, A &&sub, size_t k1, size_t k2)
Convert a sequence of images to a sequence of image columns to be multiplied by kernels of size (k1...
Definition: transformers.hpp:514
constexpr bool is_fast
Traits to test if the given ETL expresion type is fast (sizes known at compile-time) ...
Definition: traits.hpp:588
Traits to get information about ETL types.
Definition: tmp.hpp:68
Root namespace for the ETL library.
Definition: adapter.hpp:15
auto dim(E &&value, size_t i) -> detail::identity_helper< E, dim_view< detail::build_identity_type< E >, D >>
Return a view representing the ith Dth dimension.
Definition: view_expression_builder.hpp:25
size_t columns(const E &expr)
Returns the number of columns of the given ETL expression.
Definition: helpers.hpp:78
Visitor to perform local evaluation when necessary.
Definition: eval_visitors.hpp:23
void direct_copy_n(const S *source, T *target, size_t n)
Performs a direct memory copy.
Definition: memory.hpp:35
constexpr bool is_transformer
Traits indicating if the given ETL type is a transformer expression.
Definition: traits.hpp:297
constexpr size_t size(const E &expr) noexcept
Returns the size of the given ETL expression.
Definition: helpers.hpp:108
requires(D > 0) struct dyn_base
Matrix with run-time fixed dimensions.
Definition: dyn_base.hpp:113
constexpr bool is_view
Traits indicating if the given ETL type is a view expression.
Definition: traits.hpp:304
auto col(E &&value, size_t i) -> detail::identity_helper< E, dim_view< detail::build_identity_type< E >, 2 >>
Returns view representing the ith column of the given expression.
Definition: view_expression_builder.hpp:47
constexpr bool is_thread_safe
Traits to test if the given ETL expresion type is thread safe.
Definition: traits.hpp:687
typename decay_traits< E >::value_type value_t
Traits to extract the value type out of an ETL type.
Definition: tmp.hpp:81
void im2col_direct(M &m, A &&sub, size_t k1, size_t k2)
Convert an image to a sequence of image columns to be multiplied by kernels of size (k1...
Definition: transformers.hpp:425
void im2col_direct_tr(M &m, A &&sub, size_t k1, size_t k2)
Convert an image to a sequence of image columns to be multiplied by kernels of size (k1...
Definition: transformers.hpp:479
void convmtx2_direct_t(M &m, A &&sub, size_t k1, size_t k2)
Compute the convolution matrix of sub into m for a kernel of size (k1,k2)
Definition: transformers.hpp:386