tensor-0.1.0
 All Data Structures Namespaces Functions Variables Typedefs Enumerations Enumerator Groups Pages
tensor_fold.cc
1 // -*- mode: c++; fill-column: 80; c-basic-offset: 2; indent-tabs-mode: nil -*-
2 /*
3  Copyright (c) 2010 Juan Jose Garcia Ripoll
4 
5  Tensor is free software; you can redistribute it and/or modify it
6  under the terms of the GNU Library General Public License as published
7  by the Free Software Foundation; either version 2 of the License, or
8  (at your option) any later version.
9 
10  This program is distributed in the hope that it will be useful,
11  but WITHOUT ANY WARRANTY; without even the implied warranty of
12  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  GNU Library General Public License for more details.
14 
15  You should have received a copy of the GNU General Public License along
16  with this program; if not, write to the Free Software Foundation, Inc.,
17  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19 
20 #define TENSOR_LOAD_IMPL
21 #include <iostream>
22 #include <tensor/tensor.h>
23 #include <tensor/io.h>
24 #include <tensor/tensor_lapack.h>
25 #include "gemm.cc"
26 
27 namespace tensor {
28 
29  using namespace blas;
30 
31  template<typename elt_t, bool do_conj>
32  void
33  do_fold(Tensor<elt_t> &output,
34  const Tensor<elt_t> &a, int _ndx1, const Tensor<elt_t> &b, int _ndx2)
35  {
36  index i_len,j_len,k_len,l_len,m_len;
37  index rank, i;
38  const index ranka = a.rank();
39  const index rankb = b.rank();
40  index ndx1 = normalize_index(_ndx1, ranka);
41  index ndx2 = normalize_index(_ndx2, rankb);
42  Indices new_dims(std::max<index>(ranka + rankb - 2, 1));
43  /*
44  * Since we use row-major order, in which the first
45  * index varies faster, we nest the loops beginning with the last index,
46  * and the loop what does is
47  * c(i,j,k,m) = a(i,l,j) * b(k,l,m)
48  * where there is a sum over the repeated index "l". In the first part of
49  * the code we find out the size of the contracted (l_len,l_len) and
50  * uncontracted (new_dims, i_len,j_len,k_len,m_len) dimensions of the
51  * tensors.
52  */
53  for (i = 0, rank = 0, i_len=1; i < ndx1; i++) {
54  index di = a.dimension(i);
55  new_dims.at(rank++) = di;
56  i_len *= di;
57  }
58  l_len = a.dimension(i++);
59  if (l_len == 0) {
60  std::cerr << "Unable to fold() tensors with dimensions" << std::endl
61  << "\t" << a.dimensions() << " and "
62  << b.dimensions() << std::endl
63  << "\tbecause indices " << ndx1 << " and " << ndx2
64  << " are empty" << std::endl;
65  abort();
66  }
67  for (j_len = 1; i < ranka; i++) {
68  index di = a.dimension(i);
69  new_dims.at(rank++) = di;
70  j_len *= di;
71  }
72  for (i = 0, k_len=1; i < ndx2; i++) {
73  index di = b.dimension(i);
74  new_dims.at(rank++) = di;
75  k_len *= di;
76  }
77  if (l_len != b.dimension(i++)) {
78  std::cerr << "Unable to fold() tensors with dimensions" << std::endl
79  << "\t" << a.dimensions() << " and "
80  << b.dimensions() << std::endl
81  << "\tbecause indices " << ndx1 << " and " << ndx2
82  << " have different sizes" << std::endl;
83  abort();
84  }
85  for (m_len = 1; i < rankb; i++) {
86  index di = b.dimension(i);
87  new_dims.at(rank++) = di;
88  m_len *= di;
89  }
90  /*
91  * Create the output tensor. Sometimes it is just a number.
92  */
93  if (rank == 0) {
94  rank = 1;
95  new_dims.at(0) = 1;
96  }
97  output = Tensor<elt_t>(new_dims);
98  if (output.size() == 0)
99  return;
100 
101  elt_t *pC = output.begin();
102  const elt_t zero = number_zero<elt_t>();
103  const elt_t one = number_one<elt_t>();
104  const elt_t *pA = a.begin();
105  const elt_t *pB = b.begin();
106  if (i_len == 1) {
107  if (k_len == 1) {
108  // C(j_len,m_len) = A(l_len,j_len)*B(l_len,m_len);
109  char transa = do_conj? 'C' : 'T';
110  char transb = 'N';
111  gemm(transa, transb, j_len, m_len, l_len, one,
112  pA, l_len, pB, l_len, zero, pC, j_len);
113  return;
114  }
115  if (m_len == 1) {
116  // C(j_len,k_len) = A(l_len,j_len)*B(k_len,l_len);
117  char transa = do_conj? 'C' : 'T';
118  char transb = 'T';
119  gemm(transa, transb, j_len, k_len, l_len, one,
120  pA, l_len, pB, k_len, zero, pC, j_len);
121  return;
122  }
123  } else if (j_len == 1 && !do_conj) {
124  if (k_len == 1) {
125  // C(i_len,m_len) = A(i_len,l_len)*B(l_len,m_len);
126  char transa = 'N';
127  char transb = 'N';
128  gemm(transa, transb, i_len, m_len, l_len, one,
129  pA, i_len, pB, l_len, zero, pC, i_len);
130  return;
131  }
132  if (m_len == 1) {
133  // C(i_len,k_len) = A(i_len,l_len)*B(k_len,l_len);
134  char transa = 'N';
135  char transb = 'T';
136  gemm(transa, transb, i_len, k_len, l_len, one,
137  pA, i_len, pB, k_len, zero, pC, i_len);
138  return;
139  }
140  }
141  const char op1 = 'N';
142  const char op2 = do_conj? 'C' : 'T';
143  const index ij_len = i_len*j_len;
144  const index il_len = i_len*l_len;
145  const index kl_len = k_len*l_len;
146  const index jk_len = j_len*k_len;
147  /*
148  * C(i,j,k,m) = A(i,l,j) * B(k,l,m)
149  */
150  for (index m = 0; m < m_len; m++) {
151  for (index j = 0; j < j_len; j++) {
152  gemm(op1, op2, i_len, k_len, l_len, one,
153  pA + il_len*j, i_len, pB + kl_len*m, k_len,
154  zero, pC + i_len*(j + jk_len*m), ij_len);
155  }
156  }
157  if (do_conj) {
158  for (index i = output.size(); i; i--, pC++)
159  *pC = tensor::conj(*pC);
160  }
161  }
162 
163 } // namespace tensor
const RTensor conj(const RTensor &r)
Complex conjugate of a real tensor.
Definition: tensor.h:461