10 #ifndef EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H 11 #define EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H 26 template <
typename Tensor,
bool HasRawAccess>
struct CoeffLoader {
31 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
CoeffLoader(
const Tensor& tensor) : m_tensor(tensor) { }
33 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void offsetBuffer(
typename Tensor::Index) {
34 eigen_assert(
false &&
"unsupported");
37 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
typename Tensor::Scalar coeff(
typename Tensor::Index index)
const {
return m_tensor.coeff(index); }
39 template<
int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
40 typename Tensor::PacketReturnType packet(
typename Tensor::Index index)
const 42 return m_tensor.template packet<LoadMode>(index);
55 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
CoeffLoader(
const Tensor& tensor) : m_data(tensor.data()) {}
57 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void offsetBuffer(
typename Tensor::Index offset) {
61 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
typename Tensor::Scalar coeff(
typename Tensor::Index index)
const {
return loadConstant(m_data+index); }
63 template<
int LoadMode> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
64 typename Tensor::PacketReturnType packet(
typename Tensor::Index index)
const 66 return internal::ploadt_ro<typename Tensor::PacketReturnType, LoadMode>(m_data + index);
69 typedef typename Tensor::Scalar Scalar;
73 template<
typename Scalar,
typename Index,
int side,
75 typename nocontract_t,
typename contract_t,
76 int packet_size,
bool inner_dim_contiguous,
int Alignment>
81 const nocontract_t& nocontract_strides,
82 const nocontract_t& ij_strides,
83 const contract_t& contract_strides,
84 const contract_t& k_strides) :
86 m_nocontract_strides(nocontract_strides),
87 m_ij_strides(ij_strides),
88 m_contract_strides(contract_strides),
89 m_k_strides(k_strides) { }
95 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void offsetBuffer(
typename Tensor::Index offset) {
96 m_tensor.offsetBuffer(offset);
100 EIGEN_STRONG_INLINE
void prefetch(Index ) { }
103 EIGEN_STRONG_INLINE Scalar operator()(Index row)
const {
105 return operator()(row, 0);
109 EIGEN_STRONG_INLINE Scalar operator()(Index row, Index col)
const {
110 return m_tensor.coeff(computeIndex(row, col));
114 EIGEN_STRONG_INLINE Index computeIndex(Index row, Index col)
const {
115 const bool left = (side == Lhs);
116 Index nocontract_val = left ? row : col;
119 const Index idx = nocontract_val / m_ij_strides[i];
120 linidx += idx * m_nocontract_strides[i];
121 nocontract_val -= idx * m_ij_strides[i];
124 if (side == Lhs && inner_dim_contiguous) {
125 eigen_assert(m_nocontract_strides[0] == 1);
126 linidx += nocontract_val;
128 linidx += nocontract_val * m_nocontract_strides[0];
132 Index contract_val = left ? col : row;
135 const Index idx = contract_val / m_k_strides[i];
136 linidx += idx * m_contract_strides[i];
137 contract_val -= idx * m_k_strides[i];
140 if (side == Rhs && inner_dim_contiguous) {
141 eigen_assert(m_contract_strides[0] == 1);
142 linidx += contract_val;
144 linidx += contract_val * m_contract_strides[0];
152 EIGEN_STRONG_INLINE
IndexPair<Index> computeIndexPair(Index row, Index col,
const Index distance)
const {
153 const bool left = (side == Lhs);
154 Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
155 Index linidx[2] = {0, 0};
158 const Index idx0 = nocontract_val[0] / m_ij_strides[i];
159 const Index idx1 = nocontract_val[1] / m_ij_strides[i];
160 linidx[0] += idx0 * m_nocontract_strides[i];
161 linidx[1] += idx1 * m_nocontract_strides[i];
162 nocontract_val[0] -= idx0 * m_ij_strides[i];
163 nocontract_val[1] -= idx1 * m_ij_strides[i];
165 if (side == Lhs && inner_dim_contiguous) {
166 eigen_assert(m_nocontract_strides[0] == 1);
167 linidx[0] += nocontract_val[0];
168 linidx[1] += nocontract_val[1];
170 linidx[0] += nocontract_val[0] * m_nocontract_strides[0];
171 linidx[1] += nocontract_val[1] * m_nocontract_strides[0];
175 Index contract_val[2] = {left ? col : row, left ? col : row + distance};
178 const Index idx0 = contract_val[0] / m_k_strides[i];
179 const Index idx1 = contract_val[1] / m_k_strides[i];
180 linidx[0] += idx0 * m_contract_strides[i];
181 linidx[1] += idx1 * m_contract_strides[i];
182 contract_val[0] -= idx0 * m_k_strides[i];
183 contract_val[1] -= idx1 * m_k_strides[i];
186 if (side == Rhs && inner_dim_contiguous) {
187 eigen_assert(m_contract_strides[0] == 1);
188 linidx[0] += contract_val[0];
189 linidx[1] += contract_val[1];
191 linidx[0] += contract_val[0] * m_contract_strides[0];
192 linidx[1] += contract_val[1] * m_contract_strides[0];
198 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index firstAligned(Index size)
const {
202 return (Alignment ==
Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size;
204 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride()
const {
210 const nocontract_t m_nocontract_strides;
211 const nocontract_t m_ij_strides;
212 const contract_t m_contract_strides;
213 const contract_t m_k_strides;
217 template<
typename Scalar,
typename Index,
int side,
219 typename nocontract_t,
typename contract_t,
220 int packet_size,
bool inner_dim_contiguous,
221 bool inner_dim_reordered,
int Alignment>
229 const nocontract_t& nocontract_strides,
230 const nocontract_t& ij_strides,
231 const contract_t& contract_strides,
232 const contract_t& k_strides) :
233 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
235 typedef typename Tensor::PacketReturnType Packet;
238 template <
int AlignmentType>
240 EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j)
const {
245 EIGEN_STATIC_ASSERT(packet_size % 2 == 0, YOU_MADE_A_PROGRAMMING_MISTAKE);
247 if (Tensor::PacketAccess && inner_dim_contiguous && !inner_dim_reordered) {
248 const Index index = this->computeIndex(i, j);
249 eigen_assert(this->computeIndex(i+packet_size-1, j) == index + packet_size-1);
250 return this->m_tensor.template packet<AlignmentType>(index);
253 const IndexPair<Index> indexPair = this->computeIndexPair(i, j, packet_size - 1);
254 const Index first = indexPair.first;
255 const Index last = indexPair.second;
261 if (Tensor::PacketAccess &&
263 (last - first) == (packet_size - 1)) {
265 return this->m_tensor.template packet<AlignmentType>(first);
268 EIGEN_ALIGN_MAX Scalar data[packet_size];
270 data[0] = this->m_tensor.coeff(first);
271 for (Index k = 1; k < packet_size - 1; k += 2) {
273 data[k] = this->m_tensor.coeff(internal_pair.first);
274 data[k + 1] = this->m_tensor.coeff(internal_pair.second);
276 data[packet_size - 1] = this->m_tensor.coeff(last);
278 return pload<Packet>(data);
281 template <
int AlignmentType>
283 EIGEN_STRONG_INLINE HalfPacket loadHalfPacket(Index i, Index j)
const {
288 if (half_packet_size == packet_size) {
289 return loadPacket<AlignmentType>(i, j);
291 EIGEN_ALIGN_MAX Scalar data[half_packet_size];
292 for (Index k = 0; k < half_packet_size; k++) {
293 data[k] = operator()(i + k, j);
295 return pload<HalfPacket>(data);
300 template<
typename Scalar,
typename Index,
int side,
302 typename nocontract_t,
typename contract_t,
303 bool inner_dim_contiguous,
304 bool inner_dim_reordered,
int Alignment>
305 class BaseTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, Alignment> :
public SimpleTensorContractionMapper<Scalar, Index, side, Tensor, nocontract_t, contract_t, 1, inner_dim_contiguous, Alignment>
312 const nocontract_t& nocontract_strides,
313 const nocontract_t& ij_strides,
314 const contract_t& contract_strides,
315 const contract_t& k_strides) :
316 ParentMapper(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
318 typedef typename Tensor::PacketReturnType Packet;
319 template <
int> EIGEN_DEVICE_FUNC
320 EIGEN_STRONG_INLINE Packet loadPacket(Index i, Index j)
const {
321 EIGEN_ALIGN_MAX Scalar data[1];
322 data[0] = this->m_tensor.coeff(this->computeIndex(i, j));
323 return pload<typename Tensor::PacketReturnType>(data);
325 template <
int> EIGEN_DEVICE_FUNC
326 EIGEN_STRONG_INLINE Packet loadHalfPacket(Index i, Index j)
const {
327 return loadPacket(i, j);
332 template<
typename Scalar,
typename Index,
int side,
334 typename nocontract_t,
typename contract_t,
336 bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment>
339 typedef typename Tensor::PacketReturnType Packet;
353 : m_base_mapper(base_mapper), m_vert_offset(vert_offset), m_horiz_offset(horiz_offset) {
356 if (UseDirectOffsets) {
357 Index stride = m_base_mapper.stride();
358 m_base_mapper.offsetBuffer(vert_offset + horiz_offset * stride);
362 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i)
const {
363 if (UseDirectOffsets) {
364 return m_base_mapper(i, 0);
366 return m_base_mapper(i + m_vert_offset, m_horiz_offset);
368 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, Index j)
const {
369 if (UseDirectOffsets) {
370 return m_base_mapper(i, j);
372 return m_base_mapper(i + m_vert_offset, j + m_horiz_offset);
375 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i)
const {
376 if (UseDirectOffsets) {
377 return m_base_mapper.template loadPacket<Alignment>(i, 0);
379 return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
381 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j)
const {
382 if (UseDirectOffsets) {
383 return m_base_mapper.template loadPacket<Alignment>(i, j);
385 return m_base_mapper.template loadPacket<Alignment>(i + m_vert_offset, j + m_horiz_offset);
388 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i)
const {
389 if (UseDirectOffsets) {
390 return m_base_mapper.template loadHalfPacket<Alignment>(i, 0);
392 return m_base_mapper.template loadHalfPacket<Alignment>(i + m_vert_offset, m_horiz_offset);
395 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
void storePacket(Index i, Packet p)
const {
396 if (UseDirectOffsets) {
397 m_base_mapper.storePacket(i, 0, p);
399 m_base_mapper.storePacket(i + m_vert_offset, m_horiz_offset, p);
402 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j)
const {
403 if (UseDirectOffsets) {
404 return LinearMapper(m_base_mapper, i, j);
406 return LinearMapper(m_base_mapper, i + m_vert_offset, j + m_horiz_offset);
409 template <
typename PacketT,
int AlignmentType>
410 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE PacketT load(Index i)
const {
413 if (UseDirectOffsets) {
414 return m_base_mapper.template loadPacket<ActualAlignment>(i, 0);
416 return m_base_mapper.template loadPacket<ActualAlignment>(i + m_vert_offset, m_horiz_offset);
419 template <
typename Packet>
420 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
bool aligned(Index)
const {
425 ParentMapper m_base_mapper;
426 const Index m_vert_offset;
427 const Index m_horiz_offset;
431 template<
typename Scalar_,
typename Index,
int side,
433 typename nocontract_t,
typename contract_t,
435 bool inner_dim_contiguous,
bool inner_dim_reordered,
int Alignment>
437 :
public BaseTensorContractionMapper<Scalar_, Index, side, Tensor, nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment> {
440 typedef Scalar_ Scalar;
446 const nocontract_t& nocontract_strides,
447 const nocontract_t& ij_strides,
448 const contract_t& contract_strides,
449 const contract_t& k_strides)
450 : Base(tensor, nocontract_strides, ij_strides, contract_strides, k_strides) { }
453 EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j)
const {
454 return SubMapper(*
this, i, j);
457 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j)
const {
458 return VectorMapper(*
this, i, j);
467 #endif // EIGEN_CXX11_TENSOR_TENSOR_CONTRACTION_MAPPER_H Definition: TensorMeta.h:154
Definition: TensorContractionMapper.h:77
Definition: TensorContractionMapper.h:26
Definition: XprHelper.h:158
Namespace containing all symbols from the Eigen library.
Definition: bench_norm.cpp:85
Data pointer has no specific alignment.
Definition: Constants.h:228
Definition: TensorContractionMapper.h:222
EIGEN_DEFAULT_DENSE_INDEX_TYPE Index
The Index type as used for the API.
Definition: Meta.h:33
Definition: Constants.h:235
Definition: PacketMath.h:48
Definition: BandTriangularSolver.h:13
AlignmentType
Enum for indicating whether a buffer is aligned or not.
Definition: Constants.h:227
Definition: TensorContractionMapper.h:337
Definition: EmulateArray.h:203
The tensor class.
Definition: Tensor.h:63