12 #ifndef MLPACK_METHODS_ANN_LAYER_LINEAR3D_IMPL_HPP 13 #define MLPACK_METHODS_ANN_LAYER_LINEAR3D_IMPL_HPP 21 template<
typename InputDataType,
typename OutputDataType,
22 typename RegularizerType>
30 template<
typename InputDataType,
typename OutputDataType,
31 typename RegularizerType>
35 RegularizerType regularizer) :
38 regularizer(regularizer)
40 weights.set_size(outSize * inSize + outSize, 1);
43 template<
typename InputDataType,
typename OutputDataType,
44 typename RegularizerType>
48 outSize(layer.outSize),
49 weights(layer.weights),
50 regularizer(layer.regularizer)
55 template<
typename InputDataType,
typename OutputDataType,
56 typename RegularizerType>
61 weights(
std::move(layer.weights)),
62 regularizer(
std::move(layer.regularizer))
67 template<
typename InputDataType,
typename OutputDataType,
68 typename RegularizerType>
75 inSize = layer.inSize;
76 outSize = layer.outSize;
77 weights = layer.weights;
78 regularizer = layer.regularizer;
83 template<
typename InputDataType,
typename OutputDataType,
84 typename RegularizerType>
93 weights = std::move(layer.weights);
94 regularizer = std::move(layer.regularizer);
99 template<
typename InputDataType,
typename OutputDataType,
100 typename RegularizerType>
103 typedef typename arma::Mat<typename OutputDataType::elem_type> MatType;
105 weight = MatType(weights.memptr(), outSize, inSize,
false,
false);
106 bias = MatType(weights.memptr() + weight.n_elem, outSize, 1,
false,
false);
109 template<
typename InputDataType,
typename OutputDataType,
110 typename RegularizerType>
111 template<
typename eT>
113 const arma::Mat<eT>& input, arma::Mat<eT>& output)
115 typedef typename arma::Mat<eT> MatType;
116 typedef typename arma::Cube<eT> CubeType;
118 if (input.n_rows % inSize != 0)
120 Log::Fatal <<
"Number of features in the input must be divisible by inSize." 124 const size_t nPoints = input.n_rows / inSize;
125 const size_t batchSize = input.n_cols;
127 output.set_size(outSize * nPoints, batchSize);
129 const CubeType inputTemp(const_cast<MatType&>(input).memptr(), inSize,
130 nPoints, batchSize,
false,
false);
132 for (
size_t i = 0; i < batchSize; ++i)
136 MatType z = weight * inputTemp.slice(i);
137 z.each_col() += bias;
138 output.col(i) = arma::vectorise(z);
142 template<
typename InputDataType,
typename OutputDataType,
143 typename RegularizerType>
144 template<
typename eT>
146 const arma::Mat<eT>& ,
147 const arma::Mat<eT>& gy,
150 typedef typename arma::Mat<eT> MatType;
151 typedef typename arma::Cube<eT> CubeType;
153 if (gy.n_rows % outSize != 0)
155 Log::Fatal <<
"Number of rows in propagated error must be divisible by \ 156 outSize." << std::endl;
159 const size_t nPoints = gy.n_rows / outSize;
160 const size_t batchSize = gy.n_cols;
162 const CubeType gyTemp(const_cast<MatType&>(gy).memptr(), outSize,
163 nPoints, batchSize,
false,
false);
165 g.set_size(inSize * nPoints, batchSize);
167 for (
size_t i = 0; i < gyTemp.n_slices; ++i)
171 g.col(i) = arma::vectorise(weight.t() * gyTemp.slice(i));
175 template<
typename InputDataType,
typename OutputDataType,
176 typename RegularizerType>
177 template<
typename eT>
179 const arma::Mat<eT>& input,
180 const arma::Mat<eT>& error,
181 arma::Mat<eT>& gradient)
183 typedef typename arma::Mat<eT> MatType;
184 typedef typename arma::Cube<eT> CubeType;
186 if (error.n_rows % outSize != 0)
187 Log::Fatal <<
"Propagated error matrix has invalid dimension!" << std::endl;
189 const size_t nPoints = input.n_rows / inSize;
190 const size_t batchSize = input.n_cols;
192 const CubeType inputTemp(const_cast<MatType&>(input).memptr(), inSize,
193 nPoints, batchSize,
false,
false);
194 const CubeType errorTemp(const_cast<MatType&>(error).memptr(), outSize,
195 nPoints, batchSize,
false,
false);
197 CubeType dW(outSize, inSize, batchSize);
198 for (
size_t i = 0; i < batchSize; ++i)
202 dW.slice(i) = errorTemp.slice(i) * inputTemp.slice(i).t();
205 gradient.set_size(arma::size(weights));
207 gradient.submat(0, 0, weight.n_elem - 1, 0)
208 = arma::vectorise(arma::sum(dW, 2));
210 gradient.submat(weight.n_elem, 0, weights.n_elem - 1, 0)
211 = arma::vectorise(arma::sum(arma::sum(errorTemp, 2), 1));
213 regularizer.Evaluate(weights, gradient);
216 template<
typename InputDataType,
typename OutputDataType,
217 typename RegularizerType>
218 template<
typename Archive>
220 Archive& ar,
const uint32_t )
222 ar(CEREAL_NVP(inSize));
223 ar(CEREAL_NVP(outSize));
227 if (cereal::is_loading<Archive>())
228 weights.set_size(outSize * inSize + outSize, 1);
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: cv.hpp:1
Definition: pointer_wrapper.hpp:23
OutputDataType const & Gradient() const
Get the gradient.
Definition: linear3d.hpp:137
Implementation of the Linear3D layer class.
Definition: layer_types.hpp:112
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: linear3d_impl.hpp:145
Linear3D & operator=(const Linear3D &layer)
Copy assignment operator.
Definition: linear3d_impl.hpp:71
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: linear3d_impl.hpp:112
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: linear3d_impl.hpp:219
Linear3D()
Create the Linear3D object.
Definition: linear3d_impl.hpp:23