1 #ifndef DASH__ALGORITHM__SUMMA_H_ 2 #define DASH__ALGORITHM__SUMMA_H_ 4 #include <dash/Exception.h> 5 #include <dash/Future.h> 6 #include <dash/Pattern.h> 7 #include <dash/Types.h> 8 #include <dash/algorithm/Copy.h> 9 #include <dash/util/Trace.h> 14 #ifdef DASH_ENABLE_MKL 16 #include <mkl_types.h> 17 #include <mkl_cblas.h> 19 #include <mkl_lapack.h> 21 #elif defined(DASH_ENABLE_BLAS) 28 #define DASH_ALGORITHM_SUMMA_ASYNC_INIT_PREFETCH 34 #if defined(DASH_ENABLE_MKL) || defined(DASH_ENABLE_BLAS) 38 template<
typename ValueType>
55 template <
typename ValueType>
71 "Called fallback implementation of DGEMM (only enabled in Debug)");
74 for (
auto i = 0; i < n; ++i) {
76 for (
auto j = 0; j < p; ++j) {
79 for (
auto k = 0; k < m; ++k) {
83 auto value = A[ik] * B[kj];
90 #endif // defined(DASH_ENABLE_MKL) || defined(DASH_ENABLE_BLAS) 124 template<
typename MatrixType>
125 using summa_pattern_constraints =
130 typename MatrixType::pattern_type>;
147 typename MatrixTypeA,
148 typename MatrixTypeB,
160 typedef typename MatrixTypeA::value_type value_type;
161 typedef typename MatrixTypeA::index_type index_t;
162 typedef typename MatrixTypeA::size_type extent_t;
166 typedef std::array<index_t, 2> coords_t;
174 typename MatrixTypeC::pattern_type
181 dash::pattern_layout_properties<>,
182 typename MatrixTypeC::pattern_type
186 std::is_floating_point<value_type>::value,
187 "dash::summa expects matrix element type double or float");
189 DASH_LOG_DEBUG(
"dash::summa()");
199 "pattern of first matrix argument does not match constraints");
209 "pattern of second matrix argument does not match constraints");
219 "pattern of result matrix does not match constraints");
221 DASH_LOG_TRACE(
"dash::summa",
"matrix pattern properties valid");
223 if (shifted_tiling) {
224 DASH_LOG_TRACE(
"dash::summa",
225 "using communication scheme for diagonal-shift mapping");
227 if (minimal_tiling) {
228 DASH_LOG_TRACE(
"dash::summa",
229 "using communication scheme for minimal partitioning");
238 auto unit_id = team.myid();
240 auto pattern_a = A.pattern();
241 auto pattern_b = B.pattern();
242 auto pattern_c = C.pattern();
243 auto m = pattern_a.extent(0);
244 #if DASH_ENABLE_TRACE_LOGGING 245 auto n = pattern_a.extent(1);
246 auto p = pattern_b.extent(0);
248 const dash::MemArrange memory_order = pattern_a.memory_order();
254 "Extents of first operand in dimension 1 do not match extents of " 255 "second operand in dimension 0");
260 "Extents of result matrix in dimension 0 do not match extents of " 261 "first operand in dimension 0");
266 "Extents of result matrix in dimension 1 do not match extents of " 267 "second operand in dimension 1");
269 DASH_LOG_TRACE(
"dash::summa",
"matrix pattern extents valid");
272 auto block_size_m = pattern_a.block(0).extent(0);
273 auto block_size_n = pattern_b.block(0).extent(1);
274 auto block_size_p = pattern_b.block(0).extent(0);
275 auto num_blocks_m = m / block_size_m;
276 #if DASH_ENABLE_TRACE_LOGGING 277 auto num_blocks_n = n / block_size_n;
278 auto num_blocks_p = p / block_size_p;
281 auto block_a_size = block_size_n * block_size_m;
282 auto block_b_size = block_size_m * block_size_p;
284 auto teamspec = C.pattern().teamspec();
285 auto unit_ts_coords = teamspec.coords(unit_id);
287 DASH_LOG_TRACE(
"dash::summa",
"blocks:",
288 "m:", num_blocks_m,
"*", block_size_m,
289 "n:", num_blocks_n,
"*", block_size_n,
290 "p:", num_blocks_p,
"*", block_size_p);
291 DASH_LOG_TRACE(
"dash::summa",
293 "cols:", teamspec.extent(0),
294 "rows:", teamspec.extent(1),
295 "unit team coords:", unit_ts_coords);
296 DASH_LOG_TRACE(
"dash::summa",
"allocating local temporary blocks, sizes:",
300 #ifdef DASH_ENABLE_MKL 301 value_type * buf_block_a_get = (value_type *)(mkl_malloc(
302 sizeof(value_type) * block_a_size, 64));
303 value_type * buf_block_b_get = (value_type *)(mkl_malloc(
304 sizeof(value_type) * block_b_size, 64));
305 value_type * buf_block_a_comp = (value_type *)(mkl_malloc(
306 sizeof(value_type) * block_a_size, 64));
307 value_type * buf_block_b_comp = (value_type *)(mkl_malloc(
308 sizeof(value_type) * block_b_size, 64));
310 auto *buf_block_a_get =
new value_type[block_a_size];
311 auto *buf_block_b_get =
new value_type[block_b_size];
312 auto *buf_block_a_comp =
new value_type[block_a_size];
313 auto *buf_block_b_comp =
new value_type[block_b_size];
317 value_type * local_block_a_get = buf_block_a_get;
318 value_type * local_block_b_get = buf_block_b_get;
319 value_type * local_block_a_comp = buf_block_a_comp;
320 value_type * local_block_b_comp = buf_block_b_comp;
321 value_type * local_block_a_get_bac =
nullptr;
322 value_type * local_block_b_get_bac =
nullptr;
323 value_type * local_block_a_comp_bac =
nullptr;
324 value_type * local_block_b_comp_bac =
nullptr;
331 auto l_block_c_get = C.local.block(0);
332 auto l_block_c_get_view = l_block_c_get.begin().viewspec();
333 index_t l_block_c_get_row = l_block_c_get_view.offset(1) / block_size_n;
334 index_t l_block_c_get_col = l_block_c_get_view.offset(0) / block_size_p;
336 coords_t block_a_get_coords = coords_t {{
static_cast<index_t
>(unit_ts_coords[0]),
337 l_block_c_get_row }};
338 coords_t block_b_get_coords = coords_t {{ l_block_c_get_col,
339 static_cast<index_t
>(unit_ts_coords[0]) }};
342 auto l_block_c_comp = l_block_c_get;
343 auto l_block_c_comp_view = l_block_c_comp.begin().viewspec();
344 index_t l_block_c_comp_row = l_block_c_comp_view.offset(1) / block_size_n;
345 index_t l_block_c_comp_col = l_block_c_comp_view.offset(0) / block_size_p;
349 auto block_a = A.block(block_a_get_coords);
350 auto block_a_lptr = block_a.begin().local();
351 auto block_b = B.block(block_b_get_coords);
352 auto block_b_lptr = block_b.begin().local();
353 DASH_LOG_TRACE(
"dash::summa",
"summa.prefetch.block.a",
354 "block:", block_a_get_coords,
355 "local:", block_a_lptr !=
nullptr,
356 "unit:", block_a.begin().lpos().unit,
357 "view:", block_a.begin().viewspec());
361 trace.enter_state(
"prefetch");
362 if (block_a_lptr ==
nullptr) {
363 #ifdef DASH_ALGORITHM_SUMMA_ASYNC_INIT_PREFETCH 370 [=]() {
return local_block_a_comp + block_a.size(); });
373 local_block_a_comp_bac = local_block_a_comp;
374 local_block_a_comp = block_a_lptr;
377 DASH_LOG_TRACE(
"dash::summa",
"summa.prefetch.block.b",
378 "block:", block_b_get_coords,
379 "local:", block_b_lptr !=
nullptr,
380 "unit:", block_b.begin().lpos().unit,
381 "view:", block_b.begin().viewspec());
382 if (block_b_lptr ==
nullptr) {
383 #ifdef DASH_ALGORITHM_SUMMA_ASYNC_INIT_PREFETCH 390 [=]() {
return local_block_b_comp + block_b.size(); });
393 local_block_b_comp_bac = local_block_b_comp;
394 local_block_b_comp = block_b_lptr;
396 #ifdef DASH_ALGORITHM_SUMMA_ASYNC_INIT_PREFETCH 397 if (block_a_lptr ==
nullptr) {
398 DASH_LOG_TRACE(
"dash::summa",
"summa.prefetch.block.a.wait",
399 "waiting for prefetching of block A from unit",
400 block_a.begin().lpos().unit);
403 if (block_b_lptr ==
nullptr) {
404 DASH_LOG_TRACE(
"dash::summa",
"summa.prefetch.block.b.wait",
405 "waiting for prefetching of block B from unit",
406 block_b.begin().lpos().unit);
410 trace.exit_state(
"prefetch");
412 DASH_LOG_TRACE(
"dash::summa",
"summa.block",
413 "prefetching of blocks completed");
417 extent_t num_local_blocks_c = pattern_c.local_blockspec().size();
419 DASH_LOG_TRACE(
"dash::summa",
"summa.block.C",
420 "C.num.local.blocks:", num_local_blocks_c,
421 "C.num.column.blocks:", num_blocks_m);
423 for (extent_t lb = 0; lb < num_local_blocks_c; ++lb) {
425 l_block_c_comp = C.local.block(lb);
426 l_block_c_comp_view = l_block_c_comp.begin().viewspec();
427 l_block_c_comp_row = l_block_c_comp_view.offset(1) / block_size_n;
428 l_block_c_comp_col = l_block_c_comp_view.offset(0) / block_size_p;
430 l_block_c_get = l_block_c_comp;
431 l_block_c_get_view = l_block_c_comp_view;
432 l_block_c_get_row = l_block_c_get_row;
433 l_block_c_get_col = l_block_c_get_col;
434 DASH_LOG_TRACE(
"dash::summa",
"summa.block.comp",
"C.local.block",
436 "row:", l_block_c_comp_row,
437 "col:", l_block_c_comp_col,
438 "view:", l_block_c_comp_view);
442 for (extent_t block_k = 0; block_k < num_blocks_m; ++block_k) {
443 DASH_LOG_TRACE(
"dash::summa",
"summa.block.k", block_k,
444 "active local block in C:", lb);
450 bool last = (lb == num_local_blocks_c - 1) &&
451 (block_k == num_blocks_m - 1);
454 auto block_get_k =
static_cast<index_t
>(block_k + 1);
455 block_get_k = (block_get_k + unit_ts_coords[0]) % num_blocks_m;
457 if (block_k == num_blocks_m - 1) {
459 block_get_k = unit_ts_coords[0];
460 l_block_c_get = C.local.block(lb + 1);
461 l_block_c_get_view = l_block_c_get.begin().viewspec();
462 l_block_c_get_row = l_block_c_get_view.offset(1) / block_size_n;
463 l_block_c_get_col = l_block_c_get_view.offset(0) / block_size_p;
466 block_a_get_coords = coords_t {{ block_get_k, l_block_c_get_row }};
467 block_b_get_coords = coords_t {{ l_block_c_get_col, block_get_k }};
469 block_a = A.block(block_a_get_coords);
470 block_a_lptr = block_a.begin().local();
471 block_b = B.block(block_b_get_coords);
472 block_b_lptr = block_b.begin().local();
473 DASH_LOG_TRACE(
"dash::summa",
"summa.prefetch.block.a",
474 "block:", block_a_get_coords,
475 "local:", block_a_lptr !=
nullptr,
476 "unit:", block_a.begin().lpos().unit,
477 "view:", block_a.begin().viewspec());
478 if (block_a_lptr ==
nullptr) {
482 local_block_a_get_bac =
nullptr;
484 local_block_a_get_bac = local_block_a_get;
485 local_block_a_get = block_a_lptr;
487 DASH_LOG_TRACE(
"dash::summa",
"summa.prefetch.block.b",
488 "block:", block_b_get_coords,
489 "local:", block_b_lptr !=
nullptr,
490 "unit:", block_b.begin().lpos().unit,
491 "view:", block_b.begin().viewspec());
492 if (block_b_lptr ==
nullptr) {
496 local_block_b_get_bac =
nullptr;
498 local_block_b_get_bac = local_block_b_get;
499 local_block_b_get = block_b_lptr;
502 DASH_LOG_TRACE(
"dash::summa",
" ->",
503 "last block multiplication",
504 "lb:", lb,
"bk:", block_k);
509 DASH_LOG_TRACE(
"dash::summa",
"summa.block.comp.multiply",
510 "multiplying local block matrices",
511 "C.local.block.comp:", lb,
512 "view:", l_block_c_comp.begin().viewspec());
514 trace.enter_state(
"multiply");
515 dash::internal::mmult_local<value_type>(
518 l_block_c_comp.begin().local(),
523 trace.exit_state(
"multiply");
525 if (local_block_a_comp_bac !=
nullptr) {
526 local_block_a_comp = local_block_a_comp_bac;
527 local_block_a_comp_bac =
nullptr;
529 if (local_block_b_comp_bac !=
nullptr) {
530 local_block_b_comp = local_block_b_comp_bac;
531 local_block_b_comp_bac =
nullptr;
537 trace.enter_state(
"prefetch");
538 if (block_a_lptr ==
nullptr) {
539 DASH_LOG_TRACE(
"dash::summa",
"summa.prefetch.block.a.wait",
540 "waiting for prefetching of block A from unit",
541 block_a.begin().lpos().unit);
544 if (block_b_lptr ==
nullptr) {
545 DASH_LOG_TRACE(
"dash::summa",
"summa.prefetch.block.b.wait",
546 "waiting for prefetching of block B from unit",
547 block_b.begin().lpos().unit);
550 DASH_LOG_TRACE(
"dash::summa",
"summa.prefetch.completed",
551 "local copies of next blocks received");
552 trace.exit_state(
"prefetch");
557 std::swap(local_block_a_get, local_block_a_comp);
558 std::swap(local_block_b_get, local_block_b_comp);
559 if (local_block_a_get_bac !=
nullptr) {
560 local_block_a_comp_bac = local_block_a_get_bac;
561 local_block_a_get_bac =
nullptr;
563 if (local_block_b_get_bac !=
nullptr) {
564 local_block_b_comp_bac = local_block_b_get_bac;
565 local_block_b_get_bac =
nullptr;
571 DASH_LOG_TRACE(
"dash::summa",
"locally completed");
572 #ifdef DASH_ENABLE_MKL 573 mkl_free(buf_block_a_get);
574 mkl_free(buf_block_b_get);
575 mkl_free(buf_block_a_comp);
576 mkl_free(buf_block_b_comp);
578 delete[] buf_block_a_get;
579 delete[] buf_block_b_get;
580 delete[] buf_block_a_comp;
581 delete[] buf_block_b_comp;
584 DASH_LOG_TRACE(
"dash::summa",
"waiting for other units");
585 trace.enter_state(
"barrier");
587 trace.exit_state(
"barrier");
589 DASH_LOG_TRACE(
"dash::summa >",
"finished");
603 typename MatrixTypeA,
604 typename MatrixTypeB,
605 typename MatrixTypeC >
618 typename MatrixTypeA,
619 typename MatrixTypeB,
620 typename MatrixTypeC >
630 ->
typename std::enable_if<
631 summa_pattern_constraints<MatrixTypeA>::satisfied::value &&
632 summa_pattern_constraints<MatrixTypeB>::satisfied::value &&
633 summa_pattern_constraints<MatrixTypeC>::satisfied::value,
643 #endif // DASH__ALGORITHM__SUMMA_H_ All blocks have identical size.
This class is a simple memory pool which holds allocates elements of size ValueType.
Local element order corresponds to a logical linearization within single blocks (if blocked) or withi...
void summa(MatrixTypeA &A, MatrixTypeB &B, MatrixTypeC &C)
Multiplies two matrices using the SUMMA algorithm.
dash::Future< ValueType * > copy_async(InputIt in_first, InputIt in_last, OutputIt out_first)
Asynchronous variant of dash::copy.
Traits for compile-time pattern constraints checking, suitable as a helper for template definitions e...
dash::pattern_partitioning_properties< dash::pattern_partitioning_tag::rectangular, dash::pattern_partitioning_tag::balanced, dash::pattern_partitioning_tag::ndimensional > summa_pattern_partitioning_constraints
Constraints on pattern partitioning properties of matrix operands passed to dash::summa.
Implementation of a future used to wait for an operation to complete and access the value returned by...
The number of blocks assigned to units may differ.
Minimal number of blocks in every dimension, typically at most one block per unit.
Block extents are constant for every dimension.
void mmult(MatrixTypeA &A, MatrixTypeB &B, MatrixTypeC &C)
Function adapter to an implementation of matrix-matrix multiplication (xDGEMM) depending on the matri...
Units are mapped to blocks in diagonal chains in all hyperplanes.
A Team instance specifies a subset of all available units.
void wait()
Wait for the value to become available.
Elements are contiguous in local memory within a single block and thus indexed blockwise.
dash::pattern_mapping_properties< dash::pattern_mapping_tag::multiple, dash::pattern_mapping_tag::unbalanced > summa_pattern_mapping_constraints
Constraints on pattern mapping properties of matrix operands passed to dash::summa.
bool check_pattern_constraints(const PatternType &pattern)
Traits for compile- and run-time pattern constraints checking, suitable for property checks where det...
Units are mapped to more than one block.
OutputIt copy(InputIt in_first, InputIt in_last, OutputIt out_first)
Copies the elements in the range, defined by [in_first, in_last), to another range beginning at out_f...
dash::pattern_layout_properties< dash::pattern_layout_tag::blocked, dash::pattern_layout_tag::linear > summa_pattern_layout_constraints
Constraints on pattern layout properties of matrix operands passed to dash::summa.
struct dash::dart_operation ValueType
Reduce operands to their minimum value.
Data range is partitioned in at least two dimensions.
Generic type of mapping properties of a model satisfying the Pattern concept.