mlpack
ub_tree_split_impl.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_IMPL_HPP
14 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_UB_TREE_SPLIT_IMPL_HPP
15 
16 #include "ub_tree_split.hpp"
18 
19 namespace mlpack {
20 namespace tree {
21 
22 template<typename BoundType, typename MatType>
23 bool UBTreeSplit<BoundType, MatType>::SplitNode(BoundType& bound,
24  MatType& data,
25  const size_t begin,
26  const size_t count,
27  SplitInfo& splitInfo)
28 {
29  constexpr size_t order = sizeof(AddressElemType) * CHAR_BIT;
30  if (begin == 0 && count == data.n_cols)
31  {
32  // Calculate all addresses.
33  InitializeAddresses(data);
34 
35  // Probably this is not a good idea. Maybe it is better to get
36  // a number of distinct samples and find the median.
37  std::sort(addresses.begin(), addresses.end(), ComparePair);
38 
39  // Save the vector in order to rearrange the dataset later.
40  splitInfo.addresses = &addresses;
41  }
42  else
43  {
44  // We have already rearranged the dataset.
45  splitInfo.addresses = NULL;
46  }
47 
48  // The bound shouldn't contain too many subrectangles.
49  // In order to minimize the number of hyperrectangles we set last bits
50  // of the last address in the node to 1 and last bits of the first address
51  // in the next node to zero in such a way that the ordering is not
52  // disturbed.
53  if (begin + count < data.n_cols)
54  {
55  // Omit leading equal bits.
56  size_t row = 0;
57  arma::Col<AddressElemType>& lo = addresses[begin + count - 1].first;
58  const arma::Col<AddressElemType>& hi = addresses[begin + count].first;
59 
60  for (; row < data.n_rows; row++)
61  if (lo[row] != hi[row])
62  break;
63 
64  size_t bit = 0;
65 
66  for (; bit < order; bit++)
67  if ((lo[row] & ((AddressElemType) 1 << (order - 1 - bit))) !=
68  (hi[row] & ((AddressElemType) 1 << (order - 1 - bit))))
69  break;
70 
71  bit++;
72 
73  // Replace insignificant bits.
74  if (bit == order)
75  {
76  bit = 0;
77  row++;
78  }
79  else
80  {
81  for (; bit < order; bit++)
82  lo[row] |= ((AddressElemType) 1 << (order - 1 - bit));
83  row++;
84  }
85 
86  for (; row < data.n_rows; row++)
87  for (; bit < order; bit++)
88  lo[row] |= ((AddressElemType) 1 << (order - 1 - bit));
89  }
90 
91  // The bound shouldn't contain too many subrectangles.
92  // In order to minimize the number of hyperrectangles we set last bits
93  // of the first address in the next node to 0 and last bits of the last
94  // address in the previous node to 1 in such a way that the ordering is not
95  // disturbed.
96  if (begin > 0)
97  {
98  // Omit leading equal bits.
99  size_t row = 0;
100  const arma::Col<AddressElemType>& lo = addresses[begin - 1].first;
101  arma::Col<AddressElemType>& hi = addresses[begin].first;
102 
103  for (; row < data.n_rows; row++)
104  if (lo[row] != hi[row])
105  break;
106 
107  size_t bit = 0;
108 
109  for (; bit < order; bit++)
110  if ((lo[row] & ((AddressElemType) 1 << (order - 1 - bit))) !=
111  (hi[row] & ((AddressElemType) 1 << (order - 1 - bit))))
112  break;
113 
114  bit++;
115 
116  // Replace insignificant bits.
117  if (bit == order)
118  {
119  bit = 0;
120  row++;
121  }
122  else
123  {
124  for (; bit < order; bit++)
125  hi[row] &= ~((AddressElemType) 1 << (order - 1 - bit));
126  row++;
127  }
128 
129  for (; row < data.n_rows; row++)
130  for (; bit < order; bit++)
131  hi[row] &= ~((AddressElemType) 1 << (order - 1 - bit));
132  }
133 
134  // Set the minimum and the maximum addresses.
135  for (size_t k = 0; k < bound.Dim(); ++k)
136  {
137  bound.LoAddress()[k] = addresses[begin].first[k];
138  bound.HiAddress()[k] = addresses[begin + count - 1].first[k];
139  }
140  bound.UpdateAddressBounds(data.cols(begin, begin + count - 1));
141 
142  return true;
143 }
144 
145 template<typename BoundType, typename MatType>
146 void UBTreeSplit<BoundType, MatType>::InitializeAddresses(const MatType& data)
147 {
148  addresses.resize(data.n_cols);
149 
150  // Calculate all addresses.
151  for (size_t i = 0; i < data.n_cols; ++i)
152  {
153  addresses[i].first.zeros(data.n_rows);
154  bound::addr::PointToAddress(addresses[i].first, data.col(i));
155  addresses[i].second = i;
156  }
157 }
158 
159 template<typename BoundType, typename MatType>
160 size_t UBTreeSplit<BoundType, MatType>::PerformSplit(
161  MatType& data,
162  const size_t begin,
163  const size_t count,
164  const SplitInfo& splitInfo)
165 {
166  // For the first time we have to rearrange the dataset.
167  if (splitInfo.addresses)
168  {
169  std::vector<size_t> newFromOld(data.n_cols);
170  std::vector<size_t> oldFromNew(data.n_cols);
171 
172  for (size_t i = 0; i < splitInfo.addresses->size(); ++i)
173  {
174  newFromOld[i] = i;
175  oldFromNew[i] = i;
176  }
177 
178  for (size_t i = 0; i < splitInfo.addresses->size(); ++i)
179  {
180  size_t index = (*splitInfo.addresses)[i].second;
181  size_t oldI = oldFromNew[i];
182  size_t newIndex = newFromOld[index];
183 
184  data.swap_cols(i, newFromOld[index]);
185 
186  size_t tmp = newFromOld[index];
187  newFromOld[index] = i;
188  newFromOld[oldI] = tmp;
189 
190  tmp = oldFromNew[i];
191  oldFromNew[i] = oldFromNew[newIndex];
192  oldFromNew[newIndex] = tmp;
193  }
194  }
195 
196  // Since the dataset is sorted we can easily obtain the split column.
197  return begin + count / 2;
198 }
199 
200 template<typename BoundType, typename MatType>
201 size_t UBTreeSplit<BoundType, MatType>::PerformSplit(
202  MatType& data,
203  const size_t begin,
204  const size_t count,
205  const SplitInfo& splitInfo,
206  std::vector<size_t>& oldFromNew)
207 {
208  // For the first time we have to rearrange the dataset.
209  if (splitInfo.addresses)
210  {
211  std::vector<size_t> newFromOld(data.n_cols);
212 
213  for (size_t i = 0; i < splitInfo.addresses->size(); ++i)
214  newFromOld[i] = i;
215 
216  for (size_t i = 0; i < splitInfo.addresses->size(); ++i)
217  {
218  size_t index = (*splitInfo.addresses)[i].second;
219  size_t oldI = oldFromNew[i];
220  size_t newIndex = newFromOld[index];
221 
222  data.swap_cols(i, newFromOld[index]);
223 
224  size_t tmp = newFromOld[index];
225  newFromOld[index] = i;
226  newFromOld[oldI] = tmp;
227 
228  tmp = oldFromNew[i];
229  oldFromNew[i] = oldFromNew[newIndex];
230  oldFromNew[newIndex] = tmp;
231  }
232  }
233 
234  // Since the dataset is sorted we can easily obtain the split column.
235  return begin + count / 2;
236 }
237 
238 } // namespace tree
239 } // namespace mlpack
240 
241 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Bounds that are useful for binary space partitioning trees.