7 #include "userlanguagemodel.h" 16 #include <string_view> 19 #include <fcitx-utils/macros.h> 20 #include "constants.h" 21 #include "historybigram.h" 22 #include "languagemodel.h" 23 #include "lm/state.hh" 32 bool useOnlyUnigram_ =
false;
35 ValidationCodeExtractor extractor_;
36 float weight_ = DEFAULT_USER_LANGUAGE_MODEL_USER_WEIGHT;
39 float wa_ = std::log10(1 - weight_), wb_ = std::log10(weight_);
41 const WordNode *wordFromState(
const State &state)
const {
42 return loadNative<const WordNode *>(
reinterpret_cast<const char *
>(
43 state.data() +
sizeof(lm::ngram::State)));
46 void setWordToState(State &state,
const WordNode *node)
const {
47 storeNative<const WordNode *>(
48 reinterpret_cast<char *
>(state.data() +
sizeof(lm::ngram::State)),
52 UserLanguageModel::UserLanguageModel(
const char *file)
55 UserLanguageModel::UserLanguageModel(
56 std::shared_ptr<const StaticLanguageModelFile> file)
58 d_ptr(std::make_unique<UserLanguageModelPrivate>()) {
61 d->beginState_ = LanguageModel::beginState();
62 d->setWordToState(d->beginState_,
nullptr);
63 d->nullState_ = LanguageModel::nullState();
64 d->setWordToState(d->nullState_,
nullptr);
67 UserLanguageModel::~UserLanguageModel() {}
79 void UserLanguageModel::load(std::istream &in) {
84 d->history_ = std::move(history);
86 void UserLanguageModel::save(std::ostream &out) {
88 d->history_.save(out);
91 void UserLanguageModel::setHistoryWeight(
float w) {
93 assert(w >= 0.0 && w <= 1.0);
95 d->wa_ = std::log10(1 - d->weight_);
96 d->wb_ = std::log10(d->weight_);
99 const State &UserLanguageModel::beginState()
const {
101 return d->beginState_;
104 const State &UserLanguageModel::nullState()
const {
106 return d->nullState_;
109 static const float log_10 = std::log(10);
115 inline float log1p10exp(
float x) {
116 return x < MIN_FLOAT_LOG10 ? 0. : std::log1p(std::pow(10, x)) / log_10;
118 inline float sum_log_prob(
float a,
float b) {
119 return a > b ? (a + log1p10exp(b - a)) : (b + log1p10exp(a - b));
122 float UserLanguageModel::score(
const State &state,
const WordNode &word,
126 if (d->useOnlyUnigram_) {
127 score = LanguageModel::score(d->nullState_, word, out);
129 score = LanguageModel::score(state, word, out);
131 const auto *prev = d->wordFromState(state);
134 userScore = d->history_.scoreWithCode(prev, &word, d->extractor_);
136 userScore = d->history_.score(prev, &word);
138 d->setWordToState(out, &word);
139 return std::max(score, sum_log_prob(score + d->wa_, userScore + d->wb_));
142 bool UserLanguageModel::isUnknown(WordIndex idx, std::string_view view)
const {
144 return idx == unknown() && d->history_.isUnknown(view);
147 float UserLanguageModel::historyWeight()
const {
152 void UserLanguageModel::setUseOnlyUnigram(
bool useOnlyUnigram) {
154 d->useOnlyUnigram_ = useOnlyUnigram;
155 d->history_.setUseOnlyUnigram(useOnlyUnigram);
158 bool UserLanguageModel::useOnlyUnigram()
const {
160 return d->useOnlyUnigram_;
163 bool UserLanguageModel::containsNonUnigram(
164 const std::vector<std::string> &words)
const {
166 if (words.size() <= 1 || d->useOnlyUnigram_) {
170 for (
auto iter = words.begin(); iter != std::prev(words.end()); ++iter) {
171 if (d->history_.containsBigram(*iter, *(std::next(iter)))) {
176 return LanguageModel::maxNgramLength(words) > 1;
179 void UserLanguageModel::setCodeExtractor(ValidationCodeExtractor extractor) {
181 d->extractor_ = std::move(extractor);
void setUnknownPenalty(float unknown)
Set unknown probability penatly.