6 #include "historybigram.h" 22 #include <string_view> 23 #include <unordered_set> 26 #include <fcitx-utils/macros.h> 27 #include <fcitx-utils/stringutils.h> 28 #include "constants.h" 32 #include "zstdfilter.h" 38 using WordWithCode = HistoryBigram::WordWithCode;
39 using WordWithCodeView = HistoryBigram::WordWithCodeView;
41 constexpr uint32_t historyBinaryFormatMagic = 0x000fc315;
42 constexpr uint32_t historyBinaryFormatVersion = 0x4;
43 constexpr
char bigramSeparator =
'\x01';
44 constexpr
char wordCodeSeparator =
'\x02';
46 std::string wordAndCodeToString(WordWithCodeView wordAndCode) {
47 std::string s{std::get<0>(wordAndCode)};
51 auto code = std::get<1>(wordAndCode);
53 s += wordCodeSeparator;
59 WordWithCode bigramWordWithCode(WordWithCodeView prev, WordWithCodeView cur) {
61 s.append(std::get<0>(prev));
63 s.append(std::get<0>(cur));
65 auto code1 = std::get<1>(prev);
66 auto code2 = std::get<1>(cur);
67 std::string concatCode;
68 if (code1.empty() && code2.empty()) {
72 concatCode += bigramSeparator;
75 return {s, concatCode};
79 using TrieType = DATrie<int32_t>;
82 WeightedTrie() =
default;
84 void clear() { trie_.clear(); }
86 const TrieType &trie()
const {
return trie_; }
88 int32_t weightedSize()
const {
return weightedSize_; }
90 int32_t freq(WordWithCodeView wordAndCode)
const {
94 TrieType::position_type pos = 0;
96 auto v = trie_.traverse(wordAndCode.first, pos);
97 if (TrieType::isValid(v)) {
99 }
else if (TrieType::isNoPath(v)) {
102 const char separator[] = {wordCodeSeparator,
'\0'};
103 v = trie_.traverse(separator, pos);
104 if (!TrieType::isNoPath(v)) {
105 if (!wordAndCode.second.empty() &&
106 wordAndCode.second.front() != bigramSeparator &&
107 wordAndCode.second.back() != bigramSeparator) {
108 v = trie_.traverse(wordAndCode.second, pos);
109 if (TrieType::isValid(v)) {
114 [
this, &result, &wordAndCode](TrieType::value_type value,
116 TrieType::position_type pos) {
120 if (!wordAndCode.second.empty()) {
122 wordAndCode.second.front() == bigramSeparator ||
123 wordAndCode.second.back() == bigramSeparator);
124 std::string codeInTrie;
125 trie().suffix(codeInTrie, len, pos);
126 if (wordAndCode.second.front() == bigramSeparator &&
127 !codeInTrie.ends_with(wordAndCode.second)) {
130 if (wordAndCode.second.back() == bigramSeparator &&
131 !codeInTrie.starts_with(wordAndCode.second)) {
144 void incFreq(WordWithCodeView wordAndCode, int32_t delta) {
145 auto s = wordAndCodeToString(wordAndCode);
147 trie_.update(s.data(), s.size(),
148 [delta](int32_t v) {
return v + delta; });
149 weightedSize_ += delta;
152 void decFreq(WordWithCodeView wordAndCode, int32_t delta) {
153 auto s = wordAndCodeToString(wordAndCode);
154 auto v = trie_.exactMatchSearch(s.data(), s.size());
155 if (TrieType::isNoValue(v)) {
159 trie_.erase(s.data(), s.size());
163 trie_.set(s.data(), s.size(), v);
164 decWeightedSize(delta);
168 void fillPredict(std::unordered_set<std::string> &words,
169 std::string_view word,
size_t maxSize)
const {
171 [
this, &words, maxSize](TrieType::value_type,
size_t len,
172 TrieType::position_type pos) {
174 trie().suffix(buf, len, pos);
175 auto separatorPos = buf.find(wordCodeSeparator);
176 if (separatorPos != std::string::npos) {
177 buf.erase(separatorPos);
180 if (buf ==
"<s>" || buf ==
"</s>") {
183 words.emplace(std::move(buf));
185 return maxSize <= 0 || words.size() < maxSize;
190 void decWeightedSize(int32_t v) {
192 weightedSize_ = std::max(weightedSize_, 0);
195 int32_t weightedSize_ = 0;
199 class HistoryBigramPool {
201 HistoryBigramPool(
size_t maxSize) : maxSize_(maxSize) {}
203 void load(std::istream &in) {
206 throw_if_io_fail(unmarshall(in, count));
209 throw_if_io_fail(unmarshall(in, size));
210 std::vector<WordWithCode> sentence;
213 throw_if_io_fail(unmarshallString(in, buffer));
214 std::string_view bufferView{buffer};
215 size_t separatorPos = bufferView.find(wordCodeSeparator);
216 if (separatorPos != std::string_view::npos) {
217 sentence.emplace_back(
218 std::string(bufferView.substr(0, separatorPos)),
219 std::string(bufferView.substr(separatorPos + 1)));
221 sentence.emplace_back(std::move(buffer),
"");
228 void loadText(std::istream &in) {
231 std::vector<std::string> lines;
232 while (std::getline(in, buf)) {
233 lines.emplace_back(buf);
234 if (lines.size() >= maxSize_) {
238 for (
auto &line : lines | std::views::reverse) {
239 std::string_view lineView{line};
240 std::vector<std::string> tokens;
241 bool withCode =
false;
242 while (!lineView.empty()) {
244 auto consumed = fcitx::stringutils::consumeMaybeEscapedValue(
245 lineView, FCITX_WHITESPACE, &token);
246 if (!consumed.empty()) {
247 tokens.push_back(std::move(token));
249 if (tokens.size() == 1 && !lineView.empty() &&
250 lineView.front() ==
'\t') {
256 if (tokens.size() % 2 != 0) {
259 add(std::views::iota(static_cast<size_t>(0),
261 std::views::transform([&tokens](
size_t i) {
262 return WordWithCode{tokens[i * 2], tokens[(i * 2) + 1]};
267 std::views::transform([](
const auto &word) -> WordWithCode {
268 std::vector<std::string> wordWithMaybeCode =
269 fcitx::stringutils::split(
271 fcitx::stringutils::SplitBehavior::KeepEmpty);
272 if (wordWithMaybeCode.size() == 2) {
273 return WordWithCode{wordWithMaybeCode[0],
274 wordWithMaybeCode[1]};
276 return WordWithCode{word,
""};
282 void save(std::ostream &out) {
283 uint32_t count = recent_.size();
284 throw_if_io_fail(marshall(out, count));
288 for (
auto &sentence : recent_ | std::views::reverse) {
289 uint32_t size = sentence.size();
290 throw_if_io_fail(marshall(out, size));
291 for (
const auto &s : sentence) {
292 throw_if_io_fail(marshallString(out, wordAndCodeToString(s)));
297 void dump(std::ostream &out)
const {
298 for (
const auto &sentence : recent_) {
300 bool hasCode = std::ranges::any_of(sentence, [](
const auto &item) {
301 return !std::get<1>(item).empty();
303 for (
const auto &s : sentence) {
309 out << fcitx::stringutils::escapeForValue(std::get<0>(s));
312 << fcitx::stringutils::escapeForValue(std::get<1>(s));
326 template <
typename R>
327 std::list<std::vector<WordWithCode>> add(
const R &sentence) {
328 std::list<std::vector<WordWithCode>> popedSentence;
329 if (sentence.empty()) {
330 return popedSentence;
333 if (std::ranges::any_of(sentence, [](
const auto &item) {
334 const auto &[word, code] = item;
335 return word.find(
'\0') != std::string::npos;
337 return popedSentence;
339 while (recent_.size() >= maxSize_) {
340 remove(recent_.back());
341 popedSentence.splice(popedSentence.end(), recent_,
342 std::prev(recent_.end()));
345 std::vector<WordWithCode> newSentence;
347 for (
auto iter = sentence.begin(), end = sentence.end(); iter != end;
349 unigram_.incFreq(*iter, delta);
350 auto next = std::ranges::next(iter);
352 incBigram(*iter, *next, delta);
354 newSentence.push_back(*iter);
356 recent_.push_front(std::move(newSentence));
357 unigram_.incFreq({
"<s>",
""}, delta);
358 unigram_.incFreq({
"</s>",
""}, delta);
359 incBigram({
"<s>",
""}, sentence.front(), delta);
360 incBigram(sentence.back(), {
"</s>",
""}, delta);
362 return popedSentence;
365 int32_t unigramFreq(WordWithCodeView s)
const {
return unigram_.freq(s); }
367 int32_t bigramFreq(WordWithCodeView s1, WordWithCodeView s2)
const {
368 return bigram_.freq(bigramWordWithCode(s1, s2));
371 bool isUnknown(WordWithCodeView word)
const {
372 return unigramFreq(word) == 0;
375 size_t maxSize()
const {
return maxSize_; }
377 size_t realSize()
const {
return recent_.size(); }
379 void forget(std::string_view word, std::string_view code) {
380 auto iter = recent_.begin();
381 while (iter != recent_.end()) {
383 iter->begin(), iter->end(), [word, code](
const auto &item) {
384 const auto &[w, c] = item;
385 return w == word && (code.empty() || c == code);
388 iter = recent_.erase(iter);
395 void fillPredict(std::unordered_set<std::string> &words,
396 std::string_view word,
size_t maxSize = 0)
const {
397 bigram_.fillPredict(words, word, maxSize);
400 bool maybeAppendToLatestSentence(
const std::vector<WordWithCode> &context,
401 std::vector<WordWithCode> &newSentence) {
402 if (recent_.empty() || newSentence.empty()) {
405 auto &latestSentence = recent_.front();
406 if (latestSentence.size() < context.size() ||
409 std::views::drop(latestSentence,
410 latestSentence.size() - context.size()))) {
415 decBigram(latestSentence.back(), {
"</s>",
""}, delta);
416 for (
auto &item : newSentence) {
417 unigram_.incFreq(item, delta);
418 incBigram(latestSentence.back(), item, delta);
419 latestSentence.push_back(std::move(item));
421 incBigram(latestSentence.back(), {
"</s>",
""}, delta);
427 template <
typename R>
428 void remove(
const R &sentence) {
430 for (
auto iter = sentence.begin(), end = sentence.end(); iter != end;
432 unigram_.decFreq(*iter, delta);
433 auto next = std::next(iter);
435 decBigram(*iter, *next, delta);
438 decBigram({
"<s>",
""}, sentence.front(), delta);
439 decBigram(sentence.back(), {
"</s>",
""}, delta);
442 void decBigram(WordWithCodeView s1, WordWithCodeView s2, int32_t delta) {
443 bigram_.decFreq(bigramWordWithCode(s1, s2), delta);
446 void incBigram(WordWithCodeView s1, WordWithCodeView s2,
int delta) {
447 bigram_.incFreq(bigramWordWithCode(s1, s2), delta);
450 const size_t maxSize_;
454 std::list<std::vector<WordWithCode>> recent_;
457 WeightedTrie unigram_;
458 WeightedTrie bigram_;
473 void populateSentence(std::list<std::vector<WordWithCode>> popedSentence) {
474 for (
size_t i = 1; !popedSentence.empty() && i < pools_.size(); i++) {
475 std::list<std::vector<WordWithCode>> nextSentences;
476 while (!popedSentence.empty()) {
477 auto newPopedSentence = pools_[i].add(popedSentence.front());
478 popedSentence.pop_front();
479 nextSentences.splice(nextSentences.end(), newPopedSentence);
481 popedSentence = std::move(nextSentences);
485 float unigramFreq(WordWithCodeView word)
const {
486 assert(pools_.size() == poolWeight_.size());
488 for (
size_t i = 0; i < pools_.size(); i++) {
489 freq += pools_[i].unigramFreq(word) * poolWeight_[i];
494 float bigramFreq(WordWithCodeView prev, WordWithCodeView cur)
const {
495 assert(pools_.size() == poolWeight_.size());
497 for (
size_t i = 0; i < pools_.size(); i++) {
498 freq += pools_[i].bigramFreq(prev, cur) * poolWeight_[i];
503 float unigramSize()
const {
505 for (
size_t i = 0; i < pools_.size(); i++) {
506 size += pools_[i].maxSize() * poolWeight_[i];
513 std::log10(DEFAULT_LANGUAGE_MODEL_UNKNOWN_PROBABILITY_PENALTY);
514 bool useOnlyUnigram_ =
false;
515 std::vector<HistoryBigramPool> pools_;
516 std::vector<float> poolWeight_;
519 HistoryBigram::HistoryBigram()
520 : d_ptr(std::make_unique<HistoryBigramPrivate>()) {
522 const float p = 1.0 / (1 + HISTORY_BIGRAM_ALPHA_VALUE);
523 constexpr std::array<int, 3> poolSize = {128, 8192, 65536};
524 d->pools_.reserve(poolSize.size());
525 d->poolWeight_.reserve(poolSize.size());
526 for (
auto size : poolSize) {
527 d->pools_.emplace_back(size);
528 float portion = 1.0F;
529 if (d->pools_.size() != poolSize.size()) {
532 portion *= std::pow(p, d->pools_.size() - 1);
533 d->poolWeight_.push_back(portion / d->pools_.back().maxSize());
536 std::log10(DEFAULT_LANGUAGE_MODEL_UNKNOWN_PROBABILITY_PENALTY));
543 d->unknown_ = unknown;
546 float HistoryBigram::unknownPenalty()
const {
551 void HistoryBigram::setUseOnlyUnigram(
bool useOnlyUnigram) {
553 d->useOnlyUnigram_ = useOnlyUnigram;
556 bool HistoryBigram::useOnlyUnigram()
const {
558 return d->useOnlyUnigram_;
562 addWithCode(sentence,
nullptr);
565 void HistoryBigram::addWithCode(
567 const ValidationCodeExtractor &validationCodeExtractor) {
569 d->populateSentence(d->pools_[0].add(
570 sentence.sentence() |
571 std::views::transform(
572 [&validationCodeExtractor](
const auto &item) -> WordWithCode {
573 return {item->word(), validationCodeExtractor
574 ? validationCodeExtractor(item)
579 void HistoryBigram::add(
const std::vector<std::string> &sentence) {
581 d->populateSentence(d->pools_[0].add(
582 sentence | std::views::transform([](
const auto &word) -> WordWithCode {
583 return WordWithCode{word,
""};
587 void HistoryBigram::addWithCode(
588 const std::vector<WordWithCode> &sentenceWithValidationCode) {
590 d->populateSentence(d->pools_[0].add(sentenceWithValidationCode));
593 bool HistoryBigram::isUnknown(std::string_view v)
const {
595 return std::ranges::all_of(d->pools_, [v](
const HistoryBigramPool &pool) {
596 return pool.isUnknown({v,
""});
600 float HistoryBigram::score(std::string_view prev, std::string_view cur)
const {
601 return scoreWithCode({prev,
""}, {cur,
""});
604 float HistoryBigram::scoreWithCode(WordWithCodeView prev,
605 WordWithCodeView cur)
const {
607 if (prev.first.empty()) {
610 if (cur.first.empty()) {
614 auto uf0 = d->unigramFreq(prev);
615 auto bf = d->bigramFreq(prev, cur);
616 auto uf1 = d->unigramFreq(cur);
618 float bigramWeight = d->useOnlyUnigram_ ? 0.0F : 0.8F;
621 pr += bigramWeight * bf / float(uf0 + (d->poolWeight_[0] / 2));
622 pr += (1.0F - bigramWeight) * uf1 /
623 float(d->unigramSize() + (d->poolWeight_[0] / 2));
625 pr = std::min<double>(pr, 1.0);
630 return std::log10(pr);
633 void HistoryBigram::load(std::istream &in) {
636 uint32_t version = 0;
637 throw_if_io_fail(unmarshall(in, magic));
638 if (magic != historyBinaryFormatMagic) {
639 throw std::invalid_argument(
"Invalid history magic.");
641 throw_if_io_fail(unmarshall(in, version));
644 std::ranges::for_each(d->pools_ | std::views::take(2),
645 [&in](
auto &pool) { pool.load(in); });
648 std::ranges::for_each(d->pools_, [&in](
auto &pool) { pool.load(in); });
651 case historyBinaryFormatVersion:
655 readZSTDCompressed(in, [d](std::istream &compressIn) {
656 std::ranges::for_each(d->pools_, [&compressIn](
auto &pool) {
657 pool.load(compressIn);
662 throw std::invalid_argument(
"Invalid history version.");
666 void HistoryBigram::loadText(std::istream &in) {
668 std::ranges::for_each(d->pools_, [&in](
auto &pool) { pool.loadText(in); });
671 void HistoryBigram::save(std::ostream &out) {
673 throw_if_io_fail(marshall(out, historyBinaryFormatMagic));
674 throw_if_io_fail(marshall(out, historyBinaryFormatVersion));
676 writeZSTDCompressed(out, [d](std::ostream &compressOut) {
677 std::ranges::for_each(
678 d->pools_, [&compressOut](
auto &pool) { pool.save(compressOut); });
682 void HistoryBigram::dump(std::ostream &out) {
684 std::ranges::for_each(d->pools_,
685 [&out](
const auto &pool) { pool.dump(out); });
688 void HistoryBigram::clear() {
690 std::ranges::for_each(d->pools_, std::mem_fn(&HistoryBigramPool::clear));
693 void HistoryBigram::forget(std::string_view word) { forget(word,
""); }
695 void HistoryBigram::forget(std::string_view word, std::string_view code) {
697 std::ranges::for_each(
698 d->pools_, [word, code](
auto &pool) { pool.forget(word, code); });
702 const std::vector<std::string> &sentence,
703 size_t maxSize)
const {
705 if (maxSize > 0 && words.size() >= maxSize) {
709 if (!sentence.empty()) {
710 lookup = sentence.back();
714 lookup += bigramSeparator;
715 std::ranges::for_each(
716 d->pools_, [&words, &lookup, maxSize](
const HistoryBigramPool &pool) {
717 pool.fillPredict(words, lookup, maxSize);
721 bool HistoryBigram::containsBigram(std::string_view prev,
722 std::string_view cur)
const {
724 return std::ranges::any_of(
725 d->pools_, [&prev, &cur](
const HistoryBigramPool &pool) {
726 return pool.bigramFreq({prev,
""}, {cur,
""}) > 0;
732 return d->unigramFreq(word);
736 WordWithCodeView cur)
const {
738 return d->bigramFreq(prev, cur);
744 for (
const auto &pool : d->pools_) {
745 freq += pool.unigramFreq(word);
751 WordWithCodeView cur)
const {
754 for (
const auto &pool : d->pools_) {
755 freq += pool.bigramFreq(prev, cur);
760 float HistoryBigram::score(
const WordNode *prev,
const WordNode *cur)
const {
761 return scoreWithCode(prev, cur,
nullptr);
764 float HistoryBigram::scoreWithCode(
766 const ValidationCodeExtractor &extractor)
const {
767 return scoreWithCode(
768 {prev ? prev->word() :
"", extractor && prev ? extractor(prev) :
""},
769 {cur ? cur->word() :
"", extractor && cur ? extractor(cur) :
""});
772 void HistoryBigram::addWithContext(
const std::vector<WordWithCode> &context,
773 std::vector<WordWithCode> newSentence) {
775 if (context.empty() ||
776 !d->pools_[0].maybeAppendToLatestSentence(context, newSentence)) {
777 addWithCode(newSentence);
int32_t rawUnigramFrequency(WordWithCodeView word) const
Query the raw frequency of the unigram.
float bigramFrequency(WordWithCodeView prev, WordWithCodeView cur) const
Query the weighted frequency of the bigram.
Provide a DATrie implementation.
float unigramFrequency(WordWithCodeView word) const
Query the weighted frequency of the unigram.
void fillPredict(std::unordered_set< std::string > &words, const std::vector< std::string > &sentence, size_t maxSize) const
Fill the prediction based on current sentence.
int32_t rawBigramFrequency(WordWithCodeView prev, WordWithCodeView cur) const
Query the raw frequency of the bigram.