libime
languagemodel.cpp
1 /*
2  * SPDX-FileCopyrightText: 2017-2017 CSSlayer <wengxt@gmail.com>
3  *
4  * SPDX-License-Identifier: LGPL-2.1-or-later
5  */
6 
7 #include "languagemodel.h"
8 #include <algorithm>
9 #include <cassert>
10 #include <cmath>
11 #include <cstdlib>
12 #include <fstream>
13 #include <ios>
14 #include <memory>
15 #include <string>
16 #include <string_view>
17 #include <type_traits>
18 #include <unordered_map>
19 #include <utility>
20 #include <vector>
21 #include <fcitx-utils/fs.h>
22 #include <fcitx-utils/macros.h>
23 #include <fcitx-utils/stringutils.h>
24 #include "config.h"
25 #include "constants.h"
26 #include "datrie.h"
27 #include "lattice.h"
28 #include "lm/config.hh"
29 #include "lm/lm_exception.hh"
30 #include "lm/model.hh"
31 #include "lm/return.hh"
32 #include "lm/state.hh"
33 #include "lm/word_index.hh"
34 #include "util/string_piece.hh"
35 #include "utils.h"
36 
37 namespace libime {
38 
40 public:
41  StaticLanguageModelFilePrivate(const char *file,
42  const lm::ngram::Config &config)
43  : model_(file, config), file_(file) {}
44  lm::ngram::QuantArrayTrieModel model_;
45  std::string file_;
46  mutable bool predictionLoaded_ = false;
47  mutable DATrie<float> prediction_;
48 };
49 
50 StaticLanguageModelFile::StaticLanguageModelFile(const char *file) {
51  lm::ngram::Config config;
52  config.sentence_marker_missing = lm::SILENT;
53  d_ptr = std::make_unique<StaticLanguageModelFilePrivate>(file, config);
54 }
55 
56 StaticLanguageModelFile::~StaticLanguageModelFile() {}
57 
58 const DATrie<float> &StaticLanguageModelFile::predictionTrie() const {
59  FCITX_D();
60  if (!d->predictionLoaded_) {
61  d->predictionLoaded_ = true;
62  try {
63  std::ifstream fin;
64  fin.open(d->file_ + ".predict", std::ios::in | std::ios::binary);
65  if (fin) {
66  DATrie<float> trie;
67  trie.load(fin);
68  d->prediction_ = std::move(trie);
69  }
70  } catch (...) {
71  }
72  }
73  return d->prediction_;
74 }
75 
76 static_assert(sizeof(void *) + sizeof(lm::ngram::State) <= StateSize, "Size");
77 
78 LanguageModelBase::~LanguageModelBase() {}
79 
80 bool LanguageModelBase::isNodeUnknown(const LatticeNode &node) const {
81  return isUnknown(node.idx(), node.word());
82 }
83 
84 float LanguageModelBase::singleWordScore(std::string_view word) const {
85  auto idx = index(word);
86  State dummy;
87  WordNode node(word, idx);
88  return score(nullState(), node, dummy);
89 }
90 
91 float LanguageModelBase::singleWordScore(const State &state,
92  std::string_view word) const {
93  return wordsScore(state, std::vector<std::string_view>{word});
94 }
95 
96 float LanguageModelBase::wordsScore(
97  const State &_state, const std::vector<std::string_view> &words) const {
98  float s = 0;
99  State state = _state;
100  State outState;
101  std::vector<WordNode> nodes;
102  for (auto word : words) {
103  auto idx = index(word);
104  nodes.emplace_back(word, idx);
105  s += score(state, nodes.back(), outState);
106  state = outState;
107  }
108  return s;
109 }
110 
111 static_assert(std::is_standard_layout_v<lm::ngram::State> &&
112  std::is_trivial_v<lm::ngram::State>,
113  "State should be pod");
114 static_assert(std::is_same_v<WordIndex, lm::WordIndex>,
115  "word index should be same type");
116 
117 static inline lm::ngram::State &lmState(State &state) {
118  return *reinterpret_cast<lm::ngram::State *>(state.data());
119 }
120 static inline const lm::ngram::State &lmState(const State &state) {
121  return *reinterpret_cast<const lm::ngram::State *>(state.data());
122 }
123 
125 public:
126  LanguageModelPrivate(std::shared_ptr<const StaticLanguageModelFile> file)
127  : file_(std::move(file)) {}
128 
129  auto *model() { return file_ ? &file_->d_func()->model_ : nullptr; }
130  const auto *model() const {
131  return file_ ? &file_->d_func()->model_ : nullptr;
132  }
133 
134  std::shared_ptr<const StaticLanguageModelFile> file_;
135  State beginState_;
136  State nullState_;
137  float unknown_ =
138  std::log10(DEFAULT_LANGUAGE_MODEL_UNKNOWN_PROBABILITY_PENALTY);
139 };
140 
141 LanguageModel::LanguageModel(const char *file)
142  : LanguageModel(std::make_shared<StaticLanguageModelFile>(file)) {}
143 
144 LanguageModel::LanguageModel(
145  std::shared_ptr<const StaticLanguageModelFile> file)
146  : d_ptr(std::make_unique<LanguageModelPrivate>(std::move(file))) {
147  FCITX_D();
148  if (d->model()) {
149  lmState(d->beginState_) = d->model()->BeginSentenceState();
150  lmState(d->nullState_) = d->model()->NullContextState();
151  }
152 }
153 
154 LanguageModel::~LanguageModel() {}
155 
156 size_t LanguageModel::maxOrder() { return KENLM_MAX_ORDER; }
157 
158 std::shared_ptr<const StaticLanguageModelFile>
159 LanguageModel::languageModelFile() const {
160  FCITX_D();
161  return d->file_;
162 }
163 
164 WordIndex LanguageModel::beginSentence() const {
165  FCITX_D();
166  if (!d->model()) {
167  return 0;
168  }
169  const auto &v = d->model()->GetVocabulary();
170  return v.BeginSentence();
171 }
172 
173 WordIndex LanguageModel::endSentence() const {
174  FCITX_D();
175  if (!d->model()) {
176  return 0;
177  }
178  const auto &v = d->model()->GetVocabulary();
179  return v.EndSentence();
180 }
181 
182 WordIndex LanguageModel::unknown() const {
183  FCITX_D();
184  if (!d->model()) {
185  return 0;
186  }
187  const auto &v = d->model()->GetVocabulary();
188  return v.NotFound();
189 }
190 
191 WordIndex LanguageModel::index(std::string_view word) const {
192  FCITX_D();
193  if (!d->model()) {
194  return 0;
195  }
196  const auto &v = d->model()->GetVocabulary();
197  return v.Index(StringPiece{word.data(), word.size()});
198 }
199 
200 const State &LanguageModel::beginState() const {
201  FCITX_D();
202  return d->beginState_;
203 }
204 
205 const State &LanguageModel::nullState() const {
206  FCITX_D();
207  return d->nullState_;
208 }
209 
210 float LanguageModel::score(const State &state, const WordNode &node,
211  State &out) const {
212  FCITX_D();
213  assert(&state != &out);
214  if (!d->model()) {
215  return d->unknown_;
216  }
217  return d->model()->Score(lmState(state), node.idx(), lmState(out)) +
218  (node.idx() == unknown() ? d->unknown_ : 0.0F);
219 }
220 
221 bool LanguageModel::isUnknown(WordIndex idx, std::string_view /*word*/) const {
222  return idx == unknown();
223 }
224 
225 unsigned int
226 LanguageModel::maxNgramLength(const std::vector<std::string> &words) const {
227  FCITX_D();
228  if (!d->model()) {
229  return 0;
230  }
231  State state = nullState();
232  State outState;
233 
234  unsigned int maxNgramLength = 0;
235  std::vector<WordNode> nodes;
236  for (const auto &word : words) {
237  const auto idx = index(word);
238  lm::FullScoreReturn full =
239  d->model()->FullScore(lmState(state), idx, lmState(outState));
240  unsigned int ngramLength = full.ngram_length;
241  if (ngramLength == 1 && idx == unknown()) {
242  ngramLength = 0;
243  }
244 
245  maxNgramLength = std::max(maxNgramLength, ngramLength);
246  state = outState;
247  }
248  return maxNgramLength;
249 }
250 
251 void LanguageModel::setUnknownPenalty(float unknown) {
252  FCITX_D();
253  d->unknown_ = unknown;
254 }
255 
256 float LanguageModel::unknownPenalty() const {
257  FCITX_D();
258  return d->unknown_;
259 }
260 
262 public:
263  std::unordered_map<std::string,
264  std::weak_ptr<const StaticLanguageModelFile>>
265  files_;
266 };
267 
268 LanguageModelResolver::LanguageModelResolver()
269  : d_ptr(std::make_unique<LanguageModelResolverPrivate>()) {}
270 
271 FCITX_DEFINE_DEFAULT_DTOR_AND_MOVE(LanguageModelResolver)
272 
273 std::shared_ptr<const StaticLanguageModelFile>
274 LanguageModelResolver::languageModelFileForLanguage(
275  const std::string &language) {
276  FCITX_D();
277  auto iter = d->files_.find(language);
278  std::shared_ptr<const StaticLanguageModelFile> file;
279  if (iter != d->files_.end()) {
280  file = iter->second.lock();
281  if (file) {
282  return file;
283  }
284  d->files_.erase(iter);
285  }
286 
287  auto fileName = languageModelFileNameForLanguage(language);
288  if (fileName.empty()) {
289  return nullptr;
290  }
291 
292  file = std::make_shared<StaticLanguageModelFile>(fileName.data());
293  d->files_.emplace(language, file);
294  return file;
295 }
296 
297 DefaultLanguageModelResolver::DefaultLanguageModelResolver() = default;
298 DefaultLanguageModelResolver::~DefaultLanguageModelResolver() = default;
299 
300 DefaultLanguageModelResolver &DefaultLanguageModelResolver::instance() {
301  static DefaultLanguageModelResolver resolver;
302  return resolver;
303 }
304 
305 std::string DefaultLanguageModelResolver::languageModelFileNameForLanguage(
306  const std::string &language) {
307  if (language.empty() || language.find('/') != std::string::npos) {
308  return {};
309  }
310 
311  const char *modelDirs = getenv("LIBIME_MODEL_DIRS");
312  std::vector<std::string> dirs;
313  if (modelDirs && modelDirs[0]) {
314  dirs = fcitx::stringutils::split(modelDirs, ":");
315  } else {
316  dirs.push_back(LIBIME_INSTALL_LIBDATADIR);
317  }
318 
319  for (const auto &dir : dirs) {
320  auto file = fcitx::stringutils::joinPath(dir, language + ".lm");
321  if (fcitx::fs::isreg(file)) {
322  return file;
323  }
324  }
325  return {};
326 }
327 } // namespace libime
a class that provides language model data for different languages.
Provide a DATrie implementation.