17 #include <string_view> 19 #include <unordered_map> 20 #include <unordered_set> 23 #include <boost/container_hash/hash.hpp> 24 #include <boost/functional/hash.hpp> 25 #include <boost/ptr_container/ptr_vector.hpp> 26 #include <fcitx-utils/macros.h> 27 #include "languagemodel.h" 29 #include "lattice_p.h" 30 #include "segmentgraph.h" 36 constexpr
int MAX_BACKWARD_SEARCH_SIZE = 10000;
44 float fn_ = -std::numeric_limits<float>::max();
45 std::shared_ptr<NBestNode> next_;
50 bool operator()(
const T &lhs,
const T &rhs)
const {
51 return lhs->fn_ < rhs->fn_;
58 : dict_(dict), model_(model) {}
63 const std::unordered_set<const SegmentGraphNode *> &ignore,
65 size_t frameSize,
void *helper)
const;
69 const std::unordered_set<const SegmentGraphNode *> &ignore,
70 size_t beamSize)
const;
72 float max,
float min,
size_t beamSize)
const;
78 bool DecoderPrivate::buildLattice(
80 const std::unordered_set<const SegmentGraphNode *> &ignore,
81 const State &state,
const SegmentGraph &graph,
size_t frameSize,
83 LatticeMap &lattice = l.d_ptr->lattice_;
86 if (!lattice.contains(&graph.start())) {
87 lattice[&graph.start()].push_back(
88 q->createLatticeNode(graph, model_,
"", model_->beginSentence(),
89 {
nullptr, &graph.start()}, state, 0));
94 std::pair<const SegmentGraphNode *, const SegmentGraphNode *>,
95 std::vector<std::unique_ptr<LatticeNode>>,
97 std::pair<const SegmentGraphNode *, const SegmentGraphNode *>>>
100 auto dictMatchCallback = [
this, &graph, &frames, q, frameSize](
101 const SegmentGraphPath &path,
WordNode &word,
103 std::unique_ptr<LatticeNodeData> data) {
104 if (InvalidWordIndex == word.idx()) {
105 auto idx = model_->index(word.word());
108 assert(path.front());
109 auto &frame = frames[std::make_pair(path.front(), path.back())];
110 const bool applyFrameSize =
111 path.front() != &graph.start() && frameSize > 0;
112 auto *node = q->createLatticeNode(
113 graph, model_, word.word(), word.idx(), path, model_->nullState(),
114 adjust, std::move(data), frame.empty());
119 frame.emplace_back(node);
120 if (!applyFrameSize) {
124 auto scoreGreaterThan = [](
const auto &lhs,
const auto &rhs) {
125 return lhs->score() > rhs->score();
128 if (frame.size() == frameSize) {
129 for (
auto &n : frame) {
131 n->setScore(model_->singleWordScore(n->word()) + n->cost());
133 std::make_heap(frame.begin(), frame.end(), scoreGreaterThan);
134 }
else if (frame.size() == frameSize + 1) {
136 node->setScore(model_->singleWordScore(node->word()) +
139 if (scoreGreaterThan(node, frame[0])) {
140 std::push_heap(frame.begin(), frame.end(), scoreGreaterThan);
141 std::pop_heap(frame.begin(), frame.end(), scoreGreaterThan);
147 dict_->matchPrefix(graph, dictMatchCallback, ignore, helper);
149 for (
auto &[path, nodes] : frames) {
150 auto &latticeUnit = lattice[path.second];
151 for (
auto &node : nodes) {
152 latticeUnit.push_back(node.release());
155 if (!lattice.contains(&graph.end())) {
160 lattice[
nullptr].push_back(
161 q->createLatticeNode(graph, model_,
"", model_->endSentence(),
162 {&graph.end(),
nullptr}, model_->nullState()));
166 void DecoderPrivate::forwardSearch(
168 const std::unordered_set<const SegmentGraphNode *> &ignore,
169 size_t beamSize)
const {
171 LatticeMap &lattice = l.d_ptr->lattice_;
173 std::tuple<float, LatticeNode *, State>>
175 const auto *start = &graph.start();
178 const SegmentGraphNode *graphNode) {
179 if (graphNode == start || !lattice.contains(graphNode) ||
180 ignore.contains(graphNode)) {
183 auto &latticeNodes = lattice[graphNode];
184 for (
auto &node : latticeNodes) {
185 const auto *from = node.from();
186 assert(graph.checkNodeInGraph(from));
187 float maxScore = -std::numeric_limits<float>::max();
190 bool isUnknown = model_->isNodeUnknown(node);
192 auto iter = unknownIdCache.find(from);
193 if (iter != unknownIdCache.end()) {
194 std::tie(maxScore, maxNode, maxState) = iter->second;
199 auto iter = lattice.find(from);
201 if (iter == lattice.end()) {
204 auto &searchFrom = iter->second;
205 auto searchSize = beamSize;
207 searchSize = std::min(searchSize, lattice[from].size());
209 searchSize = lattice[from].size();
211 for (
auto &parent : searchFrom | std::views::take(searchSize)) {
212 auto score = parent.score() +
213 model_->score(parent.state(), node, state);
214 if (score > maxScore) {
222 unknownIdCache.emplace(
223 std::piecewise_construct, std::forward_as_tuple(from),
224 std::forward_as_tuple(maxScore, maxNode, maxState));
229 node.setScore(maxScore + node.cost());
230 node.setPrev(maxNode);
231 node.state() = maxState;
233 if (q->needSort(graph, graphNode)) {
236 return lhs.score() > rhs.score();
242 graph.bfs(start, updateForNode);
243 updateForNode(graph,
nullptr);
246 std::string concatNBest(
NBestNode *node, std::string_view sep =
"") {
249 result.append(node->node_->word());
250 result.append(sep.data(), sep.size());
251 node = node->next_.get();
257 size_t nbest,
float max,
float min,
258 size_t beamSize)
const {
259 auto &lattice = l.d_ptr->lattice_;
262 assert(lattice[&graph.start()].size() == 1);
263 assert(lattice[
nullptr].size() == 1);
264 auto *pos = &lattice[
nullptr][0];
265 l.d_ptr->nbests_.push_back(pos->toSentenceResult());
267 std::unordered_set<std::string> dup;
268 dup.insert(l.d_ptr->nbests_[0].toString());
269 using PriorityQueueType =
270 std::priority_queue<std::shared_ptr<NBestNode>,
271 std::vector<std::shared_ptr<NBestNode>>,
274 PriorityQueueType result;
276 auto *eos = &lattice[
nullptr][0];
278 return std::make_shared<NBestNode>(node);
280 q.push(newNBestNode(eos));
282 auto *bos = &lattice[&graph.start()][0];
284 std::shared_ptr<NBestNode> node = q.top();
286 if (bos == node->node_) {
287 auto sentence = concatNBest(node.get());
288 if (dup.contains(sentence)) {
292 if (eos->score() - node->fn_ > max) {
296 if (result.size() >= nbest) {
299 dup.insert(sentence);
301 if (acc >= MAX_BACKWARD_SEARCH_SIZE) {
304 auto searchSize = beamSize;
306 searchSize = std::min(searchSize,
307 lattice[node->node_->from()].size());
309 searchSize = lattice[node->node_->from()].size();
311 for (
auto &from : lattice[node->node_->from()] |
312 std::views::take(searchSize)) {
314 model_->score(from.state(), *node->node_, state) +
316 if (&from != bos && score < min) {
319 std::shared_ptr<NBestNode> parent = newNBestNode(&from);
320 parent->gn_ = score + node->gn_;
321 parent->fn_ = parent->gn_ + parent->node_->score();
322 parent->next_ = node;
324 if (eos->score() - parent->gn_ <= max) {
325 q.push(std::move(parent));
327 if (acc >= MAX_BACKWARD_SEARCH_SIZE) {
335 while (!result.empty()) {
336 auto node = result.top();
341 auto pivot = node->next_;
343 pivot = pivot->next_;
346 SentenceResult::Sentence result;
347 result.reserve(count);
350 if (pivot->node_->to()) {
351 result.emplace_back(pivot->node_);
353 pivot = pivot->next_;
355 l.d_ptr->nbests_.emplace_back(std::move(result), node->fn_);
361 : d_ptr(std::make_unique<DecoderPrivate>(dict, model)) {}
363 Decoder::~Decoder() {}
376 const State &beginState,
float max,
float min,
377 size_t beamSize,
size_t frameSize,
void *helper)
const {
379 LatticeMap &lattice = l.d_ptr->lattice_;
381 l.d_ptr->nbests_.clear();
383 lattice.erase(
nullptr);
384 std::unordered_set<const SegmentGraphNode *> ignore;
386 for (
auto &p : lattice) {
387 ignore.insert(p.first);
390 auto t0 = std::chrono::high_resolution_clock::now();
392 if (!d->buildLattice(
this, l, ignore, beginState, graph, frameSize,
396 LIBIME_DEBUG() <<
"Build Lattice: " << millisecondsTill(t0);
397 d->forwardSearch(
this, graph, l, ignore, beamSize);
398 LIBIME_DEBUG() <<
"Forward Search: " << millisecondsTill(t0);
399 d->backwardSearch(graph, l, nbest, max, min, beamSize);
400 LIBIME_DEBUG() <<
"Backward Search: " << millisecondsTill(t0);
406 std::string_view word, WordIndex idx, SegmentGraphPath path,
407 const State &state,
float cost, std::unique_ptr<LatticeNodeData> ,
409 return new LatticeNode(word, idx, std::move(path), state, cost);