tensor-0.1.0
 All Data Structures Namespaces Functions Variables Typedefs Enumerations Enumerator Groups Pages
tensor_take_diag.cc
1 // -*- mode: c++; fill-column: 80; c-basic-offset: 2; indent-tabs-mode: nil -*-
2 /*
3  Copyright (c) 2010 Juan Jose Garcia Ripoll
4 
5  Tensor is free software; you can redistribute it and/or modify it
6  under the terms of the GNU Library General Public License as published
7  by the Free Software Foundation; either version 2 of the License, or
8  (at your option) any later version.
9 
10  This program is distributed in the hope that it will be useful,
11  but WITHOUT ANY WARRANTY; without even the implied warranty of
12  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  GNU Library General Public License for more details.
14 
15  You should have received a copy of the GNU General Public License along
16  with this program; if not, write to the Free Software Foundation, Inc.,
17  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19 
20 #include <algorithm>
21 #include <tensor/tensor.h>
22 
23 template<typename elt_t>
24 static void
25 do_diag(elt_t *output, const elt_t *input, tensor::index a1, tensor::index a2,
26  tensor::index a2b, tensor::index a3, tensor::index a4, tensor::index a5,
27  tensor::index which)
28 {
29  // output(a1,a2b,a3,a5), input(a1,a2,a3,a4,a5)
30  tensor::index o1, o2;
31  if (which < 0) {
32  o2 = - which;
33  o1 = 0;
34  } else {
35  o2 = 0;
36  o1 = which;
37  }
38  for (tensor::index m = 0; m < a5; m++) {
39  for (tensor::index l = 0; l < a2b; l++) {
40  for (tensor::index k = 0; k < a3; k++) {
41  for (tensor::index i = 0; i < a1; i++) {
42  output[i + a1 * (l + a2b * (k + a3 * m))] =
43  input[i + a1 * ((o1+l) + a2 * (k + a3 * ((o2 + l) + a4 * m)))];
44  }
45  }
46  }
47  }
48 }
49 
50 /* Extract a diagonal from a matrix. */
51 template<typename elt_t>
52 const tensor::Tensor<elt_t> do_take_diag(const tensor::Tensor<elt_t> &a, int which, int ndx1, int ndx2)
53 {
54  if (ndx1 < 0)
55  ndx1 += a.rank();
56  assert((ndx1 < a.rank()) && (ndx1 >= 0));
57  if (ndx2 < 0)
58  ndx2 += a.rank();
59  assert((ndx2 < a.rank()) && (ndx2 >= 0));
60 
61  tensor::Indices new_dims(std::max((int)a.rank()-1,(int)1));
62  size_t rank = 0;
63  tensor::index i, a1, a2, a3, a4, a5, a2b;
64  if (ndx1 > ndx2) {
65  std::swap(ndx1, ndx2);
66  which = -which;
67  }
68  for (i = 0, a1 = 1; i < ndx1; i++) {
69  size_t di = a.dimension(i);
70  new_dims.at(rank++) = di;
71  a1 *= di;
72  }
73  a2 = a.dimension(i++);
74  new_dims.at(rank++) = a2;
75  for (a3 = 1; i < ndx2; i++) {
76  size_t di = a.dimension(i);
77  new_dims.at(rank++) = di;
78  a3 *= di;
79  }
80  a4 = a.dimension(i++);
81  for (a5 = 1; i < (tensor::index)a.rank(); i++) {
82  size_t di = a.dimension(i);
83  new_dims.at(rank++) = di;
84  a5 *= di;
85  }
86  if (which <= -a2 || which >= a4) {
87  std::cerr << "In take_diag(M, which, ...), WHICH has a value " << which << " which exceeds the size of the tensor";
88  abort();
89  }
90  if (a2 == 1 && a4 == 1) {
91  return tensor::Tensor<elt_t>(new_dims, a);
92  }
93  if (which < 0) {
94  a2b = std::max((tensor::index)0, std::min(a2 + which, a4));
95  } else {
96  a2b = std::max((tensor::index)0, std::min(a2, a4 - which));
97  }
98  new_dims.at(ndx1) = a2b;
99  tensor::Tensor<elt_t> output(new_dims);
100  which = -which;
101  if (a2b) {
102  do_diag<elt_t>(output.begin(), a.begin(), a1,a2,a2b,a3,a3,a5,which);
103  }
104  return output;
105 }
int rank() const
Number of Tensor indices.
Definition: tensor.h:119
iterator begin()
Iterator at the beginning.
Definition: tensor.h:256
Vector of 'index' type, where 'index' fits the indices of a tensor.
Definition: indices.h:35
index dimension(int which) const
Length of a given Tensor index.