tensor-0.1.0
 All Data Structures Namespaces Functions Variables Typedefs Enumerations Enumerator Groups Pages
tensor_trace.cc
1 // -*- mode: c++; fill-column: 80; c-basic-offset: 2; indent-tabs-mode: nil -*-
2 /*
3  Copyright (c) 2010 Juan Jose Garcia Ripoll
4 
5  Tensor is free software; you can redistribute it and/or modify it
6  under the terms of the GNU Library General Public License as published
7  by the Free Software Foundation; either version 2 of the License, or
8  (at your option) any later version.
9 
10  This program is distributed in the hope that it will be useful,
11  but WITHOUT ANY WARRANTY; without even the implied warranty of
12  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  GNU Library General Public License for more details.
14 
15  You should have received a copy of the GNU General Public License along
16  with this program; if not, write to the Free Software Foundation, Inc.,
17  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19 
20 #include <tensor/tensor.h>
21 
22 namespace tensor {
23 
24  template<typename n>
25  static void trace_loop(n *C, const n *D, index a1, index a2, index a3, index a4, index a5)
26  {
27  // C(a1,a3,a5), D(a1,a2,a3,a4,a5)
28  // C(i,k,m) = sum D(i,l,k,l,m) over l
29  index a2b = std::min(a2, a4);
30  if (a1 == 1) {
31  for (index m = 0; m < a5; m++) {
32  for (index l = 0; l < a2b; l++) {
33  for (index k = 0; k < a3; k++) {
34  C[k + a3 * m] +=
35  D[l + a2 * (k + a3 * (l + a4 * m))];
36  }
37  }
38  }
39  } else {
40  for (index m = 0; m < a5; m++) {
41  for (index l = 0; l < a2b; l++) {
42  for (index k = 0; k < a3; k++) {
43  for (index i = 0; i < a1; i++) {
44  C[i + a1 * (k + a3 * m)] +=
45  D[i + a1 * (l + a2 * (k + a3 * (l + a4 * m)))];
46  }
47  }
48  }
49  }
50  }
51  }
52 
53  template<typename elt_t>
54  static const Tensor<elt_t> do_trace(const Tensor<elt_t> &D, index i1, index i2)
55  {
56  assert(i1 < D.rank() && i1 > -D.rank());
57  assert(i2 < D.rank() && i2 > -D.rank());
58  if (i1 < 0)
59  i1 = i1 + D.rank();
60  if (i2 < 0)
61  i2 = i2 + D.rank();
62  if (i1 > i2) {
63  std::swap(i1,i2);
64  } else if (i2 == i1) {
65  std::cerr << "In trace(D, i, j), indices 'i' and 'j' are the same." << std::endl;
66  abort();
67  }
68 
69  index a1, a2, a3, a4, a5, i, rank;
70  Indices dimensions(std::max(D.rank() - 2, 1));
71  dimensions.at(rank=0) = 1;
72  for (a1 = 1, i = 0; i < i1; ) {
73  a1 *= (dimensions.at(rank++) = D.dimension(i++));
74  }
75  a2 = D.dimension(i++);
76  for (a3 = 1; i < i2; )
77  a3 *= (dimensions.at(rank++) = D.dimension(i++));
78  a4 = D.dimension(i++);
79  for (a5 = 1; i < D.rank(); )
80  a5 *= (dimensions.at(rank++) = D.dimension(i++));
81 
82  Tensor<elt_t> output = RTensor::zeros(dimensions);
83  trace_loop<elt_t>(output.begin(), D.begin(), a1, a2, a3, a4, a5);
84  return output;
85  }
86 
87 } // namespace tensor
int rank() const
Number of Tensor indices.
Definition: tensor.h:119
static const Tensor< elt_t > zeros(index rows)
Matrix of zeros.
Definition: tensor.h:235