libime
datrie.cpp
1 /*
2  * SPDX-FileCopyrightText: 2015-2017 CSSlayer <wengxt@gmail.com>
3  *
4  * SPDX-License-Identifier: LGPL-2.1-or-later
5  */
6 
7 // Original Author
8 // cedar -- C++ implementation of Efficiently-updatable Double ARray trie
9 // $Id: cedarpp.h 1830 2014-06-16 06:17:42Z ynaga $
10 // SPDX-FileCopyrightText: 2009-2014 Naoki Yoshinaga
11 // <ynaga@tkl.iis.u-tokyo.ac.jp>
12 
13 #include "datrie.h"
14 #include <sys/types.h>
15 #include <algorithm>
16 #include <array>
17 #include <cassert>
18 #include <cmath>
19 #include <cstdint>
20 #include <cstring>
21 #include <fstream>
22 #include <ios>
23 #include <istream>
24 #include <iterator>
25 #include <limits>
26 #include <memory>
27 #include <optional>
28 #include <ostream>
29 #include <stdexcept>
30 #include <string>
31 #include <string_view>
32 #include <tuple>
33 #include <vector>
34 #include <fcitx-utils/macros.h>
35 #include "naivevector.h"
36 #include "utils_p.h"
37 
38 namespace libime {
39 
40 namespace {
41 
42 template <typename value_type>
43 union DecoderUnion {
44  int32_t result;
45  value_type result_value;
46 };
47 
48 template <typename V>
49 int32_t decodeValue(V raw) {
50  DecoderUnion<V> decoder;
51  decoder.result_value = raw;
52  return decoder.result;
53 }
54 
55 template <typename T>
56 struct NanValue {
57  static int32_t NO_VALUE() { return -1; };
58  static int32_t NO_PATH() { return -2; }
59 };
60 
61 // Musl doesn't have nanf implementation we need, just check if they are the
62 // same value. If not, prefer old hardcoded value.
63 bool isGoodNanf() {
64  int32_t noValue = decodeValue(std::nanf("1"));
65  int32_t noPath = decodeValue(std::nanf("2"));
66  return (noValue != noPath);
67 }
68 
69 template <>
70 struct NanValue<float> {
71  static_assert(std::numeric_limits<float>::has_quiet_NaN,
72  "Require support for quiet NaN.");
73  static int32_t NO_VALUE() {
74  return isGoodNanf() ? decodeValue(std::nanf("1")) : 0x7fc00001;
75  }
76  static int32_t NO_PATH() {
77  return isGoodNanf() ? decodeValue(std::nanf("2")) : 0x7fc00002;
78  }
79 };
80 
81 template <typename T>
82 constexpr T decodeImpl(int32_t raw) {
83  typename DATriePrivate<T>::decoder_type d;
84  d.result = raw;
85  return d.result_value;
86 }
87 
88 } // namespace
89 
90 // template<typename T>
91 // using vector_impl = std::vector<T>;
92 // naivevector is used due to realloc not available with std::vector.
93 template <typename T>
94 using vector_impl = naivevector<T>;
95 
96 template <typename V, bool ORDERED, int MAX_TRIAL>
98 public:
99  using base_type = DATrie<V>;
100  using value_type = typename base_type::value_type;
101  using position_type = typename base_type::position_type;
102  using updater_type = typename base_type::updater_type;
103  using callback_type = typename base_type::callback_type;
104  using decoder_type = DecoderUnion<V>;
105 
106  static inline const int32_t CEDAR_NO_VALUE = NanValue<V>::NO_VALUE();
107  static inline const int32_t CEDAR_NO_PATH = NanValue<V>::NO_PATH();
108 
109  static constexpr size_t MAX_ALLOC_SIZE = 1
110  << 16; // must be divisible by 256
111  using result_type = value_type;
112  using uchar = uint8_t;
113  static_assert(sizeof(value_type) <= sizeof(int32_t),
114  "value size need to be same as int32_t");
115  struct node {
116  union {
117  int32_t base;
118  value_type value;
119  };
120  int32_t check;
121  node(const int32_t base_ = 0, const int32_t check_ = 0)
122  : base(base_), check(check_) {}
123 
124  node(std::istream &in) : base(0), check(0) {
125  throw_if_io_fail(unmarshall(in, base));
126  throw_if_io_fail(unmarshall(in, check));
127  }
128 
129  FCITX_INLINE_DEFINE_DEFAULT_DTOR_AND_COPY(node);
130 
131  friend std::ostream &operator<<(std::ostream &out, const node &n) {
132  marshall(out, n.base) && marshall(out, n.check);
133  return out;
134  }
135  };
136 
137  struct npos_t {
138  uint32_t offset;
139  uint32_t index;
140 
141  npos_t() : offset(0), index(0) {}
142 
143  explicit npos_t(position_type i) {
144  offset = i >> 32;
145  index = i & 0xffffffffULL;
146  }
147 
148  operator bool() { return offset != 0 && index != 0; }
149 
150  bool operator==(const npos_t &other) const {
151  return offset == other.offset && index == other.index;
152  }
153 
154  bool operator!=(const npos_t &other) const {
155  return !operator==(other);
156  }
157 
158  position_type toInt() {
159  return (static_cast<position_type>(offset) << 32) | index;
160  }
161  };
162  struct block { // a block w/ 256 elements
163  int32_t prev; // prev block; 3 bytes
164  int32_t next; // next block; 3 bytes
165  int16_t num; // # empty elements; 0 - 256
166  int16_t reject; // minimum # branching failed to locate; soft limit
167  int32_t trial; // # trial
168  int32_t ehead; // first empty item
169  block() : prev(0), next(0), num(256), reject(257), trial(0), ehead(0) {}
170 
171  block(std::istream &in) : block() {
172  throw_if_io_fail(unmarshall(in, prev));
173  throw_if_io_fail(unmarshall(in, next));
174  throw_if_io_fail(unmarshall(in, num));
175  throw_if_io_fail(unmarshall(in, reject));
176  throw_if_io_fail(unmarshall(in, trial));
177  throw_if_io_fail(unmarshall(in, ehead));
178  }
179 
180  friend std::ostream &operator<<(std::ostream &out, const block &b) {
181  marshall(out, b.prev) && marshall(out, b.next) &&
182  marshall(out, b.num) && marshall(out, b.reject) &&
183  marshall(out, b.trial) && marshall(out, b.ehead);
184  return out;
185  }
186  };
187  struct ninfo { // x1.5 update speed; +.25 % memory (8n -> 10n)
188  uchar sibling; // right sibling (= 0 if not exist)
189  uchar child; // first child
190  ninfo() : sibling(0), child(0) {}
191 
192  ninfo(std::istream &in) {
193  throw_if_io_fail(unmarshall(in, sibling));
194  throw_if_io_fail(unmarshall(in, child));
195  }
196 
197  friend std::ostream &operator<<(std::ostream &out, const ninfo &n) {
198  marshall(out, n.sibling) && marshall(out, n.child);
199  return out;
200  }
201  };
202 
203  vector_impl<node> m_array;
204  vector_impl<char> m_tail;
205  vector_impl<int> m_tail0;
206  vector_impl<block> m_block;
207  vector_impl<ninfo> m_ninfo;
208 
209  int32_t m_bheadF; // first block of Full; 0
210  int32_t m_bheadC; // first block of Closed; 0 if no Closed
211  int32_t m_bheadO; // first block of Open; 0 if no Open
212  std::array<int, 257> m_reject;
213 
214  DATriePrivate() { init(); }
215  FCITX_INLINE_DEFINE_DEFAULT_DTOR_AND_COPY(DATriePrivate)
216 
217  size_t size() const { return m_ninfo.size(); }
218 
219  size_t capacity() const { return m_array.size(); }
220 
221  void clear() {
222  init();
223  m_array.shrink_to_fit();
224  m_block.shrink_to_fit();
225  m_tail.shrink_to_fit();
226  m_ninfo.shrink_to_fit();
227  m_tail0.shrink_to_fit();
228  }
229 
230  size_t num_keys() const {
231  size_t i = 0;
232  for (auto to = 0; to < static_cast<int>(size()); ++to) {
233  const node &n = m_array[to];
234  if (n.check >= 0 && (m_array[n.check].base == to || n.base < 0)) {
235  ++i;
236  }
237  }
238  return i;
239  }
240 
241  void open(std::istream &fin) {
242  uint32_t len = 0;
243  uint32_t size_ = 0;
244  throw_if_io_fail(unmarshall(fin, len));
245  throw_if_io_fail(unmarshall(fin, size_));
246 
247  const size_t length_ = static_cast<size_t>(len);
248 
249  m_tail.resize(length_);
250  m_tail0.resize(0);
251  m_array.reserve(size_);
252  m_array.resize(0);
253  m_ninfo.reserve(size_);
254  m_ninfo.resize(0);
255  m_block.reserve(size_ >> 8);
256  m_block.resize(0);
257 
258  throw_if_io_fail(fin.read(reinterpret_cast<char *>(m_tail.data()),
259  sizeof(char) * length_));
260 
261  for (auto i = 0U; i < size_; i++) {
262  m_array.emplace_back(fin);
263  }
264  m_array.resize(size_);
265 
266  throw_if_io_fail(unmarshall(fin, m_bheadF));
267  throw_if_io_fail(unmarshall(fin, m_bheadC));
268  throw_if_io_fail(unmarshall(fin, m_bheadO));
269 
270  for (auto i = 0U; i < size_; i++) {
271  m_ninfo.emplace_back(fin);
272  }
273 
274  for (auto i = 0U, end = size_ >> 8; i < end; i++) {
275  m_block.emplace_back(fin);
276  }
277  }
278 
279  void save(std::ostream &fout) {
280  shrink_tail();
281 
282  const uint32_t length = m_tail.size();
283  const uint32_t size_ = size();
284 
285  assert(m_block.size() << 8 == m_ninfo.size());
286  throw_if_io_fail(marshall(fout, length));
287  throw_if_io_fail(marshall(fout, size_));
288  throw_if_io_fail(fout.write(reinterpret_cast<char *>(m_tail.data()),
289  sizeof(char) * length));
290 
291  auto s = size_;
292  for (auto &n : m_array) {
293  throw_if_io_fail(fout << n);
294  if (--s == 0) {
295  break;
296  }
297  }
298 
299  throw_if_io_fail(marshall(fout, m_bheadF));
300  throw_if_io_fail(marshall(fout, m_bheadC));
301  throw_if_io_fail(marshall(fout, m_bheadO));
302 
303  for (auto &n : m_ninfo) {
304  throw_if_io_fail(fout << n);
305  }
306 
307  for (auto &b : m_block) {
308  throw_if_io_fail(fout << b);
309  }
310  }
311 
312  void init() {
313  m_bheadF = m_bheadC = m_bheadO = 0;
314  m_array.clear();
315  m_array.resize(256);
316  m_array[0] = node(0, -1);
317 
318  for (int i = 1; i < 256; ++i) {
319  m_array[i] =
320  node(i == 1 ? -255 : -(i - 1), i == 255 ? -1 : -(i + 1));
321  }
322 
323  m_ninfo.clear();
324  m_ninfo.resize(256);
325 
326  m_block.clear();
327  m_block.reserve(1);
328  m_block.resize(1);
329  m_block[0].ehead = 1;
330  // put a dummy entry
331  m_tail0.resize(0);
332  m_tail.clear();
333  m_tail.resize(sizeof(int32_t));
334 
335  for (auto i = 0; i <= 256; ++i) {
336  m_reject[i] = i + 1;
337  }
338  }
339 
340  void suffix(std::string &key, size_t len, npos_t pos) const {
341  key.clear();
342  key.resize(len);
343 
344  auto to = pos.index;
345  if (const int offset = pos.offset) {
346  size_t len_tail = std::strlen(&m_tail[-m_array[to].base]);
347  if (len > len_tail) {
348  len -= len_tail;
349  } else {
350  len_tail = len;
351  len = 0;
352  }
353  std::copy(&m_tail[static_cast<size_t>(offset) - len_tail],
354  &m_tail[static_cast<size_t>(offset)], key.begin() + len);
355  }
356  while (len--) {
357  const int from = m_array[to].check;
358  key[len] =
359  static_cast<char>(m_array[from].base ^ static_cast<int>(to));
360  to = from;
361  }
362  }
363 
364  template <typename U>
365  void update(const char *key, const U &callback) {
366  update(key, std::strlen(key), callback);
367  }
368 
369  template <typename U>
370  void update(const char *key, size_t len, const U &callback) {
371  npos_t from;
372  size_t pos(0);
373  update(key, from, pos, len, callback);
374  }
375 
376  template <typename U>
377  void update(const char *key, npos_t &from, size_t &pos, size_t len,
378  const U &callback) {
379  update(key, from, pos, len, callback, [](const int, const int) {});
380  }
381 
382  template <typename U, typename T>
383  void update(const char *key, npos_t &npos, size_t &pos, size_t len,
384  const U &callback, const T &cf) {
385  if (!len && !npos) {
386  throw std::invalid_argument("failed to insert zero-length key");
387  }
388 
389  auto &from = npos.index;
390  auto offset = npos.offset;
391  if (!offset) { // node on trie
392  for (const uchar *const key_ = reinterpret_cast<const uchar *>(key);
393  m_array[from].base >= 0; ++pos) {
394  if (pos == len) {
395  const auto to = _follow(from, 0, cf);
396  m_array[to].value = callback(m_array[to].value);
397  return;
398  }
399  from = static_cast<size_t>(_follow(from, key_[pos], cf));
400  }
401  offset = -m_array[from].base;
402  }
403  if (offset >= sizeof(int32_t)) { // go to m_tail
404  const size_t pos_orig = pos;
405  char *const tail = m_tail.data() + offset - pos;
406  while (pos < len && key[pos] == tail[pos]) {
407  ++pos;
408  }
409  //
410  if (pos == len && tail[pos] == '\0') { // found exact key
411  if (const ssize_t moved =
412  pos - pos_orig) { // search end on tail
413  npos.offset = offset + moved;
414  }
415  char *const data = &tail[len + 1];
416  storeDWord(data, callback(loadDWord<value_type>(data)));
417  return;
418  }
419  // otherwise, insert the common prefix in tail if any
420  if (npos.offset) {
421  npos.offset = 0; // reset to update tail offset
422  for (auto offset_ = static_cast<size_t>(-m_array[from].base);
423  offset_ < offset;) {
424  from = static_cast<size_t>(
425  _follow(from, static_cast<uchar>(m_tail[offset_]), cf));
426  ++offset_;
427  }
428  }
429  for (size_t pos_ = pos_orig; pos_ < pos; ++pos_) {
430  from = static_cast<size_t>(
431  _follow(from, static_cast<uchar>(key[pos_]), cf));
432  }
433  // for example:
434  // we originally had abcde in trie, then bcde is in tail
435  // if we want to insert abcdf
436  // pos would be 1 (things on trie that matched)
437  // and pos_orig will be 4 (things on tail + trie) that matched
438  // and tail[pos] will not be empty in this case
439  ssize_t moved = pos - pos_orig;
440  if (tail[pos]) { // remember to move offset to existing tail
441  const int to_ =
442  _follow(from, static_cast<uchar>(tail[pos]), cf);
443  m_array[to_].base = -static_cast<int>(offset + ++moved);
444  moved -= 1 + sizeof(value_type); // keep record
445  }
446  moved += offset;
447  for (ssize_t i = offset; i <= moved; i += 1 + sizeof(value_type)) {
448  if (m_tail0.capacity() == m_tail0.size()) {
449  auto quota =
450  m_tail0.capacity() + (m_tail0.size() >= MAX_ALLOC_SIZE
451  ? MAX_ALLOC_SIZE
452  : m_tail0.size());
453  m_tail0.reserve(quota);
454  }
455  m_tail0.push_back(i);
456  }
457  if (pos == len || tail[pos] == '\0') {
458  const int to = _follow(from, 0, cf);
459  if (pos == len) {
460  m_array[to].value = callback(m_array[to].value);
461  return;
462  }
463  // tail[pos] == 0, so the actual content in tail can be get rid
464  // of
465  // and tail0 is used to track those hole
466  m_array[to].value = loadDWord<value_type>(&tail[pos + 1]);
467  }
468  from = static_cast<size_t>(
469  _follow(from, static_cast<uchar>(key[pos]), cf));
470  ++pos;
471  }
472  const auto needed = len - pos + 1 + sizeof(value_type);
473  if (pos == len && !m_tail0.empty()) { // reuse
474  const int offset0 = *m_tail0.rbegin();
475  m_tail[offset0] = '\0';
476  m_array[from].base = -offset0;
477  m_tail0.pop_back();
478  char *const data = &m_tail[offset0 + 1];
479  storeDWord(data, callback(0));
480  return;
481  }
482  if (m_tail.capacity() < m_tail.size() + needed) {
483  auto quota =
484  m_tail.capacity() +
485  std::max(needed, std::min(MAX_ALLOC_SIZE, m_tail.size()));
486  m_tail.reserve(quota);
487  }
488  m_array[from].base = -m_tail.size();
489  const size_t pos_orig = pos;
490  auto old_length = m_tail.size();
491  m_tail.resize(m_tail.size() + needed);
492  char *const tail = &m_tail[old_length] - pos;
493  if (pos < len) {
494  do {
495  tail[pos] = key[pos];
496  } while (++pos < len);
497  npos.offset = old_length + len - pos_orig;
498  }
499  char *const data = &tail[len + 1];
500  storeDWord(data, callback(loadDWord<value_type>(data)));
501  }
502 
503  // easy-going erase () without compression
504  int erase(const char *key) { return erase(key, std::strlen(key)); }
505  int erase(const char *key, size_t len, npos_t npos = npos_t()) {
506  size_t pos = 0;
507  auto &from = npos.index;
508  const auto i = _find(key, npos, pos, len);
509  if (i == CEDAR_NO_PATH || i == CEDAR_NO_VALUE) {
510  return -1;
511  }
512  if (npos.offset) {
513  npos.offset = 0; // leave tail as is
514  }
515  bool flag = m_array[from].base < 0; // have sibling
516  int e = flag ? static_cast<int>(from) : m_array[from].base ^ 0;
517  from = m_array[e].check;
518  do {
519  const node &n = m_array[from];
520  flag = m_ninfo[n.base ^ m_ninfo[from].child].sibling;
521  if (flag) {
522  _pop_sibling(from, n.base, static_cast<uchar>(n.base ^ e));
523  }
524  _push_enode(e);
525  e = static_cast<int>(from);
526  from = static_cast<size_t>(m_array[from].check);
527  } while (!flag);
528  return 0;
529  }
530 
531  bool foreach(const callback_type &callback, npos_t root = npos_t()) const {
532  int32_t resultRaw;
533  size_t p(0);
534  npos_t from = root;
535  for (resultRaw = begin(from, p); resultRaw != CEDAR_NO_PATH;
536  resultRaw = next(from, p, root)) {
537  if (resultRaw != CEDAR_NO_VALUE &&
538  !callback(decodeImpl<V>(resultRaw), p, from.toInt())) {
539  return false;
540  }
541  }
542  return true;
543  }
544 
545  template <typename T>
546  void dump(T *result, const size_t result_len) const {
547  size_t num(0);
548  foreach([result, result_len, &num](value_type value, size_t len,
549  position_type pos) {
550  if (num < result_len) {
551  _set_result(&result[num++], value, len, npos_t(pos));
552  } else {
553  return false;
554  }
555  return true;
556  });
557  }
558  void shrink_tail() {
559  const size_t length_ =
560  static_cast<size_t>(m_tail.size()) -
561  (static_cast<size_t>(m_tail0.size()) * (1 + sizeof(value_type)));
562  decltype(m_tail) t;
563  // a dummy entry
564  t.resize(sizeof(int32_t));
565  t.reserve(length_);
566  for (int to = 0; to < static_cast<int>(size()); ++to) {
567  node &n = m_array[to];
568  if (n.check >= 0 && m_array[n.check].base != to && n.base < 0) {
569  char *const tail_(&m_tail[-n.base]);
570  n.base = -static_cast<int32_t>(t.size());
571  auto i = 0;
572  do {
573  t.push_back(tail_[i]);
574  } while (tail_[i++]);
575  t.resize(t.size() + sizeof(value_type));
576  storeDWord(&t[t.size() - sizeof(value_type)],
577  loadDWord<value_type>(&tail_[i]));
578  }
579  }
580  using std::swap;
581  swap(t, m_tail);
582  m_tail0.resize(0);
583  m_tail0.shrink_to_fit();
584  }
585 
586  // return the first child for a tree rooted by a given node
587  int32_t begin(npos_t &npos, size_t &len) const {
588  auto &from = npos.index;
589  int base =
590  npos.offset ? -static_cast<int>(npos.offset) : m_array[from].base;
591  if (base >= 0) { // on trie
592  uchar c = m_ninfo[from].child;
593  if (!from) {
594  c = m_ninfo[base ^ c].sibling;
595  if (!c) { // bug fix
596  return CEDAR_NO_PATH; // no entry
597  }
598  }
599  for (; c && base >= 0; ++len) {
600  from = static_cast<size_t>(base) ^ c;
601  base = m_array[from].base;
602  c = m_ninfo[from].child;
603  }
604  if (base >= 0) {
605  return m_array[base ^ c].base;
606  }
607  }
608  const size_t len_ = std::strlen(&m_tail[-base]);
609  npos.offset = static_cast<size_t>(-base) + len_;
610  len += len_;
611  return loadDWord<int32_t>(&m_tail[-base] + len_ + 1);
612  }
613  // return the next child if any
614  int32_t next(npos_t &npos, size_t &len,
615  const npos_t root = npos_t()) const {
616  uchar c = 0;
617  auto &from = npos.index;
618  if (const int offset = npos.offset) { // on tail
619  if (root.offset) {
620  return CEDAR_NO_PATH;
621  }
622  npos.offset = 0;
623  len -= static_cast<size_t>(offset - (-m_array[from].base));
624  } else {
625  c = m_ninfo[m_array[from].base].sibling;
626  }
627  for (; !c && npos != root; --len) {
628  c = m_ninfo[from].sibling;
629  from = static_cast<size_t>(m_array[from].check);
630  }
631  if (!c) {
632  return CEDAR_NO_PATH;
633  }
634  from = static_cast<size_t>(m_array[from].base) ^ c;
635  return begin(npos, ++len);
636  }
637  // follow/create edge
638  template <typename T>
639  int _follow(uint32_t &from, const uchar label, const T &cf) {
640  int to = 0;
641  const int base = m_array[from].base;
642  if (base < 0 || m_array[base ^ label].check < 0) {
643  to = _pop_enode(base, label, static_cast<int>(from));
644  _push_sibling(from, to ^ label, label, base >= 0);
645  } else if (m_array[base ^ label].check != static_cast<int>(from)) {
646  to = _resolve(from, base, label, cf);
647  } else {
648  to = base ^ label;
649  }
650  return to;
651  }
652 
653  // find key from double array
654  int32_t _find(const char *key, npos_t &npos, size_t &pos,
655  const size_t len) const {
656  auto &from = npos.index;
657  auto offset = npos.offset;
658  if (!offset) { // node on trie
659  for (const uchar *const key_ = reinterpret_cast<const uchar *>(key);
660  m_array[from].base >= 0;) {
661  if (pos == len) {
662  const node &n = m_array[m_array[from].base ^ 0];
663  if (n.check != static_cast<int>(from)) {
664  return CEDAR_NO_VALUE;
665  }
666  return n.base;
667  }
668  size_t to = static_cast<size_t>(m_array[from].base);
669  to ^= key_[pos];
670  if (m_array[to].check != static_cast<int>(from)) {
671  return CEDAR_NO_PATH;
672  }
673  ++pos;
674  from = to;
675  }
676  offset = -m_array[from].base;
677  }
678  // switch to _tail to match suffix
679  const size_t pos_orig = pos; // start position in reading _tail
680  const char *const tail = &m_tail[offset] - pos;
681  if (pos < len) {
682  do {
683  if (key[pos] != tail[pos]) {
684  break;
685  }
686  } while (++pos < len);
687  if (const int moved = pos - pos_orig) {
688  npos.offset = offset + moved;
689  }
690  if (pos < len) {
691  return CEDAR_NO_PATH; // input > tail, input != tail
692  }
693  }
694  if (tail[pos]) {
695  return CEDAR_NO_VALUE; // input < tail
696  }
697  return loadDWord<int32_t>(&tail[len + 1]);
698  }
699 
700  // explore new block to settle down
701  int _find_place() {
702  if (m_bheadC) {
703  return m_block[m_bheadC].ehead;
704  }
705  if (m_bheadO) {
706  return m_block[m_bheadO].ehead;
707  }
708  return _add_block() << 8;
709  }
710  int _find_place(const uchar *const first, const uchar *const last) {
711  if (auto bi = m_bheadO) {
712  const auto bz = m_block[m_bheadO].prev;
713  const auto nc = std::distance(first, last);
714  while (true) { // set candidate block
715  block &b = m_block[bi];
716  if (b.num >= nc && nc < b.reject) { // explore configuration
717  for (int e = b.ehead;;) {
718  const int base = e ^ *first;
719  for (const uchar *p = first;
720  m_array[base ^ *++p].check < 0;) {
721  if (p + 1 == last) {
722  b.ehead = e;
723  return b.ehead; // no conflict
724  }
725  }
726  e = -m_array[e].check;
727  if (e == b.ehead) {
728  break;
729  }
730  }
731  }
732  b.reject = nc;
733  if (b.reject < m_reject[b.num]) {
734  m_reject[b.num] = b.reject;
735  }
736  const int bi_ = b.next;
737  if (++b.trial == MAX_TRIAL) {
738  _transfer_block(bi, m_bheadO, m_bheadC);
739  }
740  if (bi == bz) {
741  break;
742  }
743  bi = bi_;
744  }
745  }
746  return _add_block() << 8;
747  }
748 
749  static void _set_result(result_type *x, value_type r, size_t /*len*/ = 0,
750  npos_t /*npos*/ = npos_t()) {
751  *x = r;
752  }
753  static void _set_result(std::tuple<value_type, size_t, position_type> *x,
754  value_type r, size_t len, npos_t npos) {
755  *x = std::make_tuple<>(r, len, npos.toInt());
756  }
757  void _pop_block(const int bi, int &head_in, const bool last) {
758  if (last) { // last one poped; Closed or Open
759  head_in = 0;
760  } else {
761  const block &b = m_block[bi];
762  m_block[b.prev].next = b.next;
763  m_block[b.next].prev = b.prev;
764  if (bi == head_in) {
765  head_in = b.next;
766  }
767  }
768  }
769  void _push_block(const int bi, int &head_out, const bool empty) {
770  block &b = m_block[bi];
771  if (empty) { // the destination is empty
772  head_out = b.prev = b.next = bi;
773  } else { // use most recently pushed
774  int &tail_out = m_block[head_out].prev;
775  b.prev = tail_out;
776  b.next = head_out;
777  head_out = tail_out = m_block[tail_out].next = bi;
778  }
779  }
780  int _add_block() {
781  // size depends on m_info.size()
782  if (size() == capacity()) { // allocate memory if needed
783  auto new_capacity =
784  capacity() +
785  (size() >= MAX_ALLOC_SIZE ? MAX_ALLOC_SIZE : size());
786  m_array.reserve(new_capacity);
787  m_array.resize(new_capacity);
788  m_ninfo.reserve(new_capacity);
789  m_block.reserve(new_capacity >> 8);
790  m_block.resize(size() >> 8);
791  }
792  assert(m_block.size() == size() >> 8);
793  m_block.resize(m_block.size() + 1);
794  m_block[size() >> 8].ehead = size();
795 
796  assert(m_array.size() >= size() + 256);
797  m_array[size()] = node(-(size() + 255), -(size() + 1));
798  for (auto i = size() + 1; i < size() + 255; ++i) {
799  m_array[i] = node(-(i - 1), -(i + 1));
800  }
801  m_array[size() + 255] = node(-(size() + 254), -size());
802  _push_block(size() >> 8, m_bheadO, !m_bheadO); // append to block Open
803  m_ninfo.resize(size() + 256);
804  return (size() >> 8) - 1;
805  }
806  // transfer block from one start w/ head_in to one start w/ head_out
807  void _transfer_block(const int bi, int &head_in, int &head_out) {
808  _pop_block(bi, head_in, bi == m_block[bi].next);
809  _push_block(bi, head_out, !head_out && m_block[bi].num);
810  }
811  // pop empty node from block; never transfer the special block (bi = 0)
812  int _pop_enode(const int base, const uchar label, const int from) {
813  const int e = base < 0 ? _find_place() : base ^ label;
814  const int bi = e >> 8;
815  node &n = m_array[e];
816  block &b = m_block[bi];
817  if (--b.num == 0) {
818  if (bi) {
819  _transfer_block(bi, m_bheadC, m_bheadF); // Closed to Full
820  }
821  } else { // release empty node from empty ring
822  m_array[-n.base].check = n.check;
823  m_array[-n.check].base = n.base;
824  if (e == b.ehead) {
825  b.ehead = -n.check; // set ehead
826  }
827  if (bi && b.num == 1 && b.trial != MAX_TRIAL) {
828  // Open to Closed
829  _transfer_block(bi, m_bheadO, m_bheadC);
830  }
831  }
832  // initialize the released node
833  if (label) {
834  n.base = -1;
835  } else {
836  n.value = value_type(0);
837  }
838  n.check = from;
839  if (base < 0) {
840  m_array[from].base = e ^ label;
841  }
842  return e;
843  }
844  // push empty node into empty ring
845  void _push_enode(const int e) {
846  const int bi = e >> 8;
847  block &b = m_block[bi];
848  if (++b.num == 1) { // Full to Closed
849  b.ehead = e;
850  m_array[e] = node(-e, -e);
851  if (bi) {
852  _transfer_block(bi, m_bheadF, m_bheadC); // Full to Closed
853  }
854  } else {
855  const int prev = b.ehead;
856  const int next = -m_array[prev].check;
857  m_array[e] = node(-prev, -next);
858  m_array[prev].check = m_array[next].base = -e;
859  if (b.num == 2 || b.trial == MAX_TRIAL) { // Closed to Open
860  if (bi) {
861  _transfer_block(bi, m_bheadC, m_bheadO);
862  }
863  }
864  b.trial = 0;
865  }
866  if (b.reject < m_reject[b.num]) {
867  b.reject = m_reject[b.num];
868  }
869  m_ninfo[e] = ninfo(); // reset ninfo; no child, no sibling
870  }
871  // push label to from's child
872  void _push_sibling(const int32_t from, const int base, const uchar label,
873  const bool flag = true) {
874  uchar *c = &m_ninfo[from].child;
875  if (flag && (ORDERED ? label > *c : !*c)) {
876  do {
877  c = &m_ninfo[base ^ *c].sibling;
878  } while (ORDERED && *c && *c < label);
879  }
880  m_ninfo[base ^ label].sibling = *c, *c = label;
881  }
882  // pop label from from's child
883  void _pop_sibling(const int32_t from, const int base, const uchar label) {
884  uchar *c = &m_ninfo[from].child;
885  while (*c != label) {
886  c = &m_ninfo[base ^ *c].sibling;
887  }
888  *c = m_ninfo[base ^ label].sibling;
889  }
890  // check whether to replace branching w/ the newly added node
891  bool _consult(const int base_n, const int base_p, uchar c_n,
892  uchar c_p) const {
893  do {
894  c_n = m_ninfo[base_n ^ c_n].sibling;
895  c_p = m_ninfo[base_p ^ c_p].sibling;
896  } while (c_n && c_p);
897  return c_p;
898  }
899  // enumerate (equal to or more than one) child nodes
900  uchar *_set_child(uchar *p, const int base, uchar c,
901  const std::optional<uchar> label = std::nullopt) {
902  if (!c) {
903  *p = c;
904  p++;
905  c = m_ninfo[base ^ c].sibling;
906  } // 0: terminal
907  if (ORDERED && label.has_value()) {
908  while (c && c < *label) {
909  *p = c;
910  p++;
911  c = m_ninfo[base ^ c].sibling;
912  }
913  }
914  if (label.has_value()) {
915  *p = *label;
916  p++;
917  }
918  while (c) {
919  *p = c;
920  p++;
921  c = m_ninfo[base ^ c].sibling;
922  }
923  return p;
924  }
925  // resolve conflict on base_n ^ label_n = base_p ^ label_p
926  template <typename T>
927  int _resolve(uint32_t &from_n, const int base_n, const uchar label_n,
928  const T &cf) {
929  // examine siblings of conflicted nodes
930  const int to_pn = base_n ^ label_n;
931  const int from_p = m_array[to_pn].check;
932  const int base_p = m_array[from_p].base;
933  const bool flag // whether to replace siblings of newly added
934  = _consult(base_n, base_p, m_ninfo[from_n].child,
935  m_ninfo[from_p].child);
936  uchar child[256];
937  uchar *const first = child;
938  uchar *const last =
939  flag ? _set_child(first, base_n, m_ninfo[from_n].child, label_n)
940  : _set_child(first, base_p, m_ninfo[from_p].child);
941  assert(first < last);
942  const int base =
943  (first + 1 == last ? _find_place() : _find_place(first, last)) ^
944  *first;
945  // replace & modify empty list
946  const int from = flag ? static_cast<int>(from_n) : from_p;
947  const int base_ = flag ? base_n : base_p;
948  if (flag && *first == label_n) {
949  m_ninfo[from].child = label_n; // new child
950  }
951  m_array[from].base = base; // new base
952  for (const uchar *p = first; p < last; ++p) { // to_ => to
953  const int to = _pop_enode(base, *p, from);
954  const int to_ = base_ ^ *p;
955  m_ninfo[to].sibling = (p + 1 == last) ? 0 : *(p + 1);
956  if (flag && to_ == to_pn) {
957  continue; // skip newcomer (no child)
958  }
959  cf(to_, to);
960  node &n = m_array[to];
961  node &n_ = m_array[to_];
962  n.base = n_.base; // copy base; bug fix
963  if (n.base > 0 && *p) {
964  uchar c = m_ninfo[to].child = m_ninfo[to_].child;
965  do {
966  m_array[n.base ^ c].check = to; // adjust grand son's check
967  } while ((c = m_ninfo[n.base ^ c].sibling));
968  }
969  if (!flag && to_ == static_cast<int>(from_n)) {
970  // parent node moved
971  from_n = static_cast<size_t>(to); // bug fix
972  }
973  if (!flag && to_ == to_pn) { // the address is immediately used
974  _push_sibling(from_n, to_pn ^ label_n, label_n);
975  m_ninfo[to_].child = 0; // remember to reset child
976  if (label_n) {
977  n_.base = -1;
978  } else {
979  n_.value = value_type(0);
980  }
981  n_.check = static_cast<int>(from_n);
982  } else {
983  _push_enode(to_);
984  }
985  }
986  return flag ? base ^ label_n : to_pn;
987  }
988 };
989 
990 template <typename T>
992 
993 template <typename T>
994 DATrie<T>::DATrie(const char *filename) : DATrie() {
995  std::ifstream fin(filename, std::ios::in | std::ios::binary);
996  throw_if_io_fail(fin);
997  d->open(fin);
998 }
999 
1000 template <typename T>
1001 DATrie<T>::DATrie(std::istream &fin) : DATrie() {
1002  d->open(fin);
1003 }
1004 
1005 template <typename T>
1006 DATrie<T>::~DATrie() = default;
1007 
1008 template <typename T>
1009 DATrie<T>::DATrie(DATrie<T> &&other) noexcept = default;
1010 
1011 template <typename T>
1012 DATrie<T>::DATrie(const DATrie<T> &other)
1013  : d(std::make_unique<DATriePrivate<T>>(*other.d)) {}
1014 
1015 template <typename T>
1016 DATrie<T> &DATrie<T>::operator=(DATrie<T> &&other) noexcept = default;
1017 
1018 template <typename T>
1019 DATrie<T> &DATrie<T>::operator=(const DATrie<T> &other) {
1020  if (this == &other) {
1021  return *this;
1022  }
1023 
1024  if (d) {
1025  *d = *other.d;
1026  } else {
1027  d = std::make_unique<DATriePrivate<T>>(*other.d);
1028  }
1029  return *this;
1030 }
1031 
1032 template <typename T>
1033 void DATrie<T>::load(std::istream &in) {
1034  clear();
1035  d->open(in);
1036 }
1037 
1038 template <typename T>
1039 void DATrie<T>::save(const char *filename) {
1040  std::ofstream fout(filename, std::ios::out | std::ios::binary);
1041  throw_if_io_fail(fout);
1042  save(fout);
1043 }
1044 
1045 template <typename T>
1046 void DATrie<T>::save(std::ostream &stream) {
1047  d->save(stream);
1048 }
1049 
1050 template <typename T>
1051 void DATrie<T>::set(const char *key, size_t len, value_type val) {
1052  d->update(key, len, [val](value_type) { return val; });
1053 }
1054 
1055 template <typename T>
1056 void DATrie<T>::update(const char *key, size_t len,
1057  DATrie<T>::updater_type updater) {
1058  d->update(key, len, updater);
1059 }
1060 
1061 template <typename T>
1062 size_t DATrie<T>::size() const {
1063  return d->num_keys();
1064 }
1065 
1066 template <typename T>
1067 bool DATrie<T>::empty() const {
1068  return d->foreach([](value_type, size_t, position_type) { return false; });
1069 }
1070 
1071 template <typename T>
1072 bool DATrie<T>::foreach(const char *prefix, size_t size, callback_type func,
1073  position_type _pos) const {
1074  size_t pos = 0;
1075  typename DATriePrivate<value_type>::npos_t from(_pos);
1076  if (d->_find(prefix, from, pos, size) ==
1078  return true;
1079  }
1080 
1081  return d->foreach(func, from);
1082 }
1083 
1084 template <typename T>
1085 bool DATrie<T>::foreach(callback_type func, position_type pos) const {
1086  typename DATriePrivate<value_type>::npos_t from(pos);
1087  return d->foreach(func, from);
1088 }
1089 
1090 template <typename T>
1091 void DATrie<T>::suffix(std::string &s, size_t len, position_type pos) const {
1092  d->suffix(s, len, typename DATriePrivate<T>::npos_t(pos));
1093 }
1094 
1095 template <typename T>
1096 void DATrie<T>::dump(value_type *data, std::size_t size) const {
1097  d->dump(data, size);
1098 }
1099 
1100 template <typename T>
1101 void DATrie<T>::dump(std::vector<typename DATrie<T>::value_type> &data) const {
1102  data.resize(size());
1103  d->dump(data.data(), data.size());
1104 }
1105 
1106 template <typename T>
1107 void DATrie<T>::dump(
1108  std::vector<std::tuple<typename DATrie<T>::value_type, size_t,
1109  typename DATrie<T>::position_type>> &data) const {
1110  data.resize(size());
1111  d->dump(data.data(), data.size());
1112 }
1113 
1114 template <typename T>
1115 bool DATrie<T>::erase(const char *key, size_t len, position_type from) {
1116  return d->erase(key, len, typename DATriePrivate<T>::npos_t(from)) == 0;
1117 }
1118 
1119 template <typename T>
1120 bool DATrie<T>::erase(position_type from) {
1121  return d->erase("", 0, typename DATriePrivate<T>::npos_t(from)) == 0;
1122 }
1123 
1124 template <typename T>
1125 typename DATrie<T>::value_type DATrie<T>::exactMatchSearch(const char *key,
1126  size_t len) const {
1127  return decodeImpl<T>(exactMatchSearchRaw(key, len));
1128 }
1129 
1130 template <typename T>
1131 int32_t DATrie<T>::exactMatchSearchRaw(const char *key, size_t len) const {
1132  size_t pos = 0;
1133  typename DATriePrivate<value_type>::npos_t npos;
1134  auto resultRaw = d->_find(key, npos, pos, len);
1135  if (resultRaw == DATriePrivate<value_type>::CEDAR_NO_PATH) {
1137  }
1138  return resultRaw;
1139 }
1140 
1141 template <typename T>
1142 bool DATrie<T>::hasExactMatch(std::string_view key) const {
1143  return isValid(exactMatchSearch(key));
1144 }
1145 
1146 template <typename T>
1147 typename DATrie<T>::value_type DATrie<T>::traverse(const char *key, size_t len,
1148  position_type &from) const {
1149  return decodeImpl<T>(traverseRaw(key, len, from));
1150 }
1151 
1152 template <typename T>
1153 int32_t DATrie<T>::traverseRaw(const char *key, size_t len,
1154  position_type &from) const {
1155  size_t pos = 0;
1156  typename DATriePrivate<T>::npos_t npos(from);
1157  auto result = d->_find(key, npos, pos, len);
1158  from = npos.toInt();
1159  return result;
1160 }
1161 
1162 template <typename T>
1163 void DATrie<T>::clear() {
1164  d->clear();
1165 }
1166 
1167 template <typename T>
1168 void DATrie<T>::shrink_tail() {
1169  d->shrink_tail();
1170 }
1171 
1172 template <typename T>
1173 bool DATrie<T>::isNoPath(value_type v) {
1174  typename DATriePrivate<T>::decoder_type d;
1175  d.result_value = v;
1176  return isNoPathRaw(d.result);
1177 }
1178 
1179 template <typename T>
1180 bool DATrie<T>::isNoValue(value_type v) {
1181  typename DATriePrivate<T>::decoder_type d;
1182  d.result_value = v;
1183  return isNoValueRaw(d.result);
1184 }
1185 
1186 template <typename T>
1187 bool DATrie<T>::isValid(value_type v) {
1188  typename DATriePrivate<T>::decoder_type d;
1189  d.result_value = v;
1190  return isValidRaw(d.result);
1191 }
1192 
1193 template <typename T>
1194 bool DATrie<T>::isNoPathRaw(int32_t v) {
1196 }
1197 
1198 template <typename T>
1199 bool DATrie<T>::isNoValueRaw(int32_t v) {
1201 }
1202 
1203 template <typename T>
1204 bool DATrie<T>::isValidRaw(int32_t v) {
1205  return !(isNoPathRaw(v) || isNoValueRaw(v));
1206 }
1207 
1208 template <typename T>
1209 T DATrie<T>::noPath() {
1210  return decodeImpl<T>(DATriePrivate<value_type>::CEDAR_NO_PATH);
1211 }
1212 
1213 template <typename T>
1214 T DATrie<T>::noValue() {
1215  return decodeImpl<T>(DATriePrivate<value_type>::CEDAR_NO_VALUE);
1216 }
1217 
1218 template <typename T>
1219 T DATrie<T>::decode(int32_t raw) {
1220  return decodeImpl<T>(raw);
1221 }
1222 
1223 template <typename T>
1224 size_t DATrie<T>::mem_size() const {
1225  // std::cout << "tail" << d->m_tail.size() << std::endl
1226  // << "tail0" << d->m_tail0.size() * sizeof(int) << std::endl
1227  // << "array" << sizeof(typename
1228  // decltype(d->m_array)::value_type) *
1229  // d->m_array.size() << std::endl
1230  // << "block" << sizeof(typename
1231  // decltype(d->m_block)::value_type) *
1232  // d->m_block.size() << std::endl
1233  // << "ninfo" << sizeof(typename
1234  // decltype(d->m_ninfo)::value_type) *
1235  // d->m_ninfo.size() << std::endl;
1236  return d->m_tail.size() + (d->m_tail0.size() * sizeof(int)) +
1237  (sizeof(typename decltype(d->m_array)::value_type) *
1238  d->m_array.size()) +
1239  (sizeof(typename decltype(d->m_block)::value_type) *
1240  d->m_block.size()) +
1241  (sizeof(typename decltype(d->m_ninfo)::value_type) *
1242  d->m_ninfo.size());
1243 }
1244 
1245 template class DATrie<float>;
1246 template class DATrie<int32_t>;
1247 template class DATrie<uint32_t>;
1248 } // namespace libime
Provide a DATrie implementation.
This is a trie based on cedar<www.tkl.iis.u-tokyo.ac.jp/~ynaga/cedar/>.
Definition: datrie.h:55