libime
decoder.cpp
1 /*
2  * SPDX-FileCopyrightText: 2017-2017 CSSlayer <wengxt@gmail.com>
3  *
4  * SPDX-License-Identifier: LGPL-2.1-or-later
5  */
6 
7 #include "decoder.h"
8 #include <algorithm>
9 #include <cassert>
10 #include <chrono>
11 #include <cstddef>
12 #include <limits>
13 #include <memory>
14 #include <queue>
15 #include <ranges>
16 #include <string>
17 #include <string_view>
18 #include <tuple>
19 #include <unordered_map>
20 #include <unordered_set>
21 #include <utility>
22 #include <vector>
23 #include <boost/container_hash/hash.hpp>
24 #include <boost/functional/hash.hpp>
25 #include <boost/ptr_container/ptr_vector.hpp>
26 #include <fcitx-utils/macros.h>
27 #include "languagemodel.h"
28 #include "lattice.h"
29 #include "lattice_p.h"
30 #include "segmentgraph.h"
31 #include "utils.h"
32 #include "utils_p.h"
33 
34 namespace libime {
35 
36 constexpr int MAX_BACKWARD_SEARCH_SIZE = 10000;
37 
38 struct NBestNode {
39  NBestNode(const LatticeNode *node) : node_(node) {}
40 
41  const LatticeNode *node_;
42  // for nbest
43  float gn_ = 0.0F;
44  float fn_ = -std::numeric_limits<float>::max();
45  std::shared_ptr<NBestNode> next_;
46 };
47 
48 template <typename T>
49 struct NBestNodeLess {
50  bool operator()(const T &lhs, const T &rhs) const {
51  return lhs->fn_ < rhs->fn_;
52  }
53 };
54 
56 public:
57  DecoderPrivate(const Dictionary *dict, const LanguageModelBase *model)
58  : dict_(dict), model_(model) {}
59 
60  // Try to update lattice based on existing data.
61  bool
62  buildLattice(const Decoder *q, Lattice &l,
63  const std::unordered_set<const SegmentGraphNode *> &ignore,
64  const State &state, const SegmentGraph &graph,
65  size_t frameSize, void *helper) const;
66 
67  void
68  forwardSearch(const Decoder *q, const SegmentGraph &graph, Lattice &lattice,
69  const std::unordered_set<const SegmentGraphNode *> &ignore,
70  size_t beamSize) const;
71  void backwardSearch(const SegmentGraph &graph, Lattice &l, size_t nbest,
72  float max, float min, size_t beamSize) const;
73 
74  const Dictionary *dict_;
75  const LanguageModelBase *model_;
76 };
77 
78 bool DecoderPrivate::buildLattice(
79  const Decoder *q, Lattice &l,
80  const std::unordered_set<const SegmentGraphNode *> &ignore,
81  const State &state, const SegmentGraph &graph, size_t frameSize,
82  void *helper) const {
83  LatticeMap &lattice = l.d_ptr->lattice_;
84 
85  // Create the root node.
86  if (!lattice.contains(&graph.start())) {
87  lattice[&graph.start()].push_back(
88  q->createLatticeNode(graph, model_, "", model_->beginSentence(),
89  {nullptr, &graph.start()}, state, 0));
90  }
91 
92  // std::vector is used here to make sure std::make_heap works.
93  std::unordered_map<
94  std::pair<const SegmentGraphNode *, const SegmentGraphNode *>,
95  std::vector<std::unique_ptr<LatticeNode>>,
96  boost::hash<
97  std::pair<const SegmentGraphNode *, const SegmentGraphNode *>>>
98  frames;
99 
100  auto dictMatchCallback = [this, &graph, &frames, q, frameSize](
101  const SegmentGraphPath &path, WordNode &word,
102  float adjust,
103  std::unique_ptr<LatticeNodeData> data) {
104  if (InvalidWordIndex == word.idx()) {
105  auto idx = model_->index(word.word());
106  word.setIdx(idx);
107  }
108  assert(path.front());
109  auto &frame = frames[std::make_pair(path.front(), path.back())];
110  const bool applyFrameSize =
111  path.front() != &graph.start() && frameSize > 0;
112  auto *node = q->createLatticeNode(
113  graph, model_, word.word(), word.idx(), path, model_->nullState(),
114  adjust, std::move(data), frame.empty());
115  if (!node) {
116  return;
117  }
118 
119  frame.emplace_back(node);
120  if (!applyFrameSize) {
121  return;
122  }
123  // Make a maximum heap.
124  auto scoreGreaterThan = [](const auto &lhs, const auto &rhs) {
125  return lhs->score() > rhs->score();
126  };
127  // Just reach the limit, initialize the heap.
128  if (frame.size() == frameSize) {
129  for (auto &n : frame) {
130  // Cache the score here.
131  n->setScore(model_->singleWordScore(n->word()) + n->cost());
132  }
133  std::make_heap(frame.begin(), frame.end(), scoreGreaterThan);
134  } else if (frame.size() == frameSize + 1) {
135  // Cache the score here.
136  node->setScore(model_->singleWordScore(node->word()) +
137  node->cost());
138  // Take a short cut, check if node score greater than minimum
139  if (scoreGreaterThan(node, frame[0])) {
140  std::push_heap(frame.begin(), frame.end(), scoreGreaterThan);
141  std::pop_heap(frame.begin(), frame.end(), scoreGreaterThan);
142  }
143  frame.pop_back();
144  }
145  };
146 
147  dict_->matchPrefix(graph, dictMatchCallback, ignore, helper);
148 
149  for (auto &[path, nodes] : frames) {
150  auto &latticeUnit = lattice[path.second];
151  for (auto &node : nodes) {
152  latticeUnit.push_back(node.release());
153  }
154  }
155  if (!lattice.contains(&graph.end())) {
156  return false;
157  }
158 
159  // Create the node for end.
160  lattice[nullptr].push_back(
161  q->createLatticeNode(graph, model_, "", model_->endSentence(),
162  {&graph.end(), nullptr}, model_->nullState()));
163  return true;
164 }
165 
166 void DecoderPrivate::forwardSearch(
167  const Decoder *q, const SegmentGraph &graph, Lattice &l,
168  const std::unordered_set<const SegmentGraphNode *> &ignore,
169  size_t beamSize) const {
170  State state;
171  LatticeMap &lattice = l.d_ptr->lattice_;
172  std::unordered_map<const SegmentGraphNode *,
173  std::tuple<float, LatticeNode *, State>>
174  unknownIdCache;
175  const auto *start = &graph.start();
176  // forward search
177  auto updateForNode = [&](const SegmentGraphBase &,
178  const SegmentGraphNode *graphNode) {
179  if (graphNode == start || !lattice.contains(graphNode) ||
180  ignore.contains(graphNode)) {
181  return true;
182  }
183  auto &latticeNodes = lattice[graphNode];
184  for (auto &node : latticeNodes) {
185  const auto *from = node.from();
186  assert(graph.checkNodeInGraph(from));
187  float maxScore = -std::numeric_limits<float>::max();
188  LatticeNode *maxNode = nullptr;
189  State maxState;
190  bool isUnknown = model_->isNodeUnknown(node);
191  if (isUnknown) {
192  auto iter = unknownIdCache.find(from);
193  if (iter != unknownIdCache.end()) {
194  std::tie(maxScore, maxNode, maxState) = iter->second;
195  }
196  }
197 
198  if (!maxNode) {
199  auto iter = lattice.find(from);
200  // assert(iter != lattice.end());
201  if (iter == lattice.end()) {
202  continue;
203  }
204  auto &searchFrom = iter->second;
205  auto searchSize = beamSize;
206  if (searchSize) {
207  searchSize = std::min(searchSize, lattice[from].size());
208  } else {
209  searchSize = lattice[from].size();
210  }
211  for (auto &parent : searchFrom | std::views::take(searchSize)) {
212  auto score = parent.score() +
213  model_->score(parent.state(), node, state);
214  if (score > maxScore) {
215  maxScore = score;
216  maxNode = &parent;
217  maxState = state;
218  }
219  }
220 
221  if (isUnknown) {
222  unknownIdCache.emplace(
223  std::piecewise_construct, std::forward_as_tuple(from),
224  std::forward_as_tuple(maxScore, maxNode, maxState));
225  }
226  }
227 
228  assert(maxNode);
229  node.setScore(maxScore + node.cost());
230  node.setPrev(maxNode);
231  node.state() = maxState;
232  }
233  if (q->needSort(graph, graphNode)) {
234  latticeNodes.sort(
235  [](const LatticeNode &lhs, const LatticeNode &rhs) {
236  return lhs.score() > rhs.score();
237  });
238  }
239  return true;
240  };
241 
242  graph.bfs(start, updateForNode);
243  updateForNode(graph, nullptr);
244 }
245 
246 std::string concatNBest(NBestNode *node, std::string_view sep = "") {
247  std::string result;
248  while (node) {
249  result.append(node->node_->word());
250  result.append(sep.data(), sep.size());
251  node = node->next_.get();
252  }
253  return result;
254 }
255 
256 void DecoderPrivate::backwardSearch(const SegmentGraph &graph, Lattice &l,
257  size_t nbest, float max, float min,
258  size_t beamSize) const {
259  auto &lattice = l.d_ptr->lattice_;
260  State state;
261  // backward search
262  assert(lattice[&graph.start()].size() == 1);
263  assert(lattice[nullptr].size() == 1);
264  auto *pos = &lattice[nullptr][0];
265  l.d_ptr->nbests_.push_back(pos->toSentenceResult());
266  if (nbest > 1) {
267  std::unordered_set<std::string> dup;
268  dup.insert(l.d_ptr->nbests_[0].toString());
269  using PriorityQueueType =
270  std::priority_queue<std::shared_ptr<NBestNode>,
271  std::vector<std::shared_ptr<NBestNode>>,
273  PriorityQueueType q;
274  PriorityQueueType result;
275 
276  auto *eos = &lattice[nullptr][0];
277  auto newNBestNode = [](const LatticeNode *node) {
278  return std::make_shared<NBestNode>(node);
279  };
280  q.push(newNBestNode(eos));
281  int acc = 0;
282  auto *bos = &lattice[&graph.start()][0];
283  while (!q.empty()) {
284  std::shared_ptr<NBestNode> node = q.top();
285  q.pop();
286  if (bos == node->node_) {
287  auto sentence = concatNBest(node.get());
288  if (dup.contains(sentence)) {
289  continue;
290  }
291 
292  if (eos->score() - node->fn_ > max) {
293  break;
294  }
295  result.push(node);
296  if (result.size() >= nbest) {
297  break;
298  }
299  dup.insert(sentence);
300  } else {
301  if (acc >= MAX_BACKWARD_SEARCH_SIZE) {
302  continue;
303  }
304  auto searchSize = beamSize;
305  if (searchSize) {
306  searchSize = std::min(searchSize,
307  lattice[node->node_->from()].size());
308  } else {
309  searchSize = lattice[node->node_->from()].size();
310  }
311  for (auto &from : lattice[node->node_->from()] |
312  std::views::take(searchSize)) {
313  auto score =
314  model_->score(from.state(), *node->node_, state) +
315  node->node_->cost();
316  if (&from != bos && score < min) {
317  continue;
318  }
319  std::shared_ptr<NBestNode> parent = newNBestNode(&from);
320  parent->gn_ = score + node->gn_;
321  parent->fn_ = parent->gn_ + parent->node_->score();
322  parent->next_ = node;
323 
324  if (eos->score() - parent->gn_ <= max) {
325  q.push(std::move(parent));
326  acc++;
327  if (acc >= MAX_BACKWARD_SEARCH_SIZE) {
328  break;
329  }
330  }
331  }
332  }
333  }
334 
335  while (!result.empty()) {
336  auto node = result.top();
337  result.pop();
338  // loop twice to avoid problem
339  size_t count = 0;
340  // skip bos
341  auto pivot = node->next_;
342  while (pivot) {
343  pivot = pivot->next_;
344  count++;
345  }
346  SentenceResult::Sentence result;
347  result.reserve(count);
348  pivot = node->next_;
349  while (pivot) {
350  if (pivot->node_->to()) {
351  result.emplace_back(pivot->node_);
352  }
353  pivot = pivot->next_;
354  }
355  l.d_ptr->nbests_.emplace_back(std::move(result), node->fn_);
356  }
357  }
358 }
359 
360 Decoder::Decoder(const Dictionary *dict, const LanguageModelBase *model)
361  : d_ptr(std::make_unique<DecoderPrivate>(dict, model)) {}
362 
363 Decoder::~Decoder() {}
364 
365 const Dictionary *Decoder::dict() const {
366  FCITX_D();
367  return d->dict_;
368 }
369 
370 const LanguageModelBase *Decoder::model() const {
371  FCITX_D();
372  return d->model_;
373 }
374 
375 bool Decoder::decode(Lattice &l, const SegmentGraph &graph, size_t nbest,
376  const State &beginState, float max, float min,
377  size_t beamSize, size_t frameSize, void *helper) const {
378  FCITX_D();
379  LatticeMap &lattice = l.d_ptr->lattice_;
380  // Clear the result.
381  l.d_ptr->nbests_.clear();
382  // Remove end node.
383  lattice.erase(nullptr);
384  std::unordered_set<const SegmentGraphNode *> ignore;
385  // Add existing SegmentGraphNode to ignore set.
386  for (auto &p : lattice) {
387  ignore.insert(p.first);
388  }
389 
390  auto t0 = std::chrono::high_resolution_clock::now();
391 
392  if (!d->buildLattice(this, l, ignore, beginState, graph, frameSize,
393  helper)) {
394  return false;
395  }
396  LIBIME_DEBUG() << "Build Lattice: " << millisecondsTill(t0);
397  d->forwardSearch(this, graph, l, ignore, beamSize);
398  LIBIME_DEBUG() << "Forward Search: " << millisecondsTill(t0);
399  d->backwardSearch(graph, l, nbest, max, min, beamSize);
400  LIBIME_DEBUG() << "Backward Search: " << millisecondsTill(t0);
401  return true;
402 }
403 
404 LatticeNode *Decoder::createLatticeNodeImpl(
405  const SegmentGraphBase & /*unused*/, const LanguageModelBase * /*unused*/,
406  std::string_view word, WordIndex idx, SegmentGraphPath path,
407  const State &state, float cost, std::unique_ptr<LatticeNodeData> /*unused*/,
408  bool /*unused*/) const {
409  return new LatticeNode(word, idx, std::move(path), state, cost);
410 }
411 } // namespace libime