libime
historybigram.cpp
1 /*
2  * SPDX-FileCopyrightText: 2017-2017 CSSlayer <wengxt@gmail.com>
3  *
4  * SPDX-License-Identifier: LGPL-2.1-or-later
5  */
6 #include "historybigram.h"
7 #include <algorithm>
8 #include <array>
9 #include <cassert>
10 #include <cmath>
11 #include <cstddef>
12 #include <cstdint>
13 #include <functional>
14 #include <istream>
15 #include <iterator>
16 #include <list>
17 #include <memory>
18 #include <ostream>
19 #include <ranges>
20 #include <stdexcept>
21 #include <string>
22 #include <string_view>
23 #include <unordered_set>
24 #include <utility>
25 #include <vector>
26 #include <fcitx-utils/macros.h>
27 #include <fcitx-utils/stringutils.h>
28 #include "constants.h"
29 #include "datrie.h"
30 #include "lattice.h"
31 #include "utils_p.h"
32 #include "zstdfilter.h"
33 
34 namespace libime {
35 
36 namespace {
37 
38 using WordWithCode = HistoryBigram::WordWithCode;
39 using WordWithCodeView = HistoryBigram::WordWithCodeView;
40 
41 constexpr uint32_t historyBinaryFormatMagic = 0x000fc315;
42 constexpr uint32_t historyBinaryFormatVersion = 0x4;
43 constexpr char bigramSeparator = '\x01';
44 constexpr char wordCodeSeparator = '\x02';
45 
46 std::string wordAndCodeToString(WordWithCodeView wordAndCode) {
47  std::string s{std::get<0>(wordAndCode)};
48  if (s.empty()) {
49  return s;
50  }
51  auto code = std::get<1>(wordAndCode);
52  if (!code.empty()) {
53  s += wordCodeSeparator;
54  s += code;
55  }
56  return s;
57 }
58 
59 WordWithCode bigramWordWithCode(WordWithCodeView prev, WordWithCodeView cur) {
60  std::string s;
61  s.append(std::get<0>(prev));
62  s += bigramSeparator;
63  s.append(std::get<0>(cur));
64 
65  auto code1 = std::get<1>(prev);
66  auto code2 = std::get<1>(cur);
67  std::string concatCode;
68  if (code1.empty() && code2.empty()) {
69  concatCode = "";
70  } else {
71  concatCode = code1;
72  concatCode += bigramSeparator;
73  concatCode += code2;
74  }
75  return {s, concatCode};
76 }
77 
78 struct WeightedTrie {
79  using TrieType = DATrie<int32_t>;
80 
81 public:
82  WeightedTrie() = default;
83 
84  void clear() { trie_.clear(); }
85 
86  const TrieType &trie() const { return trie_; }
87 
88  int32_t weightedSize() const { return weightedSize_; }
89 
90  int32_t freq(WordWithCodeView wordAndCode) const {
91  // If query with code, the match will be {word, ""} + {word, code}.
92  // If query without code, the match will be {word, ""} + {word,
93  // separator}.
94  TrieType::position_type pos = 0;
95  auto result = 0;
96  auto v = trie_.traverse(wordAndCode.first, pos);
97  if (TrieType::isValid(v)) {
98  result += v;
99  } else if (TrieType::isNoPath(v)) {
100  return 0;
101  }
102  const char separator[] = {wordCodeSeparator, '\0'};
103  v = trie_.traverse(separator, pos);
104  if (!TrieType::isNoPath(v)) {
105  if (!wordAndCode.second.empty() &&
106  wordAndCode.second.front() != bigramSeparator &&
107  wordAndCode.second.back() != bigramSeparator) {
108  v = trie_.traverse(wordAndCode.second, pos);
109  if (TrieType::isValid(v)) {
110  result += v;
111  }
112  } else {
113  trie_.foreach(
114  [this, &result, &wordAndCode](TrieType::value_type value,
115  size_t len,
116  TrieType::position_type pos) {
117  if (len == 0) {
118  return true;
119  }
120  if (!wordAndCode.second.empty()) {
121  assert(
122  wordAndCode.second.front() == bigramSeparator ||
123  wordAndCode.second.back() == bigramSeparator);
124  std::string codeInTrie;
125  trie().suffix(codeInTrie, len, pos);
126  if (wordAndCode.second.front() == bigramSeparator &&
127  !codeInTrie.ends_with(wordAndCode.second)) {
128  return true;
129  }
130  if (wordAndCode.second.back() == bigramSeparator &&
131  !codeInTrie.starts_with(wordAndCode.second)) {
132  return true;
133  }
134  }
135  result += value;
136  return true;
137  },
138  pos);
139  }
140  }
141  return result;
142  }
143 
144  void incFreq(WordWithCodeView wordAndCode, int32_t delta) {
145  auto s = wordAndCodeToString(wordAndCode);
146 
147  trie_.update(s.data(), s.size(),
148  [delta](int32_t v) { return v + delta; });
149  weightedSize_ += delta;
150  }
151 
152  void decFreq(WordWithCodeView wordAndCode, int32_t delta) {
153  auto s = wordAndCodeToString(wordAndCode);
154  auto v = trie_.exactMatchSearch(s.data(), s.size());
155  if (TrieType::isNoValue(v)) {
156  return;
157  }
158  if (v <= delta) {
159  trie_.erase(s.data(), s.size());
160  decWeightedSize(v);
161  } else {
162  v -= delta;
163  trie_.set(s.data(), s.size(), v);
164  decWeightedSize(delta);
165  }
166  }
167 
168  void fillPredict(std::unordered_set<std::string> &words,
169  std::string_view word, size_t maxSize) const {
170  trie_.foreach(word,
171  [this, &words, maxSize](TrieType::value_type, size_t len,
172  TrieType::position_type pos) {
173  std::string buf;
174  trie().suffix(buf, len, pos);
175  auto separatorPos = buf.find(wordCodeSeparator);
176  if (separatorPos != std::string::npos) {
177  buf.erase(separatorPos);
178  }
179  // Skip special word.
180  if (buf == "<s>" || buf == "</s>") {
181  return true;
182  }
183  words.emplace(std::move(buf));
184 
185  return maxSize <= 0 || words.size() < maxSize;
186  });
187  }
188 
189 private:
190  void decWeightedSize(int32_t v) {
191  weightedSize_ -= v;
192  weightedSize_ = std::max(weightedSize_, 0);
193  }
194 
195  int32_t weightedSize_ = 0;
196  TrieType trie_;
197 };
198 
199 class HistoryBigramPool {
200 public:
201  HistoryBigramPool(size_t maxSize) : maxSize_(maxSize) {}
202 
203  void load(std::istream &in) {
204  clear();
205  uint32_t count = 0;
206  throw_if_io_fail(unmarshall(in, count));
207  while (count--) {
208  uint32_t size = 0;
209  throw_if_io_fail(unmarshall(in, size));
210  std::vector<WordWithCode> sentence;
211  while (size--) {
212  std::string buffer;
213  throw_if_io_fail(unmarshallString(in, buffer));
214  std::string_view bufferView{buffer};
215  size_t separatorPos = bufferView.find(wordCodeSeparator);
216  if (separatorPos != std::string_view::npos) {
217  sentence.emplace_back(
218  std::string(bufferView.substr(0, separatorPos)),
219  std::string(bufferView.substr(separatorPos + 1)));
220  } else {
221  sentence.emplace_back(std::move(buffer), "");
222  }
223  }
224  add(sentence);
225  }
226  }
227 
228  void loadText(std::istream &in) {
229  clear();
230  std::string buf;
231  std::vector<std::string> lines;
232  while (std::getline(in, buf)) {
233  lines.emplace_back(buf);
234  if (lines.size() >= maxSize_) {
235  break;
236  }
237  }
238  for (auto &line : lines | std::views::reverse) {
239  std::string_view lineView{line};
240  std::vector<std::string> tokens;
241  bool withCode = false;
242  while (!lineView.empty()) {
243  std::string token;
244  auto consumed = fcitx::stringutils::consumeMaybeEscapedValue(
245  lineView, FCITX_WHITESPACE, &token);
246  if (!consumed.empty()) {
247  tokens.push_back(std::move(token));
248  }
249  if (tokens.size() == 1 && !lineView.empty() &&
250  lineView.front() == '\t') {
251  withCode = true;
252  }
253  }
254 
255  if (withCode) {
256  if (tokens.size() % 2 != 0) {
257  continue;
258  }
259  add(std::views::iota(static_cast<size_t>(0),
260  tokens.size() / 2) |
261  std::views::transform([&tokens](size_t i) {
262  return WordWithCode{tokens[i * 2], tokens[(i * 2) + 1]};
263  }));
264 
265  } else {
266  add(tokens |
267  std::views::transform([](const auto &word) -> WordWithCode {
268  std::vector<std::string> wordWithMaybeCode =
269  fcitx::stringutils::split(
270  word, "\t",
271  fcitx::stringutils::SplitBehavior::KeepEmpty);
272  if (wordWithMaybeCode.size() == 2) {
273  return WordWithCode{wordWithMaybeCode[0],
274  wordWithMaybeCode[1]};
275  }
276  return WordWithCode{word, ""};
277  }));
278  }
279  }
280  }
281 
282  void save(std::ostream &out) {
283  uint32_t count = recent_.size();
284  throw_if_io_fail(marshall(out, count));
285  // When we do save, we need to reverse the history order.
286  // Because loading the history is done by call "add", which basically
287  // expect the history from old to new.
288  for (auto &sentence : recent_ | std::views::reverse) {
289  uint32_t size = sentence.size();
290  throw_if_io_fail(marshall(out, size));
291  for (const auto &s : sentence) {
292  throw_if_io_fail(marshallString(out, wordAndCodeToString(s)));
293  }
294  }
295  }
296 
297  void dump(std::ostream &out) const {
298  for (const auto &sentence : recent_) {
299  bool first = true;
300  bool hasCode = std::ranges::any_of(sentence, [](const auto &item) {
301  return !std::get<1>(item).empty();
302  });
303  for (const auto &s : sentence) {
304  if (first) {
305  first = false;
306  } else {
307  out << " ";
308  }
309  out << fcitx::stringutils::escapeForValue(std::get<0>(s));
310  if (hasCode) {
311  out << "\t"
312  << fcitx::stringutils::escapeForValue(std::get<1>(s));
313  }
314  }
315  out << '\n';
316  }
317  }
318 
319  void clear() {
320  recent_.clear();
321  unigram_.clear();
322  bigram_.clear();
323  size_ = 0;
324  }
325 
326  template <typename R>
327  std::list<std::vector<WordWithCode>> add(const R &sentence) {
328  std::list<std::vector<WordWithCode>> popedSentence;
329  if (sentence.empty()) {
330  return popedSentence;
331  }
332  // Validate data.
333  if (std::ranges::any_of(sentence, [](const auto &item) {
334  const auto &[word, code] = item;
335  return word.find('\0') != std::string::npos;
336  })) {
337  return popedSentence;
338  }
339  while (recent_.size() >= maxSize_) {
340  remove(recent_.back());
341  popedSentence.splice(popedSentence.end(), recent_,
342  std::prev(recent_.end()));
343  }
344 
345  std::vector<WordWithCode> newSentence;
346  auto delta = 1;
347  for (auto iter = sentence.begin(), end = sentence.end(); iter != end;
348  iter++) {
349  unigram_.incFreq(*iter, delta);
350  auto next = std::ranges::next(iter);
351  if (next != end) {
352  incBigram(*iter, *next, delta);
353  }
354  newSentence.push_back(*iter);
355  }
356  recent_.push_front(std::move(newSentence));
357  unigram_.incFreq({"<s>", ""}, delta);
358  unigram_.incFreq({"</s>", ""}, delta);
359  incBigram({"<s>", ""}, sentence.front(), delta);
360  incBigram(sentence.back(), {"</s>", ""}, delta);
361 
362  return popedSentence;
363  }
364 
365  int32_t unigramFreq(WordWithCodeView s) const { return unigram_.freq(s); }
366 
367  int32_t bigramFreq(WordWithCodeView s1, WordWithCodeView s2) const {
368  return bigram_.freq(bigramWordWithCode(s1, s2));
369  }
370 
371  bool isUnknown(WordWithCodeView word) const {
372  return unigramFreq(word) == 0;
373  }
374 
375  size_t maxSize() const { return maxSize_; }
376 
377  size_t realSize() const { return recent_.size(); }
378 
379  void forget(std::string_view word, std::string_view code) {
380  auto iter = recent_.begin();
381  while (iter != recent_.end()) {
382  if (std::find_if(
383  iter->begin(), iter->end(), [word, code](const auto &item) {
384  const auto &[w, c] = item;
385  return w == word && (code.empty() || c == code);
386  }) != iter->end()) {
387  remove(*iter);
388  iter = recent_.erase(iter);
389  } else {
390  ++iter;
391  }
392  }
393  }
394 
395  void fillPredict(std::unordered_set<std::string> &words,
396  std::string_view word, size_t maxSize = 0) const {
397  bigram_.fillPredict(words, word, maxSize);
398  }
399 
400  bool maybeAppendToLatestSentence(const std::vector<WordWithCode> &context,
401  std::vector<WordWithCode> &newSentence) {
402  if (recent_.empty() || newSentence.empty()) {
403  return false;
404  }
405  auto &latestSentence = recent_.front();
406  if (latestSentence.size() < context.size() ||
407  !std::ranges::equal(
408  context,
409  std::views::drop(latestSentence,
410  latestSentence.size() - context.size()))) {
411  return false;
412  }
413 
414  const int delta = 1;
415  decBigram(latestSentence.back(), {"</s>", ""}, delta);
416  for (auto &item : newSentence) {
417  unigram_.incFreq(item, delta);
418  incBigram(latestSentence.back(), item, delta);
419  latestSentence.push_back(std::move(item));
420  }
421  incBigram(latestSentence.back(), {"</s>", ""}, delta);
422 
423  return true;
424  }
425 
426 private:
427  template <typename R>
428  void remove(const R &sentence) {
429  const int delta = 1;
430  for (auto iter = sentence.begin(), end = sentence.end(); iter != end;
431  iter++) {
432  unigram_.decFreq(*iter, delta);
433  auto next = std::next(iter);
434  if (next != end) {
435  decBigram(*iter, *next, delta);
436  }
437  }
438  decBigram({"<s>", ""}, sentence.front(), delta);
439  decBigram(sentence.back(), {"</s>", ""}, delta);
440  }
441 
442  void decBigram(WordWithCodeView s1, WordWithCodeView s2, int32_t delta) {
443  bigram_.decFreq(bigramWordWithCode(s1, s2), delta);
444  }
445 
446  void incBigram(WordWithCodeView s1, WordWithCodeView s2, int delta) {
447  bigram_.incFreq(bigramWordWithCode(s1, s2), delta);
448  }
449 
450  const size_t maxSize_;
451 
452  // Used when maxSize_ != 0.
453  size_t size_ = 0;
454  std::list<std::vector<WordWithCode>> recent_;
455 
456  // Used for look up
457  WeightedTrie unigram_;
458  WeightedTrie bigram_;
459 };
460 
461 } // namespace
462 
463 // We define the frequency as following.
464 // (1 - p) the frequency belongs to first pool.
465 // p * (1 - p) Second pool
466 // p^2 * (1 - p) Third pool
467 // ...
468 // p^(n-1) n-th pool.
469 // In sum, it's (1-p) * p^(i - 1)
470 // And then we define alpha as p = 1 / (1 + alpha).
472 public:
473  void populateSentence(std::list<std::vector<WordWithCode>> popedSentence) {
474  for (size_t i = 1; !popedSentence.empty() && i < pools_.size(); i++) {
475  std::list<std::vector<WordWithCode>> nextSentences;
476  while (!popedSentence.empty()) {
477  auto newPopedSentence = pools_[i].add(popedSentence.front());
478  popedSentence.pop_front();
479  nextSentences.splice(nextSentences.end(), newPopedSentence);
480  }
481  popedSentence = std::move(nextSentences);
482  }
483  }
484 
485  float unigramFreq(WordWithCodeView word) const {
486  assert(pools_.size() == poolWeight_.size());
487  float freq = 0;
488  for (size_t i = 0; i < pools_.size(); i++) {
489  freq += pools_[i].unigramFreq(word) * poolWeight_[i];
490  }
491  return freq;
492  }
493 
494  float bigramFreq(WordWithCodeView prev, WordWithCodeView cur) const {
495  assert(pools_.size() == poolWeight_.size());
496  float freq = 0;
497  for (size_t i = 0; i < pools_.size(); i++) {
498  freq += pools_[i].bigramFreq(prev, cur) * poolWeight_[i];
499  }
500  return freq;
501  }
502 
503  float unigramSize() const {
504  float size = 0;
505  for (size_t i = 0; i < pools_.size(); i++) {
506  size += pools_[i].maxSize() * poolWeight_[i];
507  }
508  return size;
509  }
510 
511  // A log probabilty.
512  float unknown_ =
513  std::log10(DEFAULT_LANGUAGE_MODEL_UNKNOWN_PROBABILITY_PENALTY);
514  bool useOnlyUnigram_ = false;
515  std::vector<HistoryBigramPool> pools_;
516  std::vector<float> poolWeight_;
517 };
518 
519 HistoryBigram::HistoryBigram()
520  : d_ptr(std::make_unique<HistoryBigramPrivate>()) {
521  FCITX_D();
522  const float p = 1.0 / (1 + HISTORY_BIGRAM_ALPHA_VALUE);
523  constexpr std::array<int, 3> poolSize = {128, 8192, 65536};
524  d->pools_.reserve(poolSize.size());
525  d->poolWeight_.reserve(poolSize.size());
526  for (auto size : poolSize) {
527  d->pools_.emplace_back(size);
528  float portion = 1.0F;
529  if (d->pools_.size() != poolSize.size()) {
530  portion *= 1 - p;
531  }
532  portion *= std::pow(p, d->pools_.size() - 1);
533  d->poolWeight_.push_back(portion / d->pools_.back().maxSize());
534  }
535  setUnknownPenalty(
536  std::log10(DEFAULT_LANGUAGE_MODEL_UNKNOWN_PROBABILITY_PENALTY));
537 }
538 
539 FCITX_DEFINE_DEFAULT_DTOR_AND_MOVE(HistoryBigram)
540 
541 void HistoryBigram::setUnknownPenalty(float unknown) {
542  FCITX_D();
543  d->unknown_ = unknown;
544 }
545 
546 float HistoryBigram::unknownPenalty() const {
547  FCITX_D();
548  return d->unknown_;
549 }
550 
551 void HistoryBigram::setUseOnlyUnigram(bool useOnlyUnigram) {
552  FCITX_D();
553  d->useOnlyUnigram_ = useOnlyUnigram;
554 }
555 
556 bool HistoryBigram::useOnlyUnigram() const {
557  FCITX_D();
558  return d->useOnlyUnigram_;
559 }
560 
561 void HistoryBigram::add(const libime::SentenceResult &sentence) {
562  addWithCode(sentence, nullptr);
563 }
564 
565 void HistoryBigram::addWithCode(
566  const libime::SentenceResult &sentence,
567  const ValidationCodeExtractor &validationCodeExtractor) {
568  FCITX_D();
569  d->populateSentence(d->pools_[0].add(
570  sentence.sentence() |
571  std::views::transform(
572  [&validationCodeExtractor](const auto &item) -> WordWithCode {
573  return {item->word(), validationCodeExtractor
574  ? validationCodeExtractor(item)
575  : ""};
576  })));
577 }
578 
579 void HistoryBigram::add(const std::vector<std::string> &sentence) {
580  FCITX_D();
581  d->populateSentence(d->pools_[0].add(
582  sentence | std::views::transform([](const auto &word) -> WordWithCode {
583  return WordWithCode{word, ""};
584  })));
585 }
586 
587 void HistoryBigram::addWithCode(
588  const std::vector<WordWithCode> &sentenceWithValidationCode) {
589  FCITX_D();
590  d->populateSentence(d->pools_[0].add(sentenceWithValidationCode));
591 }
592 
593 bool HistoryBigram::isUnknown(std::string_view v) const {
594  FCITX_D();
595  return std::ranges::all_of(d->pools_, [v](const HistoryBigramPool &pool) {
596  return pool.isUnknown({v, ""});
597  });
598 }
599 
600 float HistoryBigram::score(std::string_view prev, std::string_view cur) const {
601  return scoreWithCode({prev, ""}, {cur, ""});
602 }
603 
604 float HistoryBigram::scoreWithCode(WordWithCodeView prev,
605  WordWithCodeView cur) const {
606  FCITX_D();
607  if (prev.first.empty()) {
608  prev.first = "<s>";
609  }
610  if (cur.first.empty()) {
611  cur.first = "<unk>";
612  }
613 
614  auto uf0 = d->unigramFreq(prev);
615  auto bf = d->bigramFreq(prev, cur);
616  auto uf1 = d->unigramFreq(cur);
617 
618  float bigramWeight = d->useOnlyUnigram_ ? 0.0F : 0.8F;
619  // add 0.5 to avoid div 0
620  float pr = 0.0F;
621  pr += bigramWeight * bf / float(uf0 + (d->poolWeight_[0] / 2));
622  pr += (1.0F - bigramWeight) * uf1 /
623  float(d->unigramSize() + (d->poolWeight_[0] / 2));
624 
625  pr = std::min<double>(pr, 1.0);
626  if (pr == 0) {
627  return d->unknown_;
628  }
629 
630  return std::log10(pr);
631 }
632 
633 void HistoryBigram::load(std::istream &in) {
634  FCITX_D();
635  uint32_t magic = 0;
636  uint32_t version = 0;
637  throw_if_io_fail(unmarshall(in, magic));
638  if (magic != historyBinaryFormatMagic) {
639  throw std::invalid_argument("Invalid history magic.");
640  }
641  throw_if_io_fail(unmarshall(in, version));
642  switch (version) {
643  case 1:
644  std::ranges::for_each(d->pools_ | std::views::take(2),
645  [&in](auto &pool) { pool.load(in); });
646  break;
647  case 2:
648  std::ranges::for_each(d->pools_, [&in](auto &pool) { pool.load(in); });
649  break;
650  case 3:
651  case historyBinaryFormatVersion:
652  // For version 3 and version 4, the format is the same, but version 4
653  // contains additional code data, bump the version to it not backward
654  // compatible with version 3.
655  readZSTDCompressed(in, [d](std::istream &compressIn) {
656  std::ranges::for_each(d->pools_, [&compressIn](auto &pool) {
657  pool.load(compressIn);
658  });
659  });
660  break;
661  default:
662  throw std::invalid_argument("Invalid history version.");
663  }
664 }
665 
666 void HistoryBigram::loadText(std::istream &in) {
667  FCITX_D();
668  std::ranges::for_each(d->pools_, [&in](auto &pool) { pool.loadText(in); });
669 }
670 
671 void HistoryBigram::save(std::ostream &out) {
672  FCITX_D();
673  throw_if_io_fail(marshall(out, historyBinaryFormatMagic));
674  throw_if_io_fail(marshall(out, historyBinaryFormatVersion));
675 
676  writeZSTDCompressed(out, [d](std::ostream &compressOut) {
677  std::ranges::for_each(
678  d->pools_, [&compressOut](auto &pool) { pool.save(compressOut); });
679  });
680 }
681 
682 void HistoryBigram::dump(std::ostream &out) {
683  FCITX_D();
684  std::ranges::for_each(d->pools_,
685  [&out](const auto &pool) { pool.dump(out); });
686 }
687 
688 void HistoryBigram::clear() {
689  FCITX_D();
690  std::ranges::for_each(d->pools_, std::mem_fn(&HistoryBigramPool::clear));
691 }
692 
693 void HistoryBigram::forget(std::string_view word) { forget(word, ""); }
694 
695 void HistoryBigram::forget(std::string_view word, std::string_view code) {
696  FCITX_D();
697  std::ranges::for_each(
698  d->pools_, [word, code](auto &pool) { pool.forget(word, code); });
699 }
700 
701 void HistoryBigram::fillPredict(std::unordered_set<std::string> &words,
702  const std::vector<std::string> &sentence,
703  size_t maxSize) const {
704  FCITX_D();
705  if (maxSize > 0 && words.size() >= maxSize) {
706  return;
707  }
708  std::string lookup;
709  if (!sentence.empty()) {
710  lookup = sentence.back();
711  } else {
712  lookup = "<s>";
713  }
714  lookup += bigramSeparator;
715  std::ranges::for_each(
716  d->pools_, [&words, &lookup, maxSize](const HistoryBigramPool &pool) {
717  pool.fillPredict(words, lookup, maxSize);
718  });
719 }
720 
721 bool HistoryBigram::containsBigram(std::string_view prev,
722  std::string_view cur) const {
723  FCITX_D();
724  return std::ranges::any_of(
725  d->pools_, [&prev, &cur](const HistoryBigramPool &pool) {
726  return pool.bigramFreq({prev, ""}, {cur, ""}) > 0;
727  });
728 }
729 
730 float HistoryBigram::unigramFrequency(WordWithCodeView word) const {
731  FCITX_D();
732  return d->unigramFreq(word);
733 }
734 
735 float HistoryBigram::bigramFrequency(WordWithCodeView prev,
736  WordWithCodeView cur) const {
737  FCITX_D();
738  return d->bigramFreq(prev, cur);
739 }
740 
741 int32_t HistoryBigram::rawUnigramFrequency(WordWithCodeView word) const {
742  FCITX_D();
743  int32_t freq = 0;
744  for (const auto &pool : d->pools_) {
745  freq += pool.unigramFreq(word);
746  }
747  return freq;
748 }
749 
750 int32_t HistoryBigram::rawBigramFrequency(WordWithCodeView prev,
751  WordWithCodeView cur) const {
752  FCITX_D();
753  int32_t freq = 0;
754  for (const auto &pool : d->pools_) {
755  freq += pool.bigramFreq(prev, cur);
756  }
757  return freq;
758 }
759 
760 float HistoryBigram::score(const WordNode *prev, const WordNode *cur) const {
761  return scoreWithCode(prev, cur, nullptr);
762 }
763 
764 float HistoryBigram::scoreWithCode(
765  const WordNode *prev, const WordNode *cur,
766  const ValidationCodeExtractor &extractor) const {
767  return scoreWithCode(
768  {prev ? prev->word() : "", extractor && prev ? extractor(prev) : ""},
769  {cur ? cur->word() : "", extractor && cur ? extractor(cur) : ""});
770 }
771 
772 void HistoryBigram::addWithContext(const std::vector<WordWithCode> &context,
773  std::vector<WordWithCode> newSentence) {
774  FCITX_D();
775  if (context.empty() ||
776  !d->pools_[0].maybeAppendToLatestSentence(context, newSentence)) {
777  addWithCode(newSentence);
778  }
779 }
780 
781 } // namespace libime
int32_t rawUnigramFrequency(WordWithCodeView word) const
Query the raw frequency of the unigram.
float bigramFrequency(WordWithCodeView prev, WordWithCodeView cur) const
Query the weighted frequency of the bigram.
Provide a DATrie implementation.
float unigramFrequency(WordWithCodeView word) const
Query the weighted frequency of the unigram.
void fillPredict(std::unordered_set< std::string > &words, const std::vector< std::string > &sentence, size_t maxSize) const
Fill the prediction based on current sentence.
int32_t rawBigramFrequency(WordWithCodeView prev, WordWithCodeView cur) const
Query the raw frequency of the bigram.