libime
pinyinprediction.cpp
1 /*
2  * SPDX-FileCopyrightText: 2023-2023 CSSlayer <wengxt@gmail.com>
3  *
4  * SPDX-License-Identifier: LGPL-2.1-or-later
5  */
6 #include "pinyinprediction.h"
7 #include <algorithm>
8 #include <cstddef>
9 #include <iterator>
10 #include <memory>
11 #include <string>
12 #include <string_view>
13 #include <tuple>
14 #include <unordered_set>
15 #include <utility>
16 #include <vector>
17 #include <fcitx-utils/macros.h>
18 #include <fcitx-utils/misc.h>
19 #include <fcitx-utils/stringutils.h>
20 #include "libime/core/historybigram.h"
21 #include "libime/core/languagemodel.h"
22 #include "libime/core/prediction.h"
23 #include "libime/pinyin/pinyindictionary.h"
24 
25 namespace libime {
26 
28 public:
29  const PinyinDictionary *dict_ = nullptr;
30 };
31 
32 PinyinPrediction::PinyinPrediction()
33  : d_ptr(std::make_unique<PinyinPredictionPrivate>()) {}
34 
35 PinyinPrediction::~PinyinPrediction() {}
36 
38  FCITX_D();
39  d->dict_ = dict;
40 }
41 
42 std::vector<std::pair<std::string, PinyinPredictionSource>>
43 PinyinPrediction::predict(const State &state,
44  const std::vector<std::string> &sentence,
45  std::string_view lastEncodedPinyin, size_t maxSize) {
46  FCITX_D();
47  std::vector<std::pair<std::string, PinyinPredictionSource>> finalResult;
48 
49  if (lastEncodedPinyin.empty() || sentence.empty()) {
50  auto result = Prediction::predictWithScore(state, sentence, maxSize);
51  std::ranges::transform(result, std::back_inserter(finalResult),
52  [](std::pair<std::string, float> &value) {
53  return std::make_pair(
54  std::move(value.first),
55  PinyinPredictionSource::Model);
56  });
57  return finalResult;
58  }
59 
60  auto cmp = [](auto &lhs, auto &rhs) {
61  if (std::get<float>(lhs) != std::get<float>((rhs))) {
62  return std::get<float>(lhs) > std::get<float>(rhs);
63  }
64  return std::get<std::string>(lhs) < std::get<std::string>(rhs);
65  };
66 
67  auto result = Prediction::predictWithScore(state, sentence, maxSize);
68  std::vector<std::tuple<std::string, float, PinyinPredictionSource>>
69  intermedidateResult;
70  std::ranges::transform(result, std::back_inserter(intermedidateResult),
71  [](std::pair<std::string, float> &value) {
72  return std::make_tuple(
73  std::move(value.first), value.second,
74  PinyinPredictionSource::Model);
75  });
76  std::ranges::make_heap(intermedidateResult, cmp);
77 
78  State prevState = model()->nullState();
79  State outState;
80  std::vector<WordNode> nodes;
81  std::unordered_set<std::string> dup;
82  if (!sentence.empty()) {
83  nodes.reserve(sentence.size());
84  for (const auto &word : fcitx::MakeIterRange(
85  sentence.begin(), std::prev(sentence.end()))) {
86  auto idx = model()->index(word);
87  nodes.emplace_back(word, idx);
88  model()->score(prevState, nodes.back(), outState);
89  prevState = outState;
90  }
91  // We record the last score for the sentence word to adjust the partial
92  // score. E.g. for 无, model may contain 压力 and dict contain 聊 score
93  // of 聊 should be P(...|无聊) and score of 压力 should be P(...|无) *
94  // P(...无|压力) adjust is the P(...|无) here.
95  nodes.emplace_back(sentence.back(), model()->index(sentence.back()));
96  float adjust = model()->score(prevState, nodes.back(), outState);
97  for (auto &result : intermedidateResult) {
98  std::get<float>(result) += adjust;
99  dup.insert(std::get<std::string>(result));
100  }
101  }
102 
103  d->dict_->matchWordsPrefix(
104  lastEncodedPinyin.data(), lastEncodedPinyin.size(),
105  [this, &sentence, &prevState, &cmp, &intermedidateResult, &dup,
106  maxSize](std::string_view, std::string_view hz, float cost) {
107  if (sentence.back().size() < hz.size() &&
108  hz.starts_with(sentence.back())) {
109 
110  std::string newWord(hz.substr(sentence.back().size()));
111  if (dup.contains(newWord)) {
112  return true;
113  }
114 
115  std::tuple<std::string, float, PinyinPredictionSource> newItem{
116  std::move(newWord),
117  cost + model()->singleWordScore(prevState, hz),
118  PinyinPredictionSource::Dictionary};
119 
120  dup.insert(std::get<std::string>(newItem));
121  intermedidateResult.push_back(std::move(newItem));
122  std::ranges::push_heap(intermedidateResult, cmp);
123  while (intermedidateResult.size() > maxSize) {
124  std::ranges::pop_heap(intermedidateResult, cmp);
125  dup.erase(
126  std::get<std::string>(intermedidateResult.back()));
127  intermedidateResult.pop_back();
128  }
129  }
130  return true;
131  });
132 
133  std::ranges::sort_heap(intermedidateResult, cmp);
134  std::ranges::transform(
135  intermedidateResult, std::back_inserter(finalResult), [](auto &value) {
136  return std::make_pair(std::move(std::get<std::string>(value)),
137  std::get<PinyinPredictionSource>(value));
138  });
139 
140  return finalResult;
141 }
142 
143 std::vector<std::pair<std::string, PinyinPredictionSource>>
145  const State &state,
146  const std::vector<libime::HistoryBigram::WordWithCode> &sentence,
147  size_t maxSize) {
148  std::vector<std::string> words;
149  words.reserve(sentence.size());
150  for (const auto &[word, code] : sentence) {
151  words.push_back(word);
152  }
153  std::string_view lastPinyin;
154  if (!sentence.empty()) {
155  lastPinyin = sentence.back().second;
156  }
157  return predict(state, words, lastPinyin, maxSize);
158 }
159 
160 std::vector<std::string>
161 PinyinPrediction::predict(const std::vector<std::string> &sentence,
162  size_t maxSize) {
163  return Prediction::predict(sentence, maxSize);
164 }
165 
166 } // namespace libime
std::vector< std::pair< std::string, PinyinPredictionSource > > predict(const State &state, const std::vector< std::string > &sentence, std::string_view lastEncodedPinyin, size_t maxSize=0)
Predict from model and pinyin dictionary for the last sentnce being type.
PinyinDictionary is a set of dictionaries for Pinyin.
void setPinyinDictionary(const PinyinDictionary *dict)
Set the pinyin dictionary used for prediction.