11 #ifndef EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H 12 #define EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H 24 template<
typename XprType>
28 typedef typename XprTraits::StorageKind StorageKind;
29 typedef typename XprTraits::Index
Index;
31 typedef typename XprType::Nested Nested;
33 static const int NumDimensions = XprTraits::NumDimensions;
34 static const int Layout = XprTraits::Layout;
37 template<
typename XprType>
43 template<
typename XprType>
52 template<
typename XprType>
68 expression()
const {
return m_xpr; }
71 typename XprType::Nested m_xpr;
75 template<
typename ArgType,
typename Device>
79 typedef typename XprType::Index
Index;
80 typedef typename XprType::Scalar Scalar;
95 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
TensorEvaluator(
const XprType& op,
const Device& device)
96 : m_impl(op.expression(), device) { }
98 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
99 return m_impl.dimensions();
102 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(Scalar* ) {
103 m_impl.evalSubExprsIfNeeded(NULL);
106 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void cleanup() {
110 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const 112 return CoeffReturnType(index, m_impl.coeff(index));
116 costPerCoeff(
bool vectorized)
const {
117 return m_impl.costPerCoeff(vectorized) +
TensorOpCost(0, 0, 1);
120 EIGEN_DEVICE_FUNC Scalar* data()
const {
return NULL; }
134 template<
typename ReduceOp,
typename Dims,
typename XprType>
138 typedef typename XprTraits::StorageKind StorageKind;
139 typedef typename XprTraits::Index
Index;
140 typedef Index Scalar;
141 typedef typename XprType::Nested Nested;
144 static const int Layout = XprTraits::Layout;
147 template<
typename ReduceOp,
typename Dims,
typename XprType>
153 template<
typename ReduceOp,
typename Dims,
typename XprType>
162 template<
typename ReduceOp,
typename Dims,
typename XprType>
171 typedef Index CoeffReturnType;
174 const ReduceOp& reduce_op,
175 const int return_dim,
176 const Dims& reduce_dims)
177 : m_xpr(expr), m_reduce_op(reduce_op), m_return_dim(return_dim), m_reduce_dims(reduce_dims) {}
181 expression()
const {
return m_xpr; }
184 const ReduceOp& reduce_op()
const {
return m_reduce_op; }
187 const Dims& reduce_dims()
const {
return m_reduce_dims; }
190 int return_dim()
const {
return m_return_dim; }
193 typename XprType::Nested m_xpr;
194 const ReduceOp m_reduce_op;
195 const int m_return_dim;
196 const Dims m_reduce_dims;
200 template<
typename ReduceOp,
typename Dims,
typename ArgType,
typename Device>
204 typedef typename XprType::Index
Index;
205 typedef typename XprType::Scalar Scalar;
206 typedef typename XprType::CoeffReturnType CoeffReturnType;
215 PacketAccess =
false,
217 Layout = TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device>::Layout,
222 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
TensorEvaluator(
const XprType& op,
const Device& device)
223 : m_orig_impl(op.expression(), device),
224 m_impl(op.expression().index_tuples().reduce(op.reduce_dims(), op.reduce_op()), device),
225 m_return_dim(op.return_dim()) {
227 gen_strides(m_orig_impl.dimensions(), m_strides);
228 if (Layout == static_cast<int>(
ColMajor)) {
229 const Index total_size = internal::array_prod(m_orig_impl.dimensions());
230 m_stride_mod = (m_return_dim < NumDims - 1) ? m_strides[m_return_dim + 1] : total_size;
232 const Index total_size = internal::array_prod(m_orig_impl.dimensions());
233 m_stride_mod = (m_return_dim > 0) ? m_strides[m_return_dim - 1] : total_size;
235 m_stride_div = m_strides[m_return_dim];
238 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
239 return m_impl.dimensions();
242 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(Scalar* ) {
243 m_impl.evalSubExprsIfNeeded(NULL);
246 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void cleanup() {
250 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
251 const TupleType v = m_impl.coeff(index);
252 return (m_return_dim < 0) ? v.first : (v.first % m_stride_mod) / m_stride_div;
255 EIGEN_DEVICE_FUNC Scalar* data()
const {
return NULL; }
258 costPerCoeff(
bool vectorized)
const {
259 const double compute_cost = 1.0 +
260 (m_return_dim < 0 ? 0.0 : (TensorOpCost::ModCost<Index>() + TensorOpCost::DivCost<Index>()));
261 return m_orig_impl.costPerCoeff(vectorized) +
262 m_impl.costPerCoeff(vectorized) +
TensorOpCost(0, 0, compute_cost);
266 EIGEN_DEVICE_FUNC
void gen_strides(
const InputDimensions& dims, StrideDims& strides) {
267 if (m_return_dim < 0) {
270 eigen_assert(m_return_dim < NumDims &&
271 "Asking to convert index to a dimension outside of the rank");
275 if (Layout == static_cast<int>(
ColMajor)) {
277 for (
int i = 1; i < NumDims; ++i) {
278 strides[i] = strides[i-1] * dims[i-1];
281 strides[NumDims-1] = 1;
282 for (
int i = NumDims - 2; i >= 0; --i) {
283 strides[i] = strides[i+1] * dims[i+1];
289 TensorEvaluator<const TensorIndexTupleOp<ArgType>, Device> m_orig_impl;
290 TensorEvaluator<const TensorReductionOp<ReduceOp, Dims, const TensorIndexTupleOp<ArgType> >, Device> m_impl;
291 const int m_return_dim;
292 StrideDims m_strides;
299 #endif // EIGEN_CXX11_TENSOR_TENSOR_ARG_MAX_H Definition: TensorCostModel.h:25
Storage order is column major (see TopicStorageOrders).
Definition: Constants.h:320
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:85
A cost model used to limit the number of threads used for evaluating tensor expression.
Definition: TensorEvaluator.h:28
Definition: TensorArgMax.h:163
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:33
The tensor base class.
Definition: TensorBase.h:827
Definition: BandTriangularSolver.h:13
Definition: TensorTraits.h:170
The type used to identify a dense storage.
Definition: Constants.h:491
Generic expression where a coefficient-wise unary operator is applied to an expression.
Definition: CwiseUnaryOp.h:55
Definition: TensorArgMax.h:53
Definition: TensorMeta.h:110
Definition: ForwardDeclarations.h:17
Definition: XprHelper.h:312
Definition: EmulateArray.h:203