rocPRIM
block_histogram_sort.hpp
1 // Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_SORT_HPP_
22 #define ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_SORT_HPP_
23 
24 #include <type_traits>
25 
26 #include "../../config.hpp"
27 #include "../../detail/various.hpp"
28 
29 #include "../../intrinsics.hpp"
30 #include "../../functional.hpp"
31 
32 #include "../block_radix_sort.hpp"
33 #include "../block_discontinuity.hpp"
34 
35 BEGIN_ROCPRIM_NAMESPACE
36 
37 namespace detail
38 {
39 
40 template<
41  class T,
42  unsigned int BlockSizeX,
43  unsigned int BlockSizeY,
44  unsigned int BlockSizeZ,
45  unsigned int ItemsPerThread,
46  unsigned int Bins
47 >
49 {
50  static constexpr unsigned int BlockSize = BlockSizeX * BlockSizeY * BlockSizeZ;
51  static_assert(
52  std::is_convertible<T, unsigned int>::value,
53  "T must be convertible to unsigned int"
54  );
55 
56 private:
59 
60 public:
62  {
63  typename radix_sort::storage_type sort;
64  struct
65  {
66  typename discontinuity::storage_type flag;
67  unsigned int start[Bins];
68  unsigned int end[Bins];
69  };
70  };
71 
73 
74  template<class Counter>
75  ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE
76  void composite(T (&input)[ItemsPerThread],
77  Counter hist[Bins])
78  {
79  ROCPRIM_SHARED_MEMORY storage_type storage;
80  this->composite(input, hist, storage);
81  }
82 
83  template<class Counter>
84  ROCPRIM_DEVICE ROCPRIM_INLINE
85  void composite(T (&input)[ItemsPerThread],
86  Counter hist[Bins],
87  storage_type& storage)
88  {
89  // TODO: Check, MSVC rejects the code with the static assertion, yet compiles fine for all tested types. Predicate likely too strict
90  //static_assert(
91  // std::is_convertible<unsigned int, Counter>::value,
92  // "unsigned int must be convertible to Counter"
93  //);
94  constexpr auto tile_size = BlockSize * ItemsPerThread;
95  const auto flat_tid = ::rocprim::flat_block_thread_id<BlockSizeX, BlockSizeY, BlockSizeZ>();
96  unsigned int head_flags[ItemsPerThread];
97  discontinuity_op flags_op(storage);
98  storage_type_& storage_ = storage.get();
99 
100  radix_sort().sort(input, storage_.sort);
101  ::rocprim::syncthreads(); // Fix race condition that appeared on Vega10 hardware, storage LDS is reused below.
102 
103  ROCPRIM_UNROLL
104  for(unsigned int offset = 0; offset < Bins; offset += BlockSize)
105  {
106  const unsigned int offset_tid = offset + flat_tid;
107  if(offset_tid < Bins)
108  {
109  storage_.start[offset_tid] = tile_size;
110  storage_.end[offset_tid] = tile_size;
111  }
112  }
114 
115  discontinuity().flag_heads(head_flags, input, flags_op, storage_.flag);
117 
118  // The start of the first bin is not overwritten since the input is sorted
119  // and the starts are based on the second item.
120  // The very first item is never used as `b` in the operator
121  // This means that this should not need synchromization, but in practice it does.
122  if(flat_tid == 0)
123  {
124  storage_.start[static_cast<unsigned int>(input[0])] = 0;
125  }
127 
128  ROCPRIM_UNROLL
129  for(unsigned int offset = 0; offset < Bins; offset += BlockSize)
130  {
131  const unsigned int offset_tid = offset + flat_tid;
132  if(offset_tid < Bins)
133  {
134  Counter count = static_cast<Counter>(storage_.end[offset_tid] - storage_.start[offset_tid]);
135  hist[offset_tid] += count;
136  }
137  }
138  }
139 
140 private:
141  struct discontinuity_op
142  {
143  storage_type &storage;
144 
145  ROCPRIM_DEVICE ROCPRIM_INLINE
146  discontinuity_op(storage_type &storage) : storage(storage)
147  {
148  }
149 
150  ROCPRIM_DEVICE ROCPRIM_INLINE
151  bool operator()(const T& a, const T& b, unsigned int b_index) const
152  {
153  storage_type_& storage_ = storage.get();
154  if(static_cast<unsigned int>(a) != static_cast<unsigned int>(b))
155  {
156  storage_.start[static_cast<unsigned int>(b)] = b_index;
157  storage_.end[static_cast<unsigned int>(a)] = b_index;
158  return true;
159  }
160  else
161  {
162  return false;
163  }
164  }
165  };
166 };
167 
168 } // end namespace detail
169 
170 END_ROCPRIM_NAMESPACE
171 
172 #endif // ROCPRIM_BLOCK_DETAIL_BLOCK_HISTOGRAM_SORT_HPP_
The block_discontinuity class is a block level parallel primitive which provides methods for flagging...
Definition: block_discontinuity.hpp:82
Definition: block_histogram_sort.hpp:48
The block_radix_sort class is a block level parallel primitive which provides methods for sorting of ...
Definition: block_radix_sort.hpp:97
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
ROCPRIM_DEVICE ROCPRIM_INLINE void flag_heads(Flag(&head_flags)[ItemsPerThread], const T(&input)[ItemsPerThread], FlagOp flag_op, storage_type &storage)
Tags head_flags that indicate discontinuities between items partitioned across the thread block...
Definition: block_discontinuity.hpp:156
ROCPRIM_DEVICE ROCPRIM_INLINE void syncthreads()
Synchronize all threads in a block (tile)
Definition: thread.hpp:216
ROCPRIM_DEVICE ROCPRIM_INLINE void sort(Key(&keys)[ItemsPerThread], storage_type &storage, unsigned int begin_bit=0, unsigned int end_bit=8 *sizeof(Key))
Performs ascending radix sort over keys partitioned across threads in a block.
Definition: block_radix_sort.hpp:179
Definition: block_histogram_sort.hpp:61