10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H 24 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
35 typedef typename LhsXprType::Nested LhsNested;
36 typedef typename RhsXprType::Nested RhsNested;
49 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
55 template<
typename Dimensions,
typename LhsXprType,
typename RhsXprType>
61 template<
typename Indices_,
typename LeftArgType_,
typename RightArgType_,
typename Device_>
63 typedef Indices_ Indices;
64 typedef LeftArgType_ LeftArgType;
65 typedef RightArgType_ RightArgType;
66 typedef Device_ Device;
74 template<
typename Indices,
typename LhsXprType,
typename RhsXprType>
80 typename RhsXprType::CoeffReturnType>::ResScalar CoeffReturnType;
86 const LhsXprType& lhs,
const RhsXprType& rhs,
const Indices& dims)
87 : m_lhs_xpr(lhs), m_rhs_xpr(rhs), m_indices(dims) {}
90 const Indices& indices()
const {
return m_indices; }
99 rhsExpression()
const {
return m_rhs_xpr; }
102 typename LhsXprType::Nested m_lhs_xpr;
103 typename RhsXprType::Nested m_rhs_xpr;
104 const Indices m_indices;
108 template<
typename Derived>
118 typedef typename XprType::Index
Index;
119 typedef typename XprType::CoeffReturnType CoeffReturnType;
139 static const int LDims =
141 static const int RDims =
144 static const int NumDims = LDims + RDims - 2 * ContractDims;
152 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
154 : m_leftImpl(choose(
Cond<static_cast<int>(Layout) == static_cast<int>(
ColMajor)>(),
156 m_rightImpl(choose(
Cond<static_cast<int>(Layout) == static_cast<int>(
ColMajor)>(),
162 YOU_MADE_A_PROGRAMMING_MISTAKE);
168 if (static_cast<int>(Layout) == static_cast<int>(
ColMajor)) {
170 for (
int i = 0; i < LDims; i++) {
171 eval_left_dims[i] = m_leftImpl.dimensions()[i];
173 for (
int i = 0; i < RDims; i++) {
174 eval_right_dims[i] = m_rightImpl.dimensions()[i];
177 for (
int i = 0; i < ContractDims; i++) {
178 eval_op_indices[i].first = op.indices()[i].first;
179 eval_op_indices[i].second = op.indices()[i].second;
183 for (
int i = 0; i < LDims; i++) {
184 eval_left_dims[i] = m_leftImpl.dimensions()[LDims - i - 1];
186 for (
int i = 0; i < RDims; i++) {
187 eval_right_dims[i] = m_rightImpl.dimensions()[RDims - i - 1];
191 for (
int i = 0; i < ContractDims; i++) {
192 eval_op_indices[i].first = LDims - 1 - op.indices()[ContractDims - 1 - i].second;
193 eval_op_indices[i].second = RDims - 1 - op.indices()[ContractDims - 1 - i].first;
199 for (
int i = 0; i < ContractDims; i++) {
200 for (
int j = i + 1; j < ContractDims; j++) {
201 eigen_assert(eval_op_indices[j].first != eval_op_indices[i].first &&
202 eval_op_indices[j].second != eval_op_indices[i].second &&
203 "contraction axes should be unique");
204 if (eval_op_indices[j].first < eval_op_indices[i].first) {
205 numext::swap(eval_op_indices[j], eval_op_indices[i]);
212 for (
int i = 0; i < LDims-1; ++i) {
213 lhs_strides[i+1] = lhs_strides[i] * eval_left_dims[i];
218 for (
int i = 0; i < RDims-1; ++i) {
219 rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i];
222 if (m_i_strides.size() > 0) m_i_strides[0] = 1;
223 if (m_j_strides.size() > 0) m_j_strides[0] = 1;
224 if (m_k_strides.size() > 0) m_k_strides[0] = 1;
234 m_lhs_inner_dim_contiguous =
true;
236 unsigned int nocontract_idx = 0;
238 for (
int i = 0; i < LDims; i++) {
240 bool contracting =
false;
241 for (
int j = 0; j < ContractDims; j++) {
242 if (eval_op_indices[j].first == i) {
249 m_dimensions[dim_idx] = eval_left_dims[i];
250 m_left_nocontract_strides[nocontract_idx] = lhs_strides[i];
252 m_lhs_inner_dim_contiguous =
false;
255 m_i_strides[nocontract_idx+1] =
256 m_i_strides[nocontract_idx] * eval_left_dims[i];
258 m_i_size = m_i_strides[nocontract_idx] * eval_left_dims[i];
266 for (
int i = 0; i < RDims; i++) {
267 bool contracting =
false;
269 for (
int j = 0; j < ContractDims; j++) {
270 if (eval_op_indices[j].second == i) {
276 m_dimensions[dim_idx] = eval_right_dims[i];
278 m_j_strides[nocontract_idx+1] =
279 m_j_strides[nocontract_idx] * eval_right_dims[i];
281 m_j_size = m_j_strides[nocontract_idx] * eval_right_dims[i];
283 m_right_nocontract_strides[nocontract_idx] = rhs_strides[i];
294 m_rhs_inner_dim_contiguous =
true;
295 m_rhs_inner_dim_reordered =
false;
296 for (
int i = 0; i < ContractDims; i++) {
297 Index left = eval_op_indices[i].first;
298 Index right = eval_op_indices[i].second;
300 Index size = eval_left_dims[left];
301 eigen_assert(size == eval_right_dims[right] &&
302 "Contraction axes must be same size");
305 m_k_strides[i+1] = m_k_strides[i] * size;
307 m_k_size = m_k_strides[i] * size;
309 m_left_contracting_strides[i] = lhs_strides[left];
310 m_right_contracting_strides[i] = rhs_strides[right];
312 if (i > 0 && right < eval_op_indices[i-1].second) {
313 m_rhs_inner_dim_reordered =
true;
316 m_rhs_inner_dim_contiguous =
false;
321 if (static_cast<int>(Layout) == static_cast<int>(
RowMajor)) {
322 for (
int i = 0, j = NumDims - 1; i < j; i++, j--) {
323 numext::swap(m_dimensions[i], m_dimensions[j]);
328 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
const Dimensions& dimensions()
const {
return m_dimensions; }
330 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
bool evalSubExprsIfNeeded(Scalar* data) {
331 m_leftImpl.evalSubExprsIfNeeded(NULL);
332 m_rightImpl.evalSubExprsIfNeeded(NULL);
337 m_result =
static_cast<Scalar *
>(m_device.allocate(dimensions().TotalSize() *
sizeof(Scalar)));
343 EIGEN_DEVICE_FUNC
void evalTo(Scalar* buffer)
const {
344 if (this->m_lhs_inner_dim_contiguous) {
345 if (this->m_rhs_inner_dim_contiguous) {
346 if (this->m_rhs_inner_dim_reordered) {
347 static_cast<const Derived*
>(
this)->
template evalProduct<true, true, true, Unaligned>(buffer);
350 static_cast<const Derived*
>(
this)->
template evalProduct<true, true, false, Unaligned>(buffer);
354 if (this->m_rhs_inner_dim_reordered) {
355 static_cast<const Derived*
>(
this)->
template evalProduct<true, false, true, Unaligned>(buffer);
358 static_cast<const Derived*
>(
this)->
template evalProduct<true, false, false, Unaligned>(buffer);
363 if (this->m_rhs_inner_dim_contiguous) {
364 if (this->m_rhs_inner_dim_reordered) {
365 static_cast<const Derived*
>(
this)->
template evalProduct<false, true, true, Unaligned>(buffer);
368 static_cast<const Derived*
>(
this)->
template evalProduct<false, true, false, Unaligned>(buffer);
372 if (this->m_rhs_inner_dim_reordered) {
373 static_cast<const Derived*
>(
this)->
template evalProduct<false, false, true, Unaligned>(buffer);
376 static_cast<const Derived*
>(
this)->
template evalProduct<false, false, false, Unaligned>(buffer);
382 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
383 EIGEN_DEVICE_FUNC
void evalGemv(Scalar* buffer)
const {
384 const Index rows = m_i_size;
385 const Index cols = m_k_size;
396 LeftEvaluator, left_nocontract_t,
397 contract_t, lhs_packet_size,
398 lhs_inner_dim_contiguous,
399 false, lhs_alignment> LhsMapper;
402 RightEvaluator, right_nocontract_t,
403 contract_t, rhs_packet_size,
404 rhs_inner_dim_contiguous,
405 rhs_inner_dim_reordered, rhs_alignment> RhsMapper;
407 LhsMapper lhs(m_leftImpl, m_left_nocontract_strides, m_i_strides,
408 m_left_contracting_strides, m_k_strides);
409 RhsMapper rhs(m_rightImpl, m_right_nocontract_strides, m_j_strides,
410 m_right_contracting_strides, m_k_strides);
412 const Scalar alpha(1);
413 const Index resIncr(1);
416 m_device.memset(buffer, 0, rows *
sizeof(Scalar));
419 rows, cols, lhs, rhs,
420 buffer, resIncr, alpha);
423 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
424 EIGEN_DEVICE_FUNC
void evalGemm(Scalar* buffer)
const {
426 const Index k = this->m_k_size;
429 const Index m = this->m_i_size;
432 const Index n = this->m_j_size;
435 this->m_device.memset(buffer, 0, m * n *
sizeof(Scalar));
442 const Index nr = Traits::nr;
443 const Index mr = Traits::mr;
452 LeftEvaluator, left_nocontract_t,
453 contract_t, lhs_packet_size,
454 lhs_inner_dim_contiguous,
458 RightEvaluator, right_nocontract_t,
459 contract_t, rhs_packet_size,
460 rhs_inner_dim_contiguous,
461 rhs_inner_dim_reordered,
Unaligned> RhsMapper;
472 LhsMapper lhs(this->m_leftImpl, this->m_left_nocontract_strides, this->m_i_strides,
473 this->m_left_contracting_strides, this->m_k_strides);
475 RhsMapper rhs(this->m_rightImpl, this->m_right_nocontract_strides, this->m_j_strides,
476 this->m_right_contracting_strides, this->m_k_strides);
478 OutputMapper output(buffer, m);
482 const Index kc = blocking.kc();
483 const Index mc = numext::mini(m, blocking.mc());
484 const Index nc = numext::mini(n, blocking.nc());
485 const Index sizeA = mc * kc;
486 const Index sizeB = kc * nc;
488 LhsScalar* blockA =
static_cast<LhsScalar *
>(this->m_device.allocate(sizeA *
sizeof(LhsScalar)));
489 RhsScalar* blockB =
static_cast<RhsScalar *
>(this->m_device.allocate(sizeB *
sizeof(RhsScalar)));
491 for(Index i2=0; i2<m; i2+=mc)
493 const Index actual_mc = numext::mini(i2+mc,m)-i2;
494 for (Index k2 = 0; k2 < k; k2 += kc) {
496 const Index actual_kc = numext::mini(k2 + kc, k) - k2;
497 pack_lhs(blockA, lhs.getSubMapper(i2, k2), actual_kc, actual_mc, 0, 0);
500 for (Index j2 = 0; j2 < n; j2 += nc) {
502 const Index actual_nc = numext::mini(j2 + nc, n) - j2;
503 pack_rhs(blockB, rhs.getSubMapper(k2, j2), actual_kc, actual_nc, 0, 0);
507 gebp(output.getSubMapper(i2, j2), blockA, blockB, actual_mc, actual_kc, actual_nc, Scalar(1), -1, -1, 0, 0);
512 this->m_device.deallocate(blockA);
513 this->m_device.deallocate(blockB);
516 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
void cleanup() {
517 m_leftImpl.cleanup();
518 m_rightImpl.cleanup();
520 if (m_result != NULL) {
521 m_device.deallocate(m_result);
526 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index)
const {
527 return m_result[index];
530 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
TensorOpCost costPerCoeff(
bool)
const {
534 template<
int LoadMode>
535 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index)
const {
536 return internal::ploadt<PacketReturnType, LoadMode>(m_result + index);
539 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar* data()
const {
return m_result; }
543 TensorContractionEvaluatorBase& operator = (
const TensorContractionEvaluatorBase&);
544 Dimensions m_dimensions;
546 contract_t m_k_strides;
547 contract_t m_left_contracting_strides;
548 contract_t m_right_contracting_strides;
550 bool m_lhs_inner_dim_contiguous;
551 bool m_rhs_inner_dim_contiguous;
552 bool m_rhs_inner_dim_reordered;
554 left_nocontract_t m_i_strides;
555 right_nocontract_t m_j_strides;
556 left_nocontract_t m_left_nocontract_strides;
557 right_nocontract_t m_right_nocontract_strides;
565 const Device& m_device;
571 template<
typename Indices,
typename LeftArgType,
typename RightArgType,
typename Device>
574 TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgType>, Device> > {
580 typedef typename XprType::Index
Index;
581 typedef typename XprType::CoeffReturnType CoeffReturnType;
597 static const int LDims =
599 static const int RDims =
607 static const int NumDims = LDims + RDims - 2 * ContractDims;
612 EIGEN_DEVICE_FUNC
TensorEvaluator(
const XprType& op,
const Device& device) :
615 template <
bool lhs_inner_dim_contiguous,
bool rhs_inner_dim_contiguous,
bool rhs_inner_dim_reordered,
int Alignment>
616 EIGEN_DEVICE_FUNC
void evalProduct(Scalar* buffer)
const {
617 if (this->m_j_size == 1) {
618 this->
template evalGemv<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
622 this->
template evalGemm<lhs_inner_dim_contiguous, rhs_inner_dim_contiguous, rhs_inner_dim_reordered, Alignment>(buffer);
628 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_H Definition: TensorCostModel.h:25
Storage order is column major (see TopicStorageOrders).
Definition: Constants.h:320
Definition: BlasUtil.h:28
Definition: XprHelper.h:158
Definition: GeneralBlockPanelKernel.h:19
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:85
A cost model used to limit the number of threads used for evaluating tensor expression.
Definition: TensorEvaluator.h:28
Data pointer has no specific alignment.
Definition: Constants.h:228
Definition: TensorContractionBlocking.h:25
Definition: TensorContraction.h:572
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:33
Definition: TensorContraction.h:75
Definition: Constants.h:235
Definition: GeneralBlockPanelKernel.h:859
Definition: BlasUtil.h:40
Definition: BlasUtil.h:192
The tensor base class.
Definition: TensorBase.h:827
EIGEN_DEVICE_FUNC const internal::remove_all< typename LhsXprType::Nested >::type & lhsExpression() const
Definition: TensorContraction.h:95
Definition: BandTriangularSolver.h:13
Definition: TensorContraction.h:109
Storage order is row major (see TopicStorageOrders).
Definition: Constants.h:322
Definition: TensorTraits.h:170
The type used to identify a dense storage.
Definition: Constants.h:491
Generic expression where a coefficient-wise unary operator is applied to an expression.
Definition: CwiseUnaryOp.h:55
Definition: TensorMeta.h:15
Definition: ForwardDeclarations.h:17
Definition: XprHelper.h:312
Definition: EmulateArray.h:203
Definition: BlasUtil.h:25