21 #include <tensor/tensor.h>
23 template<
typename elt_t>
25 do_diag(elt_t *output,
const elt_t *input, tensor::index a1, tensor::index a2,
26 tensor::index a2b, tensor::index a3, tensor::index a4, tensor::index a5,
38 for (tensor::index m = 0; m < a5; m++) {
39 for (tensor::index l = 0; l < a2b; l++) {
40 for (tensor::index k = 0; k < a3; k++) {
41 for (tensor::index i = 0; i < a1; i++) {
42 output[i + a1 * (l + a2b * (k + a3 * m))] =
43 input[i + a1 * ((o1+l) + a2 * (k + a3 * ((o2 + l) + a4 * m)))];
51 template<
typename elt_t>
56 assert((ndx1 < a.
rank()) && (ndx1 >= 0));
59 assert((ndx2 < a.
rank()) && (ndx2 >= 0));
63 tensor::index i, a1, a2, a3, a4, a5, a2b;
65 std::swap(ndx1, ndx2);
68 for (i = 0, a1 = 1; i < ndx1; i++) {
70 new_dims.at(rank++) = di;
74 new_dims.at(rank++) = a2;
75 for (a3 = 1; i < ndx2; i++) {
77 new_dims.at(rank++) = di;
81 for (a5 = 1; i < (tensor::index)a.
rank(); i++) {
83 new_dims.at(rank++) = di;
86 if (which <= -a2 || which >= a4) {
87 std::cerr <<
"In take_diag(M, which, ...), WHICH has a value " << which <<
" which exceeds the size of the tensor";
90 if (a2 == 1 && a4 == 1) {
94 a2b = std::max((tensor::index)0, std::min(a2 + which, a4));
96 a2b = std::max((tensor::index)0, std::min(a2, a4 - which));
98 new_dims.at(ndx1) = a2b;
102 do_diag<elt_t>(output.begin(), a.
begin(), a1,a2,a2b,a3,a3,a5,which);
int rank() const
Number of Tensor indices.
iterator begin()
Iterator at the beginning.
Vector of 'index' type, where 'index' fits the indices of a tensor.
index dimension(int which) const
Length of a given Tensor index.