7 #include "languagemodel.h" 16 #include <string_view> 17 #include <type_traits> 18 #include <unordered_map> 21 #include <fcitx-utils/fs.h> 22 #include <fcitx-utils/macros.h> 23 #include <fcitx-utils/stringutils.h> 25 #include "constants.h" 28 #include "lm/config.hh" 29 #include "lm/lm_exception.hh" 30 #include "lm/model.hh" 31 #include "lm/return.hh" 32 #include "lm/state.hh" 33 #include "lm/word_index.hh" 34 #include "util/string_piece.hh" 42 const lm::ngram::Config &config)
43 : model_(file, config), file_(file) {}
44 lm::ngram::QuantArrayTrieModel model_;
46 mutable bool predictionLoaded_ =
false;
50 StaticLanguageModelFile::StaticLanguageModelFile(
const char *file) {
51 lm::ngram::Config config;
52 config.sentence_marker_missing = lm::SILENT;
53 d_ptr = std::make_unique<StaticLanguageModelFilePrivate>(file, config);
56 StaticLanguageModelFile::~StaticLanguageModelFile() {}
58 const DATrie<float> &StaticLanguageModelFile::predictionTrie()
const {
60 if (!d->predictionLoaded_) {
61 d->predictionLoaded_ =
true;
64 fin.open(d->file_ +
".predict", std::ios::in | std::ios::binary);
68 d->prediction_ = std::move(trie);
73 return d->prediction_;
76 static_assert(
sizeof(
void *) +
sizeof(lm::ngram::State) <= StateSize,
"Size");
78 LanguageModelBase::~LanguageModelBase() {}
80 bool LanguageModelBase::isNodeUnknown(
const LatticeNode &node)
const {
81 return isUnknown(node.idx(), node.word());
84 float LanguageModelBase::singleWordScore(std::string_view word)
const {
85 auto idx = index(word);
88 return score(nullState(), node, dummy);
91 float LanguageModelBase::singleWordScore(
const State &state,
92 std::string_view word)
const {
93 return wordsScore(state, std::vector<std::string_view>{word});
96 float LanguageModelBase::wordsScore(
97 const State &_state,
const std::vector<std::string_view> &words)
const {
101 std::vector<WordNode> nodes;
102 for (
auto word : words) {
103 auto idx = index(word);
104 nodes.emplace_back(word, idx);
105 s += score(state, nodes.back(), outState);
111 static_assert(std::is_standard_layout_v<lm::ngram::State> &&
112 std::is_trivial_v<lm::ngram::State>,
113 "State should be pod");
114 static_assert(std::is_same_v<WordIndex, lm::WordIndex>,
115 "word index should be same type");
117 static inline lm::ngram::State &lmState(State &state) {
118 return *
reinterpret_cast<lm::ngram::State *
>(state.data());
120 static inline const lm::ngram::State &lmState(
const State &state) {
121 return *
reinterpret_cast<const lm::ngram::State *
>(state.data());
127 : file_(std::move(file)) {}
129 auto *model() {
return file_ ? &file_->d_func()->model_ :
nullptr; }
130 const auto *model()
const {
131 return file_ ? &file_->d_func()->model_ :
nullptr;
134 std::shared_ptr<const StaticLanguageModelFile> file_;
138 std::log10(DEFAULT_LANGUAGE_MODEL_UNKNOWN_PROBABILITY_PENALTY);
141 LanguageModel::LanguageModel(
const char *file)
142 :
LanguageModel(std::make_shared<StaticLanguageModelFile>(file)) {}
144 LanguageModel::LanguageModel(
145 std::shared_ptr<const StaticLanguageModelFile> file)
146 : d_ptr(std::make_unique<LanguageModelPrivate>(std::move(file))) {
149 lmState(d->beginState_) = d->model()->BeginSentenceState();
150 lmState(d->nullState_) = d->model()->NullContextState();
154 LanguageModel::~LanguageModel() {}
156 size_t LanguageModel::maxOrder() {
return KENLM_MAX_ORDER; }
158 std::shared_ptr<const StaticLanguageModelFile>
159 LanguageModel::languageModelFile()
const {
164 WordIndex LanguageModel::beginSentence()
const {
169 const auto &v = d->model()->GetVocabulary();
170 return v.BeginSentence();
173 WordIndex LanguageModel::endSentence()
const {
178 const auto &v = d->model()->GetVocabulary();
179 return v.EndSentence();
182 WordIndex LanguageModel::unknown()
const {
187 const auto &v = d->model()->GetVocabulary();
191 WordIndex LanguageModel::index(std::string_view word)
const {
196 const auto &v = d->model()->GetVocabulary();
197 return v.Index(StringPiece{word.data(), word.size()});
200 const State &LanguageModel::beginState()
const {
202 return d->beginState_;
205 const State &LanguageModel::nullState()
const {
207 return d->nullState_;
210 float LanguageModel::score(
const State &state,
const WordNode &node,
213 assert(&state != &out);
217 return d->model()->Score(lmState(state), node.idx(), lmState(out)) +
218 (node.idx() == unknown() ? d->unknown_ : 0.0F);
221 bool LanguageModel::isUnknown(WordIndex idx, std::string_view )
const {
222 return idx == unknown();
226 LanguageModel::maxNgramLength(
const std::vector<std::string> &words)
const {
231 State state = nullState();
234 unsigned int maxNgramLength = 0;
235 std::vector<WordNode> nodes;
236 for (
const auto &word : words) {
237 const auto idx = index(word);
238 lm::FullScoreReturn full =
239 d->model()->FullScore(lmState(state), idx, lmState(outState));
240 unsigned int ngramLength = full.ngram_length;
241 if (ngramLength == 1 && idx == unknown()) {
245 maxNgramLength = std::max(maxNgramLength, ngramLength);
248 return maxNgramLength;
251 void LanguageModel::setUnknownPenalty(
float unknown) {
253 d->unknown_ = unknown;
256 float LanguageModel::unknownPenalty()
const {
263 std::unordered_map<std::string,
264 std::weak_ptr<const StaticLanguageModelFile>>
268 LanguageModelResolver::LanguageModelResolver()
269 : d_ptr(std::make_unique<LanguageModelResolverPrivate>()) {}
273 std::shared_ptr<const StaticLanguageModelFile>
274 LanguageModelResolver::languageModelFileForLanguage(
275 const std::string &language) {
277 auto iter = d->files_.find(language);
278 std::shared_ptr<const StaticLanguageModelFile> file;
279 if (iter != d->files_.end()) {
280 file = iter->second.lock();
284 d->files_.erase(iter);
287 auto fileName = languageModelFileNameForLanguage(language);
288 if (fileName.empty()) {
292 file = std::make_shared<StaticLanguageModelFile>(fileName.data());
293 d->files_.emplace(language, file);
297 DefaultLanguageModelResolver::DefaultLanguageModelResolver() =
default;
298 DefaultLanguageModelResolver::~DefaultLanguageModelResolver() =
default;
305 std::string DefaultLanguageModelResolver::languageModelFileNameForLanguage(
306 const std::string &language) {
307 if (language.empty() || language.find(
'/') != std::string::npos) {
311 const char *modelDirs = getenv(
"LIBIME_MODEL_DIRS");
312 std::vector<std::string> dirs;
313 if (modelDirs && modelDirs[0]) {
314 dirs = fcitx::stringutils::split(modelDirs,
":");
316 dirs.push_back(LIBIME_INSTALL_LIBDATADIR);
319 for (
const auto &dir : dirs) {
320 auto file = fcitx::stringutils::joinPath(dir, language +
".lm");
321 if (fcitx::fs::isreg(file)) {
a class that provides language model data for different languages.
Provide a DATrie implementation.