libime
userlanguagemodel.cpp
1 /*
2  * SPDX-FileCopyrightText: 2017-2017 CSSlayer <wengxt@gmail.com>
3  *
4  * SPDX-License-Identifier: LGPL-2.1-or-later
5  */
6 
7 #include "userlanguagemodel.h"
8 #include <algorithm>
9 #include <cassert>
10 #include <cmath>
11 #include <istream>
12 #include <iterator>
13 #include <memory>
14 #include <ostream>
15 #include <string>
16 #include <string_view>
17 #include <utility>
18 #include <vector>
19 #include <fcitx-utils/macros.h>
20 #include "constants.h"
21 #include "historybigram.h"
22 #include "languagemodel.h"
23 #include "lm/state.hh"
24 #include "utils_p.h"
25 
26 namespace libime {
27 
29 public:
30  State beginState_;
31  State nullState_;
32  bool useOnlyUnigram_ = false;
33 
34  HistoryBigram history_;
35  ValidationCodeExtractor extractor_;
36  float weight_ = DEFAULT_USER_LANGUAGE_MODEL_USER_WEIGHT;
37  // log(wa * exp(a) + wb * exp(b))
38  // log(exp(log(wa) + a) + exp(b + log(wb))
39  float wa_ = std::log10(1 - weight_), wb_ = std::log10(weight_);
40 
41  const WordNode *wordFromState(const State &state) const {
42  return loadNative<const WordNode *>(reinterpret_cast<const char *>(
43  state.data() + sizeof(lm::ngram::State)));
44  }
45 
46  void setWordToState(State &state, const WordNode *node) const {
47  storeNative<const WordNode *>(
48  reinterpret_cast<char *>(state.data() + sizeof(lm::ngram::State)),
49  node);
50  }
51 };
52 UserLanguageModel::UserLanguageModel(const char *file)
53  : UserLanguageModel(std::make_shared<StaticLanguageModelFile>(file)) {}
54 
55 UserLanguageModel::UserLanguageModel(
56  std::shared_ptr<const StaticLanguageModelFile> file)
57  : LanguageModel(std::move(file)),
58  d_ptr(std::make_unique<UserLanguageModelPrivate>()) {
59  FCITX_D();
60  // resize will fill remaining with zero
61  d->beginState_ = LanguageModel::beginState();
62  d->setWordToState(d->beginState_, nullptr);
63  d->nullState_ = LanguageModel::nullState();
64  d->setWordToState(d->nullState_, nullptr);
65 }
66 
67 UserLanguageModel::~UserLanguageModel() {}
68 
69 HistoryBigram &UserLanguageModel::history() {
70  FCITX_D();
71  return d->history_;
72 }
73 
74 const HistoryBigram &UserLanguageModel::history() const {
75  FCITX_D();
76  return d->history_;
77 }
78 
79 void UserLanguageModel::load(std::istream &in) {
80  FCITX_D();
81  HistoryBigram history;
82  history.setUnknownPenalty(d->history_.unknownPenalty());
83  history.load(in);
84  d->history_ = std::move(history);
85 }
86 void UserLanguageModel::save(std::ostream &out) {
87  FCITX_D();
88  d->history_.save(out);
89 }
90 
91 void UserLanguageModel::setHistoryWeight(float w) {
92  FCITX_D();
93  assert(w >= 0.0 && w <= 1.0);
94  d->weight_ = w;
95  d->wa_ = std::log10(1 - d->weight_);
96  d->wb_ = std::log10(d->weight_);
97 }
98 
99 const State &UserLanguageModel::beginState() const {
100  FCITX_D();
101  return d->beginState_;
102 }
103 
104 const State &UserLanguageModel::nullState() const {
105  FCITX_D();
106  return d->nullState_;
107 }
108 
109 static const float log_10 = std::log(10);
110 
111 // log10(exp10(a) + exp10(b))
112 // = log10(exp10(b) * (1 + exp10(a - b)))
113 // = b + log10(1 + exp10(a - b))
114 // = b + log1p(exp10(a - b)) / log(10)
115 inline float log1p10exp(float x) {
116  return x < MIN_FLOAT_LOG10 ? 0. : std::log1p(std::pow(10, x)) / log_10;
117 }
118 inline float sum_log_prob(float a, float b) {
119  return a > b ? (a + log1p10exp(b - a)) : (b + log1p10exp(a - b));
120 }
121 
122 float UserLanguageModel::score(const State &state, const WordNode &word,
123  State &out) const {
124  FCITX_D();
125  float score;
126  if (d->useOnlyUnigram_) {
127  score = LanguageModel::score(d->nullState_, word, out);
128  } else {
129  score = LanguageModel::score(state, word, out);
130  }
131  const auto *prev = d->wordFromState(state);
132  float userScore;
133  if (d->extractor_) {
134  userScore = d->history_.scoreWithCode(prev, &word, d->extractor_);
135  } else {
136  userScore = d->history_.score(prev, &word);
137  }
138  d->setWordToState(out, &word);
139  return std::max(score, sum_log_prob(score + d->wa_, userScore + d->wb_));
140 }
141 
142 bool UserLanguageModel::isUnknown(WordIndex idx, std::string_view view) const {
143  FCITX_D();
144  return idx == unknown() && d->history_.isUnknown(view);
145 }
146 
147 float UserLanguageModel::historyWeight() const {
148  FCITX_D();
149  return d->weight_;
150 }
151 
152 void UserLanguageModel::setUseOnlyUnigram(bool useOnlyUnigram) {
153  FCITX_D();
154  d->useOnlyUnigram_ = useOnlyUnigram;
155  d->history_.setUseOnlyUnigram(useOnlyUnigram);
156 }
157 
158 bool UserLanguageModel::useOnlyUnigram() const {
159  FCITX_D();
160  return d->useOnlyUnigram_;
161 }
162 
163 bool UserLanguageModel::containsNonUnigram(
164  const std::vector<std::string> &words) const {
165  FCITX_D();
166  if (words.size() <= 1 || d->useOnlyUnigram_) {
167  return false;
168  }
169 
170  for (auto iter = words.begin(); iter != std::prev(words.end()); ++iter) {
171  if (d->history_.containsBigram(*iter, *(std::next(iter)))) {
172  return true;
173  }
174  }
175 
176  return LanguageModel::maxNgramLength(words) > 1;
177 }
178 
179 void UserLanguageModel::setCodeExtractor(ValidationCodeExtractor extractor) {
180  FCITX_D();
181  d->extractor_ = std::move(extractor);
182 }
183 
184 } // namespace libime
void setUnknownPenalty(float unknown)
Set unknown probability penatly.