6 #include "pinyinprediction.h" 12 #include <string_view> 14 #include <unordered_set> 17 #include <fcitx-utils/macros.h> 18 #include <fcitx-utils/misc.h> 19 #include <fcitx-utils/stringutils.h> 20 #include "libime/core/historybigram.h" 21 #include "libime/core/languagemodel.h" 22 #include "libime/core/prediction.h" 23 #include "libime/pinyin/pinyindictionary.h" 32 PinyinPrediction::PinyinPrediction()
33 : d_ptr(std::make_unique<PinyinPredictionPrivate>()) {}
35 PinyinPrediction::~PinyinPrediction() {}
42 std::vector<std::pair<std::string, PinyinPredictionSource>>
44 const std::vector<std::string> &sentence,
45 std::string_view lastEncodedPinyin,
size_t maxSize) {
47 std::vector<std::pair<std::string, PinyinPredictionSource>> finalResult;
49 if (lastEncodedPinyin.empty() || sentence.empty()) {
50 auto result = Prediction::predictWithScore(state, sentence, maxSize);
51 std::ranges::transform(result, std::back_inserter(finalResult),
52 [](std::pair<std::string, float> &value) {
53 return std::make_pair(
54 std::move(value.first),
55 PinyinPredictionSource::Model);
60 auto cmp = [](
auto &lhs,
auto &rhs) {
61 if (std::get<float>(lhs) != std::get<float>((rhs))) {
62 return std::get<float>(lhs) > std::get<float>(rhs);
64 return std::get<std::string>(lhs) < std::get<std::string>(rhs);
67 auto result = Prediction::predictWithScore(state, sentence, maxSize);
68 std::vector<std::tuple<std::string, float, PinyinPredictionSource>>
70 std::ranges::transform(result, std::back_inserter(intermedidateResult),
71 [](std::pair<std::string, float> &value) {
72 return std::make_tuple(
73 std::move(value.first), value.second,
74 PinyinPredictionSource::Model);
76 std::ranges::make_heap(intermedidateResult, cmp);
78 State prevState = model()->nullState();
80 std::vector<WordNode> nodes;
81 std::unordered_set<std::string> dup;
82 if (!sentence.empty()) {
83 nodes.reserve(sentence.size());
84 for (
const auto &word : fcitx::MakeIterRange(
85 sentence.begin(), std::prev(sentence.end()))) {
86 auto idx = model()->index(word);
87 nodes.emplace_back(word, idx);
88 model()->score(prevState, nodes.back(), outState);
95 nodes.emplace_back(sentence.back(), model()->index(sentence.back()));
96 float adjust = model()->score(prevState, nodes.back(), outState);
97 for (
auto &result : intermedidateResult) {
98 std::get<float>(result) += adjust;
99 dup.insert(std::get<std::string>(result));
103 d->dict_->matchWordsPrefix(
104 lastEncodedPinyin.data(), lastEncodedPinyin.size(),
105 [
this, &sentence, &prevState, &cmp, &intermedidateResult, &dup,
106 maxSize](std::string_view, std::string_view hz,
float cost) {
107 if (sentence.back().size() < hz.size() &&
108 hz.starts_with(sentence.back())) {
110 std::string newWord(hz.substr(sentence.back().size()));
111 if (dup.contains(newWord)) {
115 std::tuple<std::string, float, PinyinPredictionSource> newItem{
117 cost + model()->singleWordScore(prevState, hz),
118 PinyinPredictionSource::Dictionary};
120 dup.insert(std::get<std::string>(newItem));
121 intermedidateResult.push_back(std::move(newItem));
122 std::ranges::push_heap(intermedidateResult, cmp);
123 while (intermedidateResult.size() > maxSize) {
124 std::ranges::pop_heap(intermedidateResult, cmp);
126 std::get<std::string>(intermedidateResult.back()));
127 intermedidateResult.pop_back();
133 std::ranges::sort_heap(intermedidateResult, cmp);
134 std::ranges::transform(
135 intermedidateResult, std::back_inserter(finalResult), [](
auto &value) {
136 return std::make_pair(std::move(std::get<std::string>(value)),
137 std::get<PinyinPredictionSource>(value));
143 std::vector<std::pair<std::string, PinyinPredictionSource>>
146 const std::vector<libime::HistoryBigram::WordWithCode> &sentence,
148 std::vector<std::string> words;
149 words.reserve(sentence.size());
150 for (
const auto &[word, code] : sentence) {
151 words.push_back(word);
153 std::string_view lastPinyin;
154 if (!sentence.empty()) {
155 lastPinyin = sentence.back().second;
157 return predict(state, words, lastPinyin, maxSize);
160 std::vector<std::string>
163 return Prediction::predict(sentence, maxSize);
std::vector< std::pair< std::string, PinyinPredictionSource > > predict(const State &state, const std::vector< std::string > &sentence, std::string_view lastEncodedPinyin, size_t maxSize=0)
Predict from model and pinyin dictionary for the last sentnce being type.
PinyinDictionary is a set of dictionaries for Pinyin.
void setPinyinDictionary(const PinyinDictionary *dict)
Set the pinyin dictionary used for prediction.