20 #define TENSOR_LOAD_IMPL
22 #include <tensor/tensor.h>
23 #include <tensor/io.h>
24 #include <tensor/tensor_lapack.h>
31 template<
typename elt_t,
bool do_conj>
33 do_fold(Tensor<elt_t> &output,
34 const Tensor<elt_t> &a,
int _ndx1,
const Tensor<elt_t> &b,
int _ndx2)
36 index i_len,j_len,k_len,l_len,m_len;
38 const index ranka = a.rank();
39 const index rankb = b.rank();
40 index ndx1 = normalize_index(_ndx1, ranka);
41 index ndx2 = normalize_index(_ndx2, rankb);
42 Indices new_dims(std::max<index>(ranka + rankb - 2, 1));
53 for (i = 0, rank = 0, i_len=1; i < ndx1; i++) {
54 index di = a.dimension(i);
55 new_dims.at(rank++) = di;
58 l_len = a.dimension(i++);
60 std::cerr <<
"Unable to fold() tensors with dimensions" << std::endl
61 <<
"\t" << a.dimensions() <<
" and "
62 << b.dimensions() << std::endl
63 <<
"\tbecause indices " << ndx1 <<
" and " << ndx2
64 <<
" are empty" << std::endl;
67 for (j_len = 1; i < ranka; i++) {
68 index di = a.dimension(i);
69 new_dims.at(rank++) = di;
72 for (i = 0, k_len=1; i < ndx2; i++) {
73 index di = b.dimension(i);
74 new_dims.at(rank++) = di;
77 if (l_len != b.dimension(i++)) {
78 std::cerr <<
"Unable to fold() tensors with dimensions" << std::endl
79 <<
"\t" << a.dimensions() <<
" and "
80 << b.dimensions() << std::endl
81 <<
"\tbecause indices " << ndx1 <<
" and " << ndx2
82 <<
" have different sizes" << std::endl;
85 for (m_len = 1; i < rankb; i++) {
86 index di = b.dimension(i);
87 new_dims.at(rank++) = di;
97 output = Tensor<elt_t>(new_dims);
98 if (output.size() == 0)
101 elt_t *pC = output.begin();
102 const elt_t zero = number_zero<elt_t>();
103 const elt_t one = number_one<elt_t>();
104 const elt_t *pA = a.begin();
105 const elt_t *pB = b.begin();
109 char transa = do_conj?
'C' :
'T';
111 gemm(transa, transb, j_len, m_len, l_len, one,
112 pA, l_len, pB, l_len, zero, pC, j_len);
117 char transa = do_conj?
'C' :
'T';
119 gemm(transa, transb, j_len, k_len, l_len, one,
120 pA, l_len, pB, k_len, zero, pC, j_len);
123 }
else if (j_len == 1 && !do_conj) {
128 gemm(transa, transb, i_len, m_len, l_len, one,
129 pA, i_len, pB, l_len, zero, pC, i_len);
136 gemm(transa, transb, i_len, k_len, l_len, one,
137 pA, i_len, pB, k_len, zero, pC, i_len);
141 const char op1 =
'N';
142 const char op2 = do_conj?
'C' :
'T';
143 const index ij_len = i_len*j_len;
144 const index il_len = i_len*l_len;
145 const index kl_len = k_len*l_len;
146 const index jk_len = j_len*k_len;
150 for (index m = 0; m < m_len; m++) {
151 for (index j = 0; j < j_len; j++) {
152 gemm(op1, op2, i_len, k_len, l_len, one,
153 pA + il_len*j, i_len, pB + kl_len*m, k_len,
154 zero, pC + i_len*(j + jk_len*m), ij_len);
158 for (index i = output.size(); i; i--, pC++)
const RTensor conj(const RTensor &r)
Complex conjugate of a real tensor.