libime
pinyindictionary.cpp
1 /*
2  * SPDX-FileCopyrightText: 2017-2017 CSSlayer <wengxt@gmail.com>
3  *
4  * SPDX-License-Identifier: LGPL-2.1-or-later
5  */
6 
7 #include "pinyindictionary.h"
8 #include <algorithm>
9 #include <cassert>
10 #include <cmath>
11 #include <cstddef>
12 #include <cstdint>
13 #include <fstream>
14 #include <iomanip>
15 #include <ios>
16 #include <istream>
17 #include <iterator>
18 #include <list>
19 #include <memory>
20 #include <optional>
21 #include <ostream>
22 #include <queue>
23 #include <stdexcept>
24 #include <string>
25 #include <string_view>
26 #include <tuple>
27 #include <unordered_set>
28 #include <utility>
29 #include <vector>
30 #include <boost/container_hash/hash.hpp>
31 #include <fcitx-utils/macros.h>
32 #include <fcitx-utils/signals.h>
33 #include <fcitx-utils/stringutils.h>
34 #include "libime/core/datrie.h"
35 #include "libime/core/dictionary.h"
36 #include "libime/core/languagemodel.h"
37 #include "libime/core/lattice.h"
38 #include "libime/core/lrucache.h"
39 #include "libime/core/segmentgraph.h"
40 #include "libime/core/triedictionary.h"
41 #include "libime/core/utils.h"
42 #include "libime/core/utils_p.h"
43 #include "libime/core/zstdfilter.h"
44 #include "libime/pinyin/pinyinmatchstate.h"
45 #include "constants.h"
46 #include "pinyindecoder_p.h"
47 #include "pinyinencoder.h"
48 #include "pinyinmatchstate_p.h"
49 
50 namespace libime {
51 
52 namespace {
53 const float fuzzyCost = std::log10(0.5F);
54 const size_t minimumLongWordLength = 3;
55 const float invalidPinyinCost = -100.0F;
56 const char pinyinHanziSep = '!';
57 
58 constexpr uint32_t pinyinBinaryFormatMagic = 0x000fc613;
59 constexpr uint32_t pinyinBinaryFormatVersion = 0x2;
60 
61 struct PinyinSegmentGraphPathHasher {
62  PinyinSegmentGraphPathHasher(const SegmentGraph &graph) : graph_(graph) {}
63 
64  // Generate a "|" separated raw pinyin string from given path, skip all
65  // separator.
66  std::string pathToPinyins(const SegmentGraphPath &path) const {
67  std::string result;
68  result.reserve(path.size() + path.back()->index() -
69  path.front()->index() + 1);
70  const auto &data = graph_.data();
71  auto iter = path.begin();
72  while (iter + 1 < path.end()) {
73  auto begin = (*iter)->index();
74  auto end = (*std::next(iter))->index();
75  iter++;
76  if (data[begin] == '\'') {
77  continue;
78  }
79  while (begin < end) {
80  result.push_back(data[begin]);
81  begin++;
82  }
83  result.push_back('|');
84  }
85  return result;
86  }
87 
88  // Generate hash for path but avoid allocate the string.
89  size_t operator()(const SegmentGraphPath &path) const {
90  if (path.size() <= 1) {
91  return 0;
92  }
93  boost::hash<char> hasher;
94 
95  size_t seed = 0;
96  const auto &data = graph_.data();
97  auto iter = path.begin();
98  while (iter + 1 < path.end()) {
99  auto begin = (*iter)->index();
100  auto end = (*std::next(iter))->index();
101  iter++;
102  if (data[begin] == '\'') {
103  continue;
104  }
105  while (begin < end) {
106  boost::hash_combine(seed, hasher(data[begin]));
107  begin++;
108  }
109  boost::hash_combine(seed, hasher('|'));
110  }
111  return seed;
112  }
113 
114  // Check equality of pinyin string and the path. The string s should be
115  // equal to pathToPinyins(path), but this function just try to avoid
116  // allocate a string for comparison.
117  bool operator()(const SegmentGraphPath &path, const std::string &s) const {
118  if (path.size() <= 1) {
119  return false;
120  }
121  auto is = s.begin();
122  const auto &data = graph_.data();
123  auto iter = path.begin();
124  while (iter + 1 < path.end() && is != s.end()) {
125  auto begin = (*iter)->index();
126  auto end = (*std::next(iter))->index();
127  iter++;
128  if (data[begin] == '\'') {
129  continue;
130  }
131  while (begin < end && is != s.end()) {
132  if (*is != data[begin]) {
133  return false;
134  }
135  is++;
136  begin++;
137  }
138  if (begin != end) {
139  return false;
140  }
141 
142  if (is == s.end() || *is != '|') {
143  return false;
144  }
145  is++;
146  }
147  return iter + 1 == path.end() && is == s.end();
148  }
149 
150 private:
151  const SegmentGraph &graph_;
152 };
153 
154 struct SegmentGraphNodeGreater {
155  bool operator()(const SegmentGraphNode *lhs,
156  const SegmentGraphNode *rhs) const {
157  return lhs->index() > rhs->index();
158  }
159 };
160 
161 // Check if the prev not is a pinyin. Separator always contrains in its own
162 // segment.
163 const SegmentGraphNode *prevIsSeparator(const SegmentGraph &graph,
164  const SegmentGraphNode &node) {
165  if (node.prevSize() == 1) {
166  const auto range = node.prevs();
167  const auto &prev = range.front();
168  auto pinyin = graph.segment(prev, node);
169  if (pinyin.starts_with("\'")) {
170  return &prev;
171  }
172  }
173  return nullptr;
174 }
175 
176 inline void searchOneStep(
177  std::list<std::pair<const PinyinTrie *, PinyinTrie::position_type>> &nodes,
178  char current) {
179  std::list<std::pair<const PinyinTrie *, PinyinTrie::position_type>>
180  extraNodes;
181  auto iter = nodes.begin();
182  while (iter != nodes.end()) {
183  if (current != 0) {
184  const auto resultRaw =
185  iter->first->traverseRaw(&current, 1, iter->second);
186 
187  if (PinyinTrie::isNoPathRaw(resultRaw)) {
188  nodes.erase(iter++);
189  } else {
190  iter++;
191  }
192  } else {
193  bool changed = false;
194  for (char test = PinyinEncoder::firstFinal;
195  test <= PinyinEncoder::lastFinal; test++) {
196  decltype(extraNodes)::value_type p = *iter;
197  const auto resultRaw = p.first->traverseRaw(&test, 1, p.second);
198  if (!PinyinTrie::isNoPathRaw(resultRaw)) {
199  extraNodes.push_back(p);
200  changed = true;
201  }
202  }
203  if (changed) {
204  *iter = extraNodes.back();
205  extraNodes.pop_back();
206  iter++;
207  } else {
208  nodes.erase(iter++);
209  }
210  }
211  }
212  nodes.splice(nodes.end(), std::move(extraNodes));
213 }
214 
215 size_t fuzzyFactor(PinyinFuzzyFlags flags) {
216  size_t factor = 0;
217  if (flags.test(PinyinFuzzyFlag::Correction)) {
218  flags = flags.unset(PinyinFuzzyFlag::Correction);
219  factor += PINYIN_CORRECTION_FUZZY_FACTOR;
220  }
221  if (flags.test(PinyinFuzzyFlag::AdvancedTypo)) {
222  flags = flags.unset(PinyinFuzzyFlag::AdvancedTypo);
223  factor += PINYIN_ADVACNED_TYPO_FUZZY_FACTOR;
224  }
225  if (flags != 0) {
226  factor += 1;
227  }
228  return factor;
229 }
230 
231 PinyinDictionary::TrieType loadTextImpl(std::istream &in) {
232  PinyinDictionary::TrieType trie;
233 
234  size_t lineNo = 0;
235  std::string lineBuf;
236  while (!in.eof()) {
237  if (!std::getline(in, lineBuf)) {
238  break;
239  }
240  lineNo++;
241 
242  std::string_view line = lineBuf;
243  std::vector<std::string> tokens;
244  while (!line.empty()) {
245  std::string token;
246  auto consumed = fcitx::stringutils::consumeMaybeEscapedValue(
247  line, FCITX_WHITESPACE, &token);
248  if (!consumed.empty()) {
249  tokens.push_back(std::string(token));
250  }
251  }
252  if (tokens.size() == 3 || tokens.size() == 2) {
253  const std::string &hanzi = tokens[0];
254  std::string_view pinyin = tokens[1];
255 
256  try {
257  float prob = 0.0F;
258  if (tokens.size() == 3) {
259  prob = std::stof(tokens[2]);
260  }
262  pinyin, PinyinFuzzyFlag::VE_UE);
263  result.push_back(pinyinHanziSep);
264  result.insert(result.end(), hanzi.begin(), hanzi.end());
265  trie.set(result.data(), result.size(), prob);
266  } catch (const std::invalid_argument &e) {
267  LIBIME_ERROR()
268  << "Skipped line " << lineNo << ", exception: " << e.what();
269  }
270  }
271  }
272  return trie;
273 }
274 
275 PinyinDictionary::TrieType loadBinaryImpl(std::istream &in) {
276  PinyinDictionary::TrieType trie;
277  uint32_t magic = 0;
278  uint32_t version = 0;
279  throw_if_io_fail(unmarshall(in, magic));
280  if (magic != pinyinBinaryFormatMagic) {
281  throw std::invalid_argument("Invalid pinyin magic.");
282  }
283  throw_if_io_fail(unmarshall(in, version));
284  switch (version) {
285  case 0x1:
286  trie.load(in);
287  break;
288  case pinyinBinaryFormatVersion:
289  readZSTDCompressed(
290  in, [&trie](std::istream &compressIn) { trie.load(compressIn); });
291  break;
292  default:
293  throw std::invalid_argument("Invalid pinyin version.");
294  break;
295  }
296  return trie;
297 }
298 
299 } // namespace
300 
302 public:
303  explicit PinyinMatchContext(
304  const SegmentGraph &graph, const GraphMatchCallback &callback,
305  const std::unordered_set<const SegmentGraphNode *> &ignore,
306  PinyinMatchState *matchState)
307  : graph_(graph), hasher_(graph), callback_(callback), ignore_(ignore),
308  matchedPathsMap_(&matchState->d_func()->matchedPaths_),
309  nodeCacheMap_(&matchState->d_func()->nodeCacheMap_),
310  matchCacheMap_(&matchState->d_func()->matchCacheMap_),
311  flags_(matchState->fuzzyFlags()),
312  spProfile_(matchState->shuangpinProfile()),
313  correctionProfile_(matchState->correctionProfile()),
314  partialLongWordLimit_(matchState->partialLongWordLimit()) {}
315 
316  explicit PinyinMatchContext(
317  const SegmentGraph &graph, const GraphMatchCallback &callback,
318  const std::unordered_set<const SegmentGraphNode *> &ignore,
319  NodeToMatchedPinyinPathsMap &matchedPaths)
320  : graph_(graph), hasher_(graph), callback_(callback), ignore_(ignore),
321  matchedPathsMap_(&matchedPaths) {}
322 
323  PinyinMatchContext(const PinyinMatchContext &) = delete;
324 
325  const SegmentGraph &graph_;
326  PinyinSegmentGraphPathHasher hasher_;
327 
328  const GraphMatchCallback &callback_;
329  const std::unordered_set<const SegmentGraphNode *> &ignore_;
330  NodeToMatchedPinyinPathsMap *matchedPathsMap_;
331  PinyinTrieNodeCache *nodeCacheMap_ = nullptr;
332  PinyinMatchResultCache *matchCacheMap_ = nullptr;
333  PinyinFuzzyFlags flags_{PinyinFuzzyFlag::None};
334  std::shared_ptr<const ShuangpinProfile> spProfile_;
335  std::shared_ptr<const PinyinCorrectionProfile> correctionProfile_;
336  size_t partialLongWordLimit_ = 0;
337 };
338 
339 class PinyinDictionaryPrivate : fcitx::QPtrHolder<PinyinDictionary> {
340 public:
342  : fcitx::QPtrHolder<PinyinDictionary>(q) {}
343 
344  void addEmptyMatch(const PinyinMatchContext &context,
345  const SegmentGraphNode &currentNode,
346  MatchedPinyinPaths &currentMatches) const;
347 
348  void findMatchesBetween(const PinyinMatchContext &context,
349  const SegmentGraphNode &prevNode,
350  const SegmentGraphNode &currentNode,
351  MatchedPinyinPaths &currentMatches) const;
352 
353  bool matchWords(const PinyinMatchContext &context,
354  const MatchedPinyinPaths &newPaths) const;
355  bool matchWordsForOnePath(const PinyinMatchContext &context,
356  const MatchedPinyinPath &path) const;
357 
358  void matchNode(const PinyinMatchContext &context,
359  const SegmentGraphNode &currentNode) const;
360 
361  fcitx::ScopedConnection conn_;
362  std::vector<PinyinDictFlags> flags_;
363 };
364 
365 void PinyinDictionaryPrivate::addEmptyMatch(
366  const PinyinMatchContext &context, const SegmentGraphNode &currentNode,
367  MatchedPinyinPaths &currentMatches) const {
368  FCITX_Q();
369  const SegmentGraph &graph = context.graph_;
370  // Create a new starting point for current node, and put it in matchResult.
371  if (&currentNode != &graph.end() &&
372  !graph.segment(currentNode.index(), currentNode.index() + 1)
373  .starts_with("\'")) {
374  SegmentGraphPath vec;
375  if (const auto *prev = prevIsSeparator(graph, currentNode)) {
376  vec.push_back(prev);
377  }
378 
379  vec.push_back(&currentNode);
380  for (size_t i = 0; i < q->dictSize(); i++) {
381  if (flags_[i].test(PinyinDictFlag::FullMatch) &&
382  &currentNode != &graph.start()) {
383  continue;
384  }
385  if (flags_[i].test(PinyinDictFlag::Disabled)) {
386  continue;
387  }
388  const auto &trie = *q->trie(i);
389  currentMatches.emplace_back(&trie, 0, vec, flags_[i]);
390  currentMatches.back().triePositions().emplace_back(0, 0);
391  }
392  }
393 }
394 
395 PinyinTriePositions traverseAlongPathOneStepBySyllables(
396  const MatchedPinyinPath &path,
397  const MatchedPinyinSyllablesWithFuzzyFlags &syls) {
398  PinyinTriePositions positions;
399  for (const auto &pr : path.triePositions()) {
400  uint64_t _pos;
401  size_t fuzzies;
402  std::tie(_pos, fuzzies) = pr;
403  for (const auto &syl : syls) {
404  // make a copy
405  auto pos = _pos;
406  auto initial = static_cast<char>(syl.first);
407  const auto resultRaw = path.trie()->traverseRaw(&initial, 1, pos);
408  if (PinyinTrie::isNoPathRaw(resultRaw)) {
409  continue;
410  }
411  const auto &finals = syl.second;
412 
413  auto updateNext = [fuzzies, &path, &positions](PinyinFinal pyFinal,
414  size_t fuzzyFactor,
415  auto pos) {
416  auto final = static_cast<char>(pyFinal);
417  const auto resultRaw = path.trie()->traverseRaw(&final, 1, pos);
418 
419  if (!PinyinTrie::isNoPathRaw(resultRaw)) {
420  size_t newFuzzies = fuzzies + fuzzyFactor;
421  positions.emplace_back(pos, newFuzzies);
422  }
423  };
424  if (finals.size() > 1 || finals[0].first != PinyinFinal::Invalid) {
425  for (auto final : finals) {
426  updateNext(final.first, fuzzyFactor(final.second), pos);
427  }
428  } else if (!path.flags_.test(PinyinDictFlag::FullMatch)) {
429  for (char test = PinyinEncoder::firstFinal;
430  test <= PinyinEncoder::lastFinal; test++) {
431  updateNext(static_cast<PinyinFinal>(test), 1, pos);
432  }
433  }
434  }
435  }
436  return positions;
437 }
438 
439 template <typename T>
440 void matchWordsOnTrie(const PinyinTrie *userDict, const MatchedPinyinPath &path,
441  bool matchLongWord, const T &callback) {
442  for (const auto &pr : path.triePositions()) {
443  uint64_t pos;
444  size_t fuzzies;
445  std::tie(pos, fuzzies) = pr;
446  const float extraCost = fuzzies * fuzzyCost;
447  // This is an inaccuration estimation, since fuzzies may contain real
448  // fuzzy pinyin. But since this value is 10, it is a good estimate.
449  // After all 10 fuzzies in a word is kinda impossible.
450  const bool isCorrection = fuzzies >= PINYIN_CORRECTION_FUZZY_FACTOR;
451  if (matchLongWord) {
452  path.trie()->foreach(
453  [userDict, &path, &callback, extraCost, isCorrection](
454  PinyinTrie::value_type value, size_t len, uint64_t pos) {
455  std::string s;
456  s.reserve(len + (path.size() * 2));
457  path.trie()->suffix(s, len + (path.size() * 2), pos);
458  if (size_t separator =
459  s.find(pinyinHanziSep, path.size() * 2);
460  separator != std::string::npos) {
461  std::string_view view(s);
462  auto encodedPinyin = view.substr(0, separator);
463  auto hanzi = view.substr(separator + 1);
464  const size_t lengthDiff =
465  ((encodedPinyin.size() / 2) - path.size());
466  // Don't match long word for "custom".
467  if (path.trie() == userDict && value < 0 &&
468  lengthDiff > 0) {
469  return true;
470  }
471  float overLengthCost = fuzzyCost * lengthDiff;
472 
473  callback(encodedPinyin, hanzi,
474  value + extraCost + overLengthCost,
475  isCorrection);
476  }
477  return true;
478  },
479  pos);
480  } else {
481  const char sep = pinyinHanziSep;
482  const auto resultRaw = path.trie()->traverseRaw(&sep, 1, pos);
483  if (PinyinTrie::isNoPathRaw(resultRaw)) {
484  continue;
485  }
486 
487  path.trie()->foreach(
488  [&path, &callback, extraCost, isCorrection](
489  PinyinTrie::value_type value, size_t len, uint64_t pos) {
490  std::string s;
491  s.reserve(len + (path.size() * 2) + 1);
492  path.trie()->suffix(s, len + (path.size() * 2) + 1, pos);
493  std::string_view view(s);
494  auto encodedPinyin = view.substr(0, path.size() * 2);
495  auto hanzi = view.substr((path.size() * 2) + 1);
496  callback(encodedPinyin, hanzi, value + extraCost,
497  isCorrection);
498  return true;
499  },
500  pos);
501  }
502  }
503 }
504 
505 bool PinyinDictionaryPrivate::matchWordsForOnePath(
506  const PinyinMatchContext &context, const MatchedPinyinPath &path) const {
507  FCITX_Q();
508  bool matched = false;
509  assert(path.path_.size() >= 2);
510  const SegmentGraphNode &prevNode = *path.path_[path.path_.size() - 2];
511 
512  if (path.flags_.test(PinyinDictFlag::FullMatch) &&
513  (path.path_.front() != &context.graph_.start() ||
514  path.path_.back() != &context.graph_.end())) {
515  return false;
516  }
517 
518  // minimumLongWordLength is to prevent algorithm runs too slow.
519  const bool matchLongWordEnabled =
520  context.partialLongWordLimit_ &&
521  std::max(minimumLongWordLength, context.partialLongWordLimit_) + 1 <=
522  path.path_.size() &&
523  !path.flags_.test(PinyinDictFlag::FullMatch);
524 
525  const bool matchLongWord =
526  (path.path_.back() == &context.graph_.end() && matchLongWordEnabled);
527 
528  auto foundOneWord = [&path, &prevNode, &matched, &context](
529  std::string_view encodedPinyin, WordNode &word,
530  float cost, bool isCorrection) {
531  context.callback_(path.path_, word, cost,
532  std::make_unique<PinyinLatticeNodePrivate>(
533  encodedPinyin, isCorrection));
534  if (path.size() == 1 &&
535  path.path_[path.path_.size() - 2] == &prevNode) {
536  matched = true;
537  }
538  };
539 
540  if (context.matchCacheMap_) {
541  auto &matchCache = (*context.matchCacheMap_)[path.trie()];
542  auto *result =
543  matchCache.find(path.path_, context.hasher_, context.hasher_);
544  if (!result) {
545  result =
546  matchCache.insert(context.hasher_.pathToPinyins(path.path_));
547  result->clear();
548 
549  auto &items = *result;
550  matchWordsOnTrie(
551  q->trie(PinyinDictionary::UserDict), path, matchLongWordEnabled,
552  [&items](std::string_view encodedPinyin, std::string_view hanzi,
553  float cost, bool isCorrection) {
554  items.emplace_back(hanzi, cost, encodedPinyin,
555  isCorrection);
556  });
557  }
558  for (auto &item : *result) {
559  if (!matchLongWord &&
560  item.encodedPinyin_.size() / 2 > path.size()) {
561  continue;
562  }
563  foundOneWord(item.encodedPinyin_, item.word_, item.value_,
564  item.isCorrection_);
565  }
566  } else {
567  matchWordsOnTrie(
568  q->trie(PinyinDictionary::UserDict), path, matchLongWord,
569  [&foundOneWord](std::string_view encodedPinyin,
570  std::string_view hanzi, float cost,
571  bool isCorrection) {
572  WordNode word(hanzi, InvalidWordIndex);
573  foundOneWord(encodedPinyin, word, cost, isCorrection);
574  });
575  }
576 
577  return matched;
578 }
579 
580 bool PinyinDictionaryPrivate::matchWords(
581  const PinyinMatchContext &context,
582  const MatchedPinyinPaths &newPaths) const {
583  bool matched = false;
584  for (const auto &path : newPaths) {
585  matched |= matchWordsForOnePath(context, path);
586  }
587 
588  return matched;
589 }
590 
591 void PinyinDictionaryPrivate::findMatchesBetween(
592  const PinyinMatchContext &context, const SegmentGraphNode &prevNode,
593  const SegmentGraphNode &currentNode,
594  MatchedPinyinPaths &currentMatches) const {
595  const SegmentGraph &graph = context.graph_;
596  auto &matchedPathsMap = *context.matchedPathsMap_;
597  auto pinyin = graph.segment(prevNode, currentNode);
598  // If predecessor is a separator, just copy every existing match result
599  // over and don't traverse on the trie.
600  if (pinyin.starts_with("\'")) {
601  const auto &prevMatches = matchedPathsMap[&prevNode];
602  for (const auto &match : prevMatches) {
603  // copy the path, and append current node.
604  auto path = match.path_;
605  path.push_back(&currentNode);
606  currentMatches.emplace_back(match.result_, std::move(path),
607  match.flags_);
608  }
609  // If the last segment is separator, there
610  if (&currentNode == &graph.end()) {
611  WordNode word("", 0);
612  context.callback_({&prevNode, &currentNode}, word, 0, nullptr);
613  }
614  return;
615  }
616 
617  const auto syls =
618  context.spProfile_
619  ? PinyinEncoder::shuangpinToSyllablesWithFuzzyFlags(
620  pinyin, *context.spProfile_, context.flags_)
621  : PinyinEncoder::stringToSyllablesWithFuzzyFlags(
622  pinyin, context.correctionProfile_.get(), context.flags_);
623  const MatchedPinyinPaths &prevMatchedPaths = matchedPathsMap[&prevNode];
624  MatchedPinyinPaths newPaths;
625  for (const auto &path : prevMatchedPaths) {
626  // Make a copy of path so we can modify based on it.
627  auto segmentPath = path.path_;
628  segmentPath.push_back(&currentNode);
629 
630  // A map from trie (dict) to a lru cache.
631  if (context.nodeCacheMap_) {
632  auto &nodeCache = (*context.nodeCacheMap_)[path.trie()];
633  auto *p =
634  nodeCache.find(segmentPath, context.hasher_, context.hasher_);
635  std::shared_ptr<MatchedPinyinTrieNodes> result;
636  if (!p) {
637  result = std::make_shared<MatchedPinyinTrieNodes>(
638  path.trie(), path.size() + 1);
639  nodeCache.insert(context.hasher_.pathToPinyins(segmentPath),
640  result);
641  result->triePositions_ =
642  traverseAlongPathOneStepBySyllables(path, syls);
643  } else {
644  result = *p;
645  assert(result->size_ == path.size() + 1);
646  }
647 
648  if (!result->triePositions_.empty()) {
649  newPaths.emplace_back(result, segmentPath, path.flags_);
650  }
651  } else {
652  // make an empty one
653  newPaths.emplace_back(path.trie(), path.size() + 1, segmentPath,
654  path.flags_);
655 
656  newPaths.back().result_->triePositions_ =
657  traverseAlongPathOneStepBySyllables(path, syls);
658  // if there's nothing, pop it.
659  if (newPaths.back().triePositions().empty()) {
660  newPaths.pop_back();
661  }
662  }
663  }
664 
665  if (!context.ignore_.contains(&currentNode)) {
666  // after we match current syllable, we first try to match word.
667  if (!matchWords(context, newPaths)) {
668  // If we failed to match any length 1 word, add a new empty word
669  // to make lattice connect together.
670  SegmentGraphPath vec;
671  vec.reserve(3);
672  if (const auto *prevPrev =
673  prevIsSeparator(context.graph_, prevNode)) {
674  vec.push_back(prevPrev);
675  }
676  vec.push_back(&prevNode);
677  vec.push_back(&currentNode);
678  WordNode word(pinyin, InvalidWordIndex);
679  context.callback_(vec, word, invalidPinyinCost, nullptr);
680  }
681  }
682 
683  std::move(newPaths.begin(), newPaths.end(),
684  std::back_inserter(currentMatches));
685 }
686 
687 void PinyinDictionaryPrivate::matchNode(
688  const PinyinMatchContext &context,
689  const SegmentGraphNode &currentNode) const {
690  auto &matchedPathsMap = *context.matchedPathsMap_;
691  // Check if the node has been searched already.
692  if (matchedPathsMap.contains(&currentNode)) {
693  return;
694  }
695  auto &currentMatches = matchedPathsMap[&currentNode];
696  // To create a new start.
697  addEmptyMatch(context, currentNode, currentMatches);
698 
699  // Iterate all predecessor and search from them.
700  for (const auto &prevNode : currentNode.prevs()) {
701  findMatchesBetween(context, prevNode, currentNode, currentMatches);
702  }
703 }
704 
705 void PinyinDictionary::matchPrefixImpl(
706  const SegmentGraph &graph, const GraphMatchCallback &callback,
707  const std::unordered_set<const SegmentGraphNode *> &ignore,
708  void *helper) const {
709  FCITX_D();
710 
711  NodeToMatchedPinyinPathsMap localMatchedPaths;
712  PinyinMatchContext context =
713  helper ? PinyinMatchContext{graph, callback, ignore,
714  static_cast<PinyinMatchState *>(helper)}
715  : PinyinMatchContext{graph, callback, ignore, localMatchedPaths};
716 
717  // A queue to make sure that node with smaller index will be visted first
718  // because we want to make sure every predecessor node are visited before
719  // visit the current node.
720  using SegmentGraphNodeQueue =
721  std::priority_queue<const SegmentGraphNode *,
722  std::vector<const SegmentGraphNode *>,
724  SegmentGraphNodeQueue q;
725 
726  const auto &start = graph.start();
727  q.push(&start);
728 
729  // The match is done with a bfs.
730  // E.g
731  // xian is
732  // start - xi - an - end
733  // \ /
734  // -- xian ---
735  // We start with start, then xi, then an and xian, then end.
736  while (!q.empty()) {
737  const auto *currentNode = q.top();
738  q.pop();
739 
740  // Push successors into the queue.
741  for (const auto &node : currentNode->nexts()) {
742  q.push(&node);
743  }
744 
745  d->matchNode(context, *currentNode);
746  }
747 }
748 
749 void PinyinDictionary::matchWords(const char *data, size_t size,
750  PinyinMatchCallback callback) const {
751  if (!PinyinEncoder::isValidUserPinyin(data, size)) {
752  return;
753  }
754 
755  FCITX_D();
756  std::list<std::pair<const PinyinTrie *, PinyinTrie::position_type>> nodes;
757  for (size_t i = 0; i < dictSize(); i++) {
758  if (d->flags_[i].test(PinyinDictFlag::Disabled)) {
759  continue;
760  }
761  const auto &trie = *this->trie(i);
762  nodes.emplace_back(&trie, 0);
763  }
764  for (size_t i = 0; i <= size && !nodes.empty(); i++) {
765  char current;
766  if (i < size) {
767  current = data[i];
768  } else {
769  current = pinyinHanziSep;
770  }
771  searchOneStep(nodes, current);
772  }
773 
774  for (auto &node : nodes) {
775  node.first->foreach(
776  [&node, &callback, size](PinyinTrie::value_type value, size_t len,
777  uint64_t pos) {
778  std::string s;
779  node.first->suffix(s, len + size + 1, pos);
780 
781  auto view = std::string_view(s);
782  return callback(view.substr(0, size), view.substr(size + 1),
783  value);
784  },
785  node.second);
786  }
787 }
788 
789 void PinyinDictionary::matchWordsPrefix(const char *data, size_t size,
790  PinyinMatchCallback callback) const {
791  if (!PinyinEncoder::isValidUserPinyin(data, size)) {
792  return;
793  }
794 
795  FCITX_D();
796  std::list<std::pair<const PinyinTrie *, PinyinTrie::position_type>> nodes;
797  for (size_t i = 0; i < dictSize(); i++) {
798  if (d->flags_[i].test(PinyinDictFlag::Disabled)) {
799  continue;
800  }
801  const auto &trie = *this->trie(i);
802  nodes.emplace_back(&trie, 0);
803  }
804  for (size_t i = 0; i < size && !nodes.empty(); i++) {
805  searchOneStep(nodes, data[i]);
806  }
807 
808  for (auto &node : nodes) {
809  node.first->foreach(
810  [&node, &callback, size](PinyinTrie::value_type value, size_t len,
811  uint64_t pos) {
812  std::string s;
813  node.first->suffix(s, len + size, pos);
814 
815  std::string_view view(s);
816  if (auto sep = view.find(pinyinHanziSep, size);
817  sep != std::string::npos) {
818  return callback(view.substr(0, sep), view.substr(sep + 1),
819  value);
820  }
821  return true;
822  },
823  node.second);
824  }
825 }
826 
827 PinyinDictionary::PinyinDictionary()
828  : d_ptr(std::make_unique<PinyinDictionaryPrivate>(this)) {
829  FCITX_D();
830  d->conn_ = connect<TrieDictionary::dictSizeChanged>([this](size_t size) {
831  FCITX_D();
832  d->flags_.resize(size);
833  });
834  d->flags_.resize(dictSize());
835 }
836 
837 PinyinDictionary::~PinyinDictionary() {}
838 
839 void PinyinDictionary::load(size_t idx, const char *filename,
840  PinyinDictFormat format) {
841  std::ifstream in(filename, std::ios::in | std::ios::binary);
842  throw_if_io_fail(in);
843  load(idx, in, format);
844 }
845 
846 void PinyinDictionary::load(size_t idx, std::istream &in,
847  PinyinDictFormat format) {
848  setTrie(idx, load(in, format));
849 }
850 
851 PinyinDictionary::TrieType PinyinDictionary::load(std::istream &in,
852  PinyinDictFormat format) {
853  switch (format) {
854  case PinyinDictFormat::Text:
855  return loadTextImpl(in);
856  case PinyinDictFormat::Binary:
857  return loadBinaryImpl(in);
858  default:
859  throw std::invalid_argument("invalid format type");
860  }
861 }
862 
863 void PinyinDictionary::loadText(size_t idx, std::istream &in) {
864  *mutableTrie(idx) = loadTextImpl(in);
865 }
866 
867 void PinyinDictionary::loadBinary(size_t idx, std::istream &in) {
868  *mutableTrie(idx) = loadBinaryImpl(in);
869 }
870 
871 void PinyinDictionary::save(size_t idx, const char *filename,
872  PinyinDictFormat format) {
873  std::ofstream fout(filename, std::ios::out | std::ios::binary);
874  throw_if_io_fail(fout);
875  save(idx, fout, format);
876 }
877 
878 void PinyinDictionary::save(size_t idx, std::ostream &out,
879  PinyinDictFormat format) {
880  switch (format) {
881  case PinyinDictFormat::Text:
882  saveText(idx, out);
883  break;
884  case PinyinDictFormat::Binary: {
885  throw_if_io_fail(marshall(out, pinyinBinaryFormatMagic));
886  throw_if_io_fail(marshall(out, pinyinBinaryFormatVersion));
887 
888  writeZSTDCompressed(out, [this, idx](std::ostream &compressOut) {
889  mutableTrie(idx)->save(compressOut);
890  });
891  } break;
892  default:
893  throw std::invalid_argument("invalid format type");
894  }
895 }
896 
897 void PinyinDictionary::saveText(size_t idx, std::ostream &out) {
898  std::string buf;
899  std::ios state(nullptr);
900  state.copyfmt(out);
901  const auto &trie = *this->trie(idx);
902  trie.foreach([&trie, &buf, &out](float value, size_t _len,
903  PinyinTrie::position_type pos) {
904  trie.suffix(buf, _len, pos);
905  auto sep = buf.find(pinyinHanziSep);
906  if (sep == std::string::npos) {
907  return true;
908  }
909  auto fullPinyin = PinyinEncoder::decodeFullPinyin(buf.data(), sep);
910  std::string_view ref(buf);
911  out << fcitx::stringutils::escapeForValue(ref.substr(sep + 1)) << " "
912  << fullPinyin << " " << std::setprecision(16) << value << '\n';
913  return true;
914  });
915  out.copyfmt(state);
916 }
917 
918 void PinyinDictionary::addWord(size_t idx, std::string_view fullPinyin,
919  std::string_view hanzi, float cost) {
921  fullPinyin, PinyinFuzzyFlag::VE_UE);
922  result.push_back(pinyinHanziSep);
923  result.insert(result.end(), hanzi.begin(), hanzi.end());
924  TrieDictionary::addWord(idx, std::string_view(result.data(), result.size()),
925  cost);
926 }
927 
928 std::optional<float>
929 PinyinDictionary::lookupWord(size_t idx, std::string_view fullPinyin,
930  std::string_view hanzi) const {
932  fullPinyin, PinyinFuzzyFlag::VE_UE);
933  result.push_back(pinyinHanziSep);
934  result.insert(result.end(), hanzi.begin(), hanzi.end());
935  auto value = trie(idx)->exactMatchSearchRaw(result.data(), result.size());
936  if (PinyinTrie::isValidRaw(value)) {
937  return value;
938  }
939  return std::nullopt;
940 }
941 
942 bool PinyinDictionary::removeWord(size_t idx, std::string_view fullPinyin,
943  std::string_view hanzi) {
945  fullPinyin, PinyinFuzzyFlag::VE_UE);
946  result.push_back(pinyinHanziSep);
947  result.insert(result.end(), hanzi.begin(), hanzi.end());
948  return TrieDictionary::removeWord(
949  idx, std::string_view(result.data(), result.size()));
950 }
951 
952 void PinyinDictionary::setFlags(size_t idx, PinyinDictFlags flags) {
953  FCITX_D();
954  if (idx >= dictSize()) {
955  return;
956  }
957  d->flags_.resize(dictSize());
958  d->flags_[idx] = flags;
959 }
960 } // namespace libime
PinyinDictionary is a set of dictionaries for Pinyin.
Provide a DATrie implementation.
static std::vector< char > encodeFullPinyinWithFlags(std::string_view pinyin, PinyinFuzzyFlags flags)
Encode a quote separated pinyin string.