1 #ifndef CPPAD_CG_UTIL_INCLUDED 2 #define CPPAD_CG_UTIL_INCLUDED 22 template<
class VectorBool,
class Base>
23 void zeroOrderDependency(ADFun<Base>& fun,
26 size_t m = fun.Range();
27 CPPADCG_ASSERT_KNOWN(vx.size() >= fun.Domain(),
"Invalid vx size")
28 CPPADCG_ASSERT_KNOWN(vy.size() >= m,
"Invalid vy size")
30 using VectorSet = std::vector<std::set<size_t> >;
32 const VectorSet jacSparsity = jacobianSparsitySet<VectorSet, Base>(fun);
34 for (
size_t i = 0; i < m; i++) {
35 for (
size_t j : jacSparsity[i]) {
44 template<
class VectorSet>
45 inline bool isIdentityPattern(
const VectorSet& pattern,
47 CPPADCG_ASSERT_UNKNOWN(pattern.size() >= mRows)
49 for (
size_t i = 0; i < mRows; i++) {
50 if (pattern[i].size() != 1 || *pattern[i].begin() != i) {
57 template<
class VectorSet>
58 inline VectorSet transposePattern(
const VectorSet& pattern,
61 CPPADCG_ASSERT_UNKNOWN(pattern.size() >= mRows)
63 VectorSet transpose(nCols);
64 for (
size_t i = 0; i < mRows; i++) {
65 for (
size_t it : pattern[i]) {
66 transpose[it].insert(i);
81 template<
class VectorSet,
class VectorSet2>
82 inline void addTransMatrixSparsity(
const VectorSet& a,
85 CPPADCG_ASSERT_UNKNOWN(a.size() >= mRows)
87 for (
size_t i = 0; i < mRows; i++) {
88 for (
size_t j : a[i]) {
102 template<
class VectorSet,
class VectorSet2>
103 inline void addTransMatrixSparsity(
const VectorSet& a,
104 VectorSet2& result) {
105 addTransMatrixSparsity<VectorSet, VectorSet2>(a, a.size(), result);
116 template<
class VectorSet,
class VectorSet2>
117 inline void addMatrixSparsity(
const VectorSet& a,
119 VectorSet2& result) {
120 CPPADCG_ASSERT_UNKNOWN(result.size() >= mRows)
121 CPPADCG_ASSERT_UNKNOWN(a.size() <= mRows)
123 for (
size_t i = 0; i < mRows; i++) {
124 if (result[i].empty()) {
127 result[i].insert(a[i].begin(), a[i].end());
139 template<
class VectorSet,
class VectorSet2>
140 inline void addMatrixSparsity(
const VectorSet& a,
141 VectorSet2& result) {
142 CPPADCG_ASSERT_UNKNOWN(result.size() == a.size())
144 addMatrixSparsity<VectorSet, VectorSet2>(a, a.size(), result);
156 template<
class VectorSet,
class VectorSet2>
157 inline void multMatrixMatrixSparsity(
const VectorSet& a,
161 multMatrixMatrixSparsity(a, b, result, a.size(), b.size(), q);
177 template<
class VectorSet,
class VectorSet2>
178 inline void multMatrixMatrixSparsity(
const VectorSet& a,
184 CPPADCG_ASSERT_UNKNOWN(a.size() >= m)
185 CPPADCG_ASSERT_UNKNOWN(b.size() >= n)
186 CPPADCG_ASSERT_UNKNOWN(result.size() >= m)
190 if (isIdentityPattern(b, n)) {
191 addMatrixSparsity(a, m, result);
196 VectorSet2 bt = transposePattern(b, n, q);
198 for (
size_t jj = 0; jj < q; jj++) {
199 const std::set<size_t>& colB = bt[jj];
201 for (
size_t i = 0; i < m; i++) {
202 const std::set<size_t>& rowA = a[i];
203 for (
size_t rowb : colB) {
204 if (rowA.find(rowb) != rowA.end()) {
205 result[i].insert(jj);
227 template<
class VectorSet,
class VectorSet2>
228 inline void multMatrixTransMatrixSparsity(
const VectorSet& a,
234 CPPADCG_ASSERT_UNKNOWN(a.size() >= m)
235 CPPADCG_ASSERT_UNKNOWN(b.size() >= m)
236 CPPADCG_ASSERT_UNKNOWN(result.size() >= n)
240 for (
size_t i = 0; i < m; i++) {
241 if (b[i].size() > 0) {
251 if (m == n && isIdentityPattern(a, m)) {
252 addMatrixSparsity(b, n, result);
257 if (m == q && isIdentityPattern(b, m)) {
258 addTransMatrixSparsity(a, m, result);
262 VectorSet at = transposePattern(a, m, n);
263 VectorSet2 bt = transposePattern(b, m, q);
265 for (
size_t jj = 0; jj < q; jj++) {
266 const std::set<size_t>& colB = bt[jj];
268 for (
size_t i = 0; i < n; i++) {
269 const std::set<size_t>& rowAt = at[i];
270 if (!rowAt.empty()) {
271 for (
size_t rowb : colB) {
272 if (rowAt.find(rowb) != rowAt.end()) {
273 result[i].insert(jj);
296 template<
class VectorSet,
class VectorSet2>
297 inline void multMatrixMatrixSparsityTrans(
const VectorSet& aT,
303 CPPADCG_ASSERT_UNKNOWN(aT.size() >= m)
304 CPPADCG_ASSERT_UNKNOWN(b.size() >= m)
308 for (
size_t i = 0; i < m; i++) {
309 if (b[i].size() > 0) {
319 if (m == q && isIdentityPattern(aT, m)) {
320 addTransMatrixSparsity(b, m, rT);
324 VectorSet a = transposePattern(aT, m, q);
325 VectorSet2 bT = transposePattern(b, m, n);
327 for (
size_t jj = 0; jj < n; jj++) {
328 for (
size_t i = 0; i < q; i++) {
329 for (
size_t it : a[i]) {
330 if (bT[jj].find(it) != bT[jj].end()) {
339 template<
class VectorBool>
340 void printSparsityPattern(
const VectorBool& sparsity,
341 const std::string& name,
342 size_t m,
size_t n) {
343 size_t width = std::ceil(std::log10((m > n) ? m : n));
345 std::cout << name <<
" sparsity:\n";
347 for (
size_t i = 0; i < m; i++) {
348 std::cout <<
" " << std::setw(width) << i <<
": ";
349 for (
size_t j = 0; j < n; j++) {
350 if (sparsity[i * n + j]) {
351 std::cout << std::setw(width) << j <<
" ";
353 std::cout << std::setw(width) <<
" " <<
" ";
358 std::cout << std::endl;
361 template<
class VectorSet>
362 void printSparsityPattern(
const VectorSet& sparsity,
363 const std::string& name,
364 bool printLocationByRow =
false) {
365 size_t maxDim = sparsity.size();
367 for (
size_t i = 0; i < sparsity.size(); i++) {
368 if (sparsity[i].size() > 0 && *sparsity[i].rbegin() > maxDim) {
369 maxDim = *sparsity[i].rbegin();
371 nnz += sparsity[i].size();
374 size_t width = std::ceil(std::log10(maxDim));
376 size_t width3 = width;
377 if (printLocationByRow) {
378 width2 = std::ceil(std::log10(nnz));
379 width3 += width2 + 1;
382 std::cout << name <<
" sparsity:\n";
386 for (
size_t i = 0; i < sparsity.size(); i++) {
387 std::cout <<
" " << std::setw(width) << i <<
": ";
389 for (
size_t j : sparsity[i]) {
390 if (j != 0 &&
long(j) != last + 1) {
391 std::cout << std::setw((j - last - 1) * (width3 + 1)) <<
" ";
393 if (printLocationByRow)
394 std::cout << std::setw(width2) << e <<
":";
395 std::cout << std::setw(width) << j <<
" ";
401 std::cout << std::endl;
404 template<
class VectorSize>
405 void printSparsityPattern(
const VectorSize& row,
406 const VectorSize& col,
407 const std::string& name,
409 std::vector<std::set<size_t> > sparsity(m);
410 generateSparsitySet(row, col, sparsity);
411 printSparsityPattern(sparsity, name);
414 inline bool intersects(
const std::set<size_t>& a,
415 const std::set<size_t>& b) {
416 if (a.empty() || b.empty()) {
418 }
else if (*a.rbegin() < *b.begin() ||
419 *a.begin() > *b.rbegin()) {
423 if (a.size() < b.size()) {
424 for (
size_t ita : a) {
425 if (b.find(ita) != b.end()) {
430 for (
size_t itb : b) {
431 if (a.find(itb) != a.end()) {
440 template<
class VectorSizet,
class VectorSet>
441 inline CppAD::sparse_rc<VectorSizet> toSparsityPattern(
const VectorSet& inPattern,
442 size_t m,
size_t n) {
444 CppAD::sparse_rc<VectorSizet> pattern;
447 for (
const auto& p: inPattern) {
451 pattern.resize(m, n, nnz);
454 for (
size_t i = 0; i < inPattern.size(); ++i) {
455 for (
size_t j : inPattern[i]) {
456 pattern.set(e++, i, j);
470 inline CodeHandler<Base>* findHandler(
const std::vector<CG<Base> >& ty) {
471 for (
size_t i = 0; i < ty.size(); i++) {
472 if (ty[i].getCodeHandler() !=
nullptr) {
473 return ty[i].getCodeHandler();
480 inline CodeHandler<Base>* findHandler(
const CppAD::vector<CG<Base> >& ty) {
481 for (
size_t i = 0; i < ty.size(); i++) {
482 if (ty[i].getCodeHandler() !=
nullptr) {
483 return ty[i].getCodeHandler();
491 for (
size_t i = 0; i < ty.size(); i++) {
492 if (ty[i].getCodeHandler() !=
nullptr) {
493 return ty[i].getCodeHandler();
500 inline Argument<Base> asArgument(
const CG<Base>& tx) {
501 if (tx.isParameter()) {
502 return Argument<Base>(tx.getValue());
504 return Argument<Base>(*tx.getOperationNode());
509 inline std::vector<Argument<Base> > asArguments(
const std::vector<CG<Base> >& tx) {
510 std::vector<Argument<Base> > arguments(tx.size());
511 for (
size_t i = 0; i < arguments.size(); i++) {
512 arguments[i] = asArgument(tx[i]);
518 inline std::vector<Argument<Base> > asArguments(
const CppAD::vector<CG<Base> >& tx) {
519 std::vector<Argument<Base> > arguments(tx.size());
520 for (
size_t i = 0; i < arguments.size(); i++) {
521 arguments[i] = asArgument(tx[i]);
536 template<
class Key,
class Value>
537 void mapKeys(
const std::map<Key, Value>& map, std::set<Key>& keys) {
538 for (
const auto& p : map) {
539 keys.insert(keys.end(), p.first);
549 template<
class Key,
class Value>
550 void mapKeys(
const std::map<Key, Value>& map, std::vector<Key>& keys) {
551 keys.resize(map.size());
554 typename std::map<Key, Value>::const_iterator it;
555 for (it = map.begin(); it != map.end(); ++it, i++) {
567 template<
class Key,
class Value>
568 bool compareMapKeys(
const std::map<Key, Value>& map,
const std::set<Key>& keys) {
569 if (map.size() != keys.size())
572 typename std::map<Key, Value>::const_iterator itm = map.begin();
573 typename std::set<Key>::const_iterator itk = keys.begin();
574 for (; itm != map.end(); ++itm, ++itk) {
575 if (itm->first != *itk)
589 template<
class Key,
class Value>
590 inline std::map<Key, Value> filterBykeys(
const std::map<Key, Value>& m,
591 const std::set<Key>& keys) {
592 std::map<Key, Value> filtered;
594 typename std::map<Key, Value>::const_iterator itM;
596 for (
const Key& k : keys) {
598 if (itM != m.end()) {
599 filtered[itM->first] = itM->second;
615 inline int compare(
const std::set<T>& s1,
const std::set<T>& s2) {
616 if (s1.size() < s2.size()) {
618 }
else if (s1.size() > s2.size()) {
621 typename std::set<T>::const_iterator it1, it2;
622 for (it1 = s1.begin(), it2 = s2.begin(); it1 != s1.end(); ++it1, ++it2) {
625 }
else if (*it1 > *it2) {
636 bool operator() (
const std::set<T>& lhs,
const std::set<T>& rhs)
const {
637 return compare(lhs, rhs) == -1;
645 inline void print(
const Base& v) {
649 template<
class Key,
class Value>
650 inline void print(
const std::map<Key, Value>& m) {
651 for (
const auto& p : m) {
652 std::cout << p.first <<
" : ";
654 std::cout << std::endl;
659 inline void print(
const std::set<Base>& s) {
662 for (
auto itj = s.begin(); itj != s.end(); ++itj) {
663 if (itj != s.begin()) std::cout <<
" ";
671 inline void print(
const std::set<Base*>& s) {
674 for (
const auto itj = s.begin(); itj != s.end(); ++itj) {
675 if (itj != s.begin()) std::cout <<
" ";
677 if (v ==
nullptr) std::cout <<
"NULL";
685 inline void print(
const std::vector<Base>& v) {
688 for (
size_t i = 0; i < v.size(); i++) {
689 if (i != 0) std::cout <<
" ";
696 template<
class VectorSize,
class VectorBase>
697 inline void printTripletMatrix(
const VectorSize &rows,
698 const VectorSize &cols,
699 const VectorBase &values) {
700 size_t n = values.size();
701 assert(rows.size() == n);
702 assert(cols.size() == n);
704 for (
size_t i = 0; i < n; ++i) {
705 std::cout <<
"[ " << rows[i] <<
", " << cols[i] <<
"] -> " << values[i] << std::endl;
724 inline CG<Base> makePrintValue(
const std::string& before,
726 const std::string& after =
"") {
727 std::cout << before << x << after;
729 if (x.getOperationNode() !=
nullptr) {
731 CG<Base> out(*handler->makePrintNode(before, *x.getOperationNode(), after));
751 inline void replaceString(std::string& text,
752 const std::string& toReplace,
753 const std::string& replacement) {
755 while ((pos = text.find(toReplace, pos)) != std::string::npos) {
756 text.replace(pos, toReplace.length(), replacement);
757 pos += replacement.length();
761 inline std::vector<std::string> explode(
const std::string& text,
762 const std::string& delimiter) {
763 std::vector<std::string> matches;
765 const size_t dlen = delimiter.length();
772 pos = text.find(delimiter, start);
773 if (pos == std::string::npos) {
776 matches.push_back(text.substr(start, pos - start));
780 if (start < text.length()) {
781 matches.push_back(text.substr(start, text.length() - start));
787 inline std::string implode(
const std::vector<std::string>& text,
788 const std::string& delimiter) {
791 }
else if (text.size() == 1) {
796 for (
const auto& s: text)
798 out.reserve(n + (text.size() - 1) * delimiter.size());
800 for (
size_t i = 1; i < text.size(); ++i) {
808 inline std::string readStringFromFile(
const std::string& path) {
809 std::ifstream iStream;
812 std::stringstream strStream;
813 strStream << iStream.rdbuf();
815 return strStream.str();
const Base & getValue() const
void setValue(const Base &val)
bool isValueDefined() const
CodeHandler< Base > * getCodeHandler() const