rocPRIM
device_transform.hpp
1 // Copyright (c) 2017-2021 Advanced Micro Devices, Inc. All rights reserved.
2 //
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
9 //
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
12 //
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 // THE SOFTWARE.
20 
21 #ifndef ROCPRIM_DEVICE_DETAIL_DEVICE_TRANSFORM_HPP_
22 #define ROCPRIM_DEVICE_DETAIL_DEVICE_TRANSFORM_HPP_
23 
24 #include <type_traits>
25 #include <iterator>
26 
27 #include "../../config.hpp"
28 #include "../../detail/various.hpp"
29 #include "../../detail/match_result_type.hpp"
30 
31 #include "../../intrinsics.hpp"
32 #include "../../functional.hpp"
33 #include "../../types.hpp"
34 
35 #include "../../block/block_load.hpp"
36 #include "../../block/block_store.hpp"
37 
38 BEGIN_ROCPRIM_NAMESPACE
39 
40 namespace detail
41 {
42 
43 // Wrapper for unpacking tuple to be used with BinaryFunction.
44 // See transform function which accepts two input iterators.
45 template<class T1, class T2, class BinaryFunction>
47 {
48  using result_type = typename ::rocprim::detail::invoke_result<BinaryFunction, T1, T2>::type;
49 
50  ROCPRIM_HOST_DEVICE inline
51  unpack_binary_op() = default;
52 
53  ROCPRIM_HOST_DEVICE inline
54  unpack_binary_op(BinaryFunction binary_op) : binary_op_(binary_op)
55  {
56  }
57 
58  ROCPRIM_HOST_DEVICE inline
59  ~unpack_binary_op() = default;
60 
61  ROCPRIM_HOST_DEVICE inline
62  result_type operator()(const ::rocprim::tuple<T1, T2>& t)
63  {
64  return binary_op_(::rocprim::get<0>(t), ::rocprim::get<1>(t));
65  }
66 
67 private:
68  BinaryFunction binary_op_;
69 };
70 
71 template<
72  unsigned int BlockSize,
73  unsigned int ItemsPerThread,
74  class ResultType,
75  class InputIterator,
76  class OutputIterator,
77  class UnaryFunction
78 >
79 ROCPRIM_DEVICE ROCPRIM_INLINE
80 void transform_kernel_impl(InputIterator input,
81  const size_t input_size,
82  OutputIterator output,
83  UnaryFunction transform_op)
84 {
85  using input_type = typename std::iterator_traits<InputIterator>::value_type;
86  using output_type = typename std::iterator_traits<OutputIterator>::value_type;
87  using result_type =
88  typename std::conditional<
89  std::is_void<output_type>::value, ResultType, output_type
90  >::type;
91 
92  constexpr unsigned int items_per_block = BlockSize * ItemsPerThread;
93 
94  const unsigned int flat_id = ::rocprim::detail::block_thread_id<0>();
95  const unsigned int flat_block_id = ::rocprim::detail::block_id<0>();
96  const unsigned int block_offset = flat_block_id * items_per_block;
97  const unsigned int number_of_blocks = ::rocprim::detail::grid_size<0>();
98  const unsigned int valid_in_last_block = input_size - block_offset;
99 
100  input_type input_values[ItemsPerThread];
101  result_type output_values[ItemsPerThread];
102 
103  if(flat_block_id == (number_of_blocks - 1)) // last block
104  {
105  block_load_direct_striped<BlockSize>(
106  flat_id,
107  input + block_offset,
108  input_values,
109  valid_in_last_block
110  );
111 
112  ROCPRIM_UNROLL
113  for(unsigned int i = 0; i < ItemsPerThread; i++)
114  {
115  if(BlockSize * i + flat_id < valid_in_last_block)
116  {
117  output_values[i] = transform_op(input_values[i]);
118  }
119  }
120 
121  block_store_direct_striped<BlockSize>(
122  flat_id,
123  output + block_offset,
124  output_values,
125  valid_in_last_block
126  );
127  }
128  else
129  {
130  block_load_direct_striped<BlockSize>(
131  flat_id,
132  input + block_offset,
133  input_values
134  );
135 
136  ROCPRIM_UNROLL
137  for(unsigned int i = 0; i < ItemsPerThread; i++)
138  {
139  output_values[i] = transform_op(input_values[i]);
140  }
141 
142  block_store_direct_striped<BlockSize>(
143  flat_id,
144  output + block_offset,
145  output_values
146  );
147  }
148 }
149 
150 } // end of detail namespace
151 
152 END_ROCPRIM_NAMESPACE
153 
154 #endif // ROCPRIM_DEVICE_DETAIL_DEVICE_TRANSFORM_HPP_
Deprecated: Configuration of device-level scan primitives.
Definition: block_histogram.hpp:62
Definition: device_transform.hpp:46
Definition: test_device_histogram.cpp:94
ROCPRIM_DEVICE ROCPRIM_INLINE unsigned int flat_block_id()
Returns flat (linear, 1D) block identifier in a multidimensional grid.
Definition: thread.hpp:178