libime
prediction.cpp
1 /*
2  * SPDX-FileCopyrightText: 2017-2017 CSSlayer <wengxt@gmail.com>
3  *
4  * SPDX-License-Identifier: LGPL-2.1-or-later
5  */
6 
7 #include "prediction.h"
8 #include <algorithm>
9 #include <cstddef>
10 #include <memory>
11 #include <string>
12 #include <unordered_set>
13 #include <utility>
14 #include <vector>
15 #include <fcitx-utils/macros.h>
16 #include "datrie.h"
17 #include "historybigram.h"
18 #include "languagemodel.h"
19 
20 namespace libime {
21 
23 public:
24  const LanguageModel *model_ = nullptr;
25  const HistoryBigram *bigram_ = nullptr;
26 };
27 
28 Prediction::Prediction() : d_ptr(std::make_unique<PredictionPrivate>()) {}
29 
30 Prediction::~Prediction() = default;
31 
32 void Prediction::setLanguageModel(const LanguageModel *model) {
33  FCITX_D();
34  d->model_ = model;
35 }
36 
37 void Prediction::setHistoryBigram(const HistoryBigram *bigram) {
38  FCITX_D();
39  d->bigram_ = bigram;
40 }
41 
42 const LanguageModel *Prediction::model() const {
43  FCITX_D();
44  return d->model_;
45 }
46 
47 const HistoryBigram *Prediction::historyBigram() const {
48  FCITX_D();
49  return d->bigram_;
50 }
51 
52 std::vector<std::string>
53 Prediction::predict(const std::vector<std::string> &sentence,
54  size_t realMaxSize) {
55  FCITX_D();
56  if (!d->model_) {
57  return {};
58  }
59 
60  State state = d->model_->nullState();
61  State outState;
62  std::vector<WordNode> nodes;
63  nodes.reserve(sentence.size());
64  for (const auto &word : sentence) {
65  auto idx = d->model_->index(word);
66  nodes.emplace_back(word, idx);
67  d->model_->score(state, nodes.back(), outState);
68  state = outState;
69  }
70  return predict(state, sentence, realMaxSize);
71 }
72 
73 std::vector<std::pair<std::string, float>>
74 Prediction::predictWithScore(const State &state,
75  const std::vector<std::string> &sentence,
76  size_t realMaxSize) {
77  FCITX_D();
78  if (!d->model_) {
79  return {};
80  }
81  // Search more get less.
82  size_t maxSize = realMaxSize * 2;
83  std::unordered_set<std::string> words;
84 
85  if (auto file = d->model_->languageModelFile()) {
86  std::string search = "<unk>";
87  if (!sentence.empty()) {
88  search = sentence.back();
89  }
90  search += "|";
91  const auto &trie = file->predictionTrie();
92  trie.foreach(search, [&trie, &words,
93  maxSize](DATrie<float>::value_type, size_t len,
95  std::string buf;
96  trie.suffix(buf, len, pos);
97  words.emplace(std::move(buf));
98 
99  return maxSize <= 0 || words.size() < maxSize;
100  });
101  }
102 
103  if (d->bigram_) {
104  d->bigram_->fillPredict(words, sentence, maxSize);
105  }
106 
107  std::vector<std::pair<std::string, float>> temps;
108  for (auto word : words) {
109  auto score = d->model_->singleWordScore(state, word);
110  temps.emplace_back(std::move(word), score);
111  }
112  std::sort(temps.begin(), temps.end(), [](auto &lhs, auto &rhs) {
113  if (lhs.second != rhs.second) {
114  return lhs.second > rhs.second;
115  }
116  return lhs.first < rhs.first;
117  });
118 
119  if (realMaxSize && temps.size() > realMaxSize) {
120  temps.resize(realMaxSize);
121  }
122  return temps;
123 }
124 
125 std::vector<std::string>
126 Prediction::predict(const State &state,
127  const std::vector<std::string> &sentence,
128  size_t realMaxSize) {
129 
130  auto temps = predictWithScore(state, sentence, realMaxSize);
131  std::vector<std::string> result;
132  result.reserve(temps.size());
133  for (auto &temp : temps) {
134  result.emplace_back(std::move(temp.first));
135  }
136  return result;
137 }
138 
139 } // namespace libime
Provide a DATrie implementation.
This is a trie based on cedar<www.tkl.iis.u-tokyo.ac.jp/~ynaga/cedar/>.
Definition: datrie.h:55