libime
zstdfilter.h
1 /*
2  * SPDX-FileCopyrightText: 2023-2023 CSSlayer <wengxt@gmail.com>
3  *
4  * SPDX-License-Identifier: LGPL-2.1-or-later
5  */
6 
7 #ifndef LIBIME_ZSTDFILTER_H
8 #define LIBIME_ZSTDFILTER_H
9 
10 #include <cstddef>
11 #include <cstring>
12 #include <ios>
13 #include <istream>
14 #include <ostream>
15 #include <stdexcept>
16 #include <boost/iostreams/categories.hpp>
17 #include <boost/iostreams/constants.hpp>
18 #include <boost/iostreams/filter/symmetric.hpp>
19 #include <boost/iostreams/filtering_streambuf.hpp>
20 #include <boost/iostreams/pipeline.hpp>
21 #include <boost/throw_exception.hpp>
22 #include <fcitx-utils/log.h>
23 #include <fcitx-utils/misc.h>
24 #include <zstd.h>
25 
26 namespace libime {
27 
28 class ZSTDError : public std::ios::failure {
29 public:
30  explicit ZSTDError(size_t error)
31  : std::ios::failure(ZSTD_getErrorName(error)), error_(error) {}
32 
33  size_t error() const { return error_; }
34  static void check(size_t error) {
35  if (ZSTD_isError(error)) {
36  boost::throw_exception(ZSTDError(error));
37  }
38  }
39 
40 private:
41  size_t error_;
42 };
43 
44 namespace details {
45 
46 enum class ZSTDResult {
47  StreamEnd,
48  Okay,
49 };
50 
52 public:
53  using char_type = char;
54  ZSTDFilterBase(const ZSTDFilterBase &) = delete;
55 
56 protected:
57  ZSTDFilterBase() = default;
58  ~ZSTDFilterBase() {}
59  void before(const char *&src_begin, const char *src_end, char *&dest_begin,
60  const char *dest_end) {
61  in_.src = src_begin;
62  in_.size = static_cast<size_t>(src_end - src_begin);
63  in_.pos = 0;
64  out_.dst = dest_begin;
65  out_.size = static_cast<size_t>(dest_end - dest_begin);
66  out_.pos = 0;
67  }
68  void after(const char *&src_begin, char *&dest_begin,
69  bool /*unused*/) const {
70  src_begin = reinterpret_cast<const char *>(in_.src) + in_.pos;
71  dest_begin = reinterpret_cast<char *>(out_.dst) + out_.pos;
72  }
73  void reset() {
74  memset(&in_, 0, sizeof(in_));
75  memset(&out_, 0, sizeof(out_));
76  eof_ = 0;
77  }
78 
79  ZSTD_inBuffer in_;
80  ZSTD_outBuffer out_;
81  int eof_ = 0;
82 };
83 
85 public:
86  ZSTDCompressorImpl() : cstream_(ZSTD_createCStream()) { reset(); }
87  bool filter(const char *&src_begin, const char *src_end, char *&dest_begin,
88  char *dest_end, bool flush) {
89  before(src_begin, src_end, dest_begin, dest_end);
90  auto result = deflate(flush);
91  after(src_begin, dest_begin, true);
92  return result != ZSTDResult::StreamEnd;
93  }
94  void close() { reset(); }
95 
96 private:
97  void reset() {
98  ZSTDFilterBase::reset();
99  ZSTDError::check(ZSTD_initCStream(cstream_.get(), 0));
100  // Enable checksum.
101  ZSTDError::check(
102  ZSTD_CCtx_setParameter(cstream_.get(), ZSTD_c_checksumFlag, 1));
103  }
104 
105  ZSTDResult deflate(bool finish) {
106  // Ignore spurious extra calls.
107  // Note size > 0 will trigger an error in this case.
108  if (eof_ && in_.size == 0) {
109  return ZSTDResult::StreamEnd;
110  }
111  size_t result = ZSTD_compressStream(cstream_.get(), &out_, &in_);
112  ZSTDError::check(result);
113  if (finish) {
114  result = ZSTD_endStream(cstream_.get(), &out_);
115  ZSTDError::check(result);
116  eof_ = result == 0;
117  return eof_ ? ZSTDResult::StreamEnd : ZSTDResult::Okay;
118  }
119  return ZSTDResult::Okay;
120  }
121 
122  fcitx::UniqueCPtr<ZSTD_CStream, &ZSTD_freeCStream> cstream_;
123 };
124 
126 public:
127  ZSTDDecompressorImpl() : dstream_(ZSTD_createDStream()) { reset(); }
128  bool filter(const char *&src_begin, const char *src_end, char *&dest_begin,
129  char *dest_end, bool flush) {
130  before(src_begin, src_end, dest_begin, dest_end);
131  auto result = inflate(flush);
132  after(src_begin, dest_begin, false);
133  return result != ZSTDResult::StreamEnd;
134  }
135  void close() { reset(); }
136 
137 private:
138  void reset() {
139  ZSTDFilterBase::reset();
140  ZSTDError::check(ZSTD_initDStream(dstream_.get()));
141  }
142 
143  ZSTDResult inflate(bool finish) {
144  // need loop since iostream code cannot handle short reads
145  do {
146  size_t result = ZSTD_decompressStream(dstream_.get(), &out_, &in_);
147  ZSTDError::check(result);
148  } while (in_.pos < in_.size && out_.pos < out_.size);
149  return finish && in_.size == 0 && out_.pos == 0 ? ZSTDResult::StreamEnd
150  : ZSTDResult::Okay;
151  }
152 
153  fcitx::UniqueCPtr<ZSTD_DStream, &ZSTD_freeDStream> dstream_;
154 };
155 
156 } // namespace details
157 
159  : boost::iostreams::symmetric_filter<details::ZSTDCompressorImpl> {
160 private:
162  using base_type = symmetric_filter<impl_type>;
163 
164 public:
165  using char_type = typename base_type::char_type;
166  using category = typename base_type::category;
167  ZSTDCompressor(std::streamsize buffer_size =
168  boost::iostreams::default_device_buffer_size)
169  : base_type(buffer_size) {}
170 };
171 BOOST_IOSTREAMS_PIPABLE(ZSTDCompressor, 0)
172 
174  : boost::iostreams::symmetric_filter<details::ZSTDDecompressorImpl> {
175 private:
177  using base_type = symmetric_filter<impl_type>;
178 
179 public:
180  using char_type = typename base_type::char_type;
181  using category = typename base_type::category;
182  ZSTDDecompressor(std::streamsize buffer_size =
183  boost::iostreams::default_device_buffer_size)
184  : base_type(buffer_size) {}
185 };
186 BOOST_IOSTREAMS_PIPABLE(ZSTDDecompressor, 0)
187 
188 template <typename Callback>
189 inline void readZSTDCompressed(std::istream &in, const Callback &callback) {
190  boost::iostreams::filtering_istreambuf compressBuf;
191  compressBuf.push(ZSTDDecompressor());
192  compressBuf.push(in);
193  std::istream compressIn(&compressBuf);
194  callback(compressIn);
195  // We don't want to read any data, but only trigger the zstd footer
196  // handling, which validates CRC.
197  compressIn.peek();
198  if (compressIn.bad()) {
199  throw std::invalid_argument("Failed to load dict data");
200  }
201 }
202 
203 template <typename Callback>
204 inline void writeZSTDCompressed(std::ostream &out, const Callback &callback) {
205  boost::iostreams::filtering_streambuf<boost::iostreams::output> compressBuf;
206  compressBuf.push(ZSTDCompressor());
207  compressBuf.push(out);
208  std::ostream compressOut(&compressBuf);
209  callback(compressOut);
210 }
211 
212 } // namespace libime
213 
214 #endif