tensor-0.1.0
 All Data Structures Namespaces Functions Variables Typedefs Enumerations Enumerator Groups Pages
mmult_tensor_sparse.h
1 // -*- mode: c++; fill-column: 80; c-basic-offset: 2; indent-tabs-mode: nil -*-
2 /*
3  Copyright (c) 2010 Juan Jose Garcia Ripoll
4 
5  Tensor is free software; you can redistribute it and/or modify it
6  under the terms of the GNU Library General Public License as published
7  by the Free Software Foundation; either version 2 of the License, or
8  (at your option) any later version.
9 
10  This program is distributed in the hope that it will be useful,
11  but WITHOUT ANY WARRANTY; without even the implied warranty of
12  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  GNU Library General Public License for more details.
14 
15  You should have received a copy of the GNU General Public License along
16  with this program; if not, write to the Free Software Foundation, Inc.,
17  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
18 */
19 
20 #ifndef TENSOR_MMULT_TENSOR_SPARSE_H
21 #define TENSOR_MMULT_TENSOR_SPARSE_H
22 
24 // RAW ROUTINE FOR THE TENSOR-SPARSE PRODUCT
25 //
26 
27 template<typename elt_t>
28 static void
29 mult_t_sp(elt_t *dest,
30  const elt_t *vector,
31  const index *row_start, const index *column, const elt_t *matrix,
32  index i_len, index j_len, index k_len, index l_len)
33 {
34  // dest(i,k,l) = vector(i,j,k) matrix(j,l)
35  for (index j = 0; j < j_len; j++, vector += i_len) {
36  for (index x = row_start[j]; x < row_start[j+1]; x++) {
37  index l = column[x];
38  elt_t *d = dest + l * (k_len*i_len);
39  const elt_t *v = vector;
40  elt_t m = matrix[x];
41  for (index k = 0; k < k_len; k++) {
42  for (index i = 0; i < i_len; i++, d++) {
43  *d += *(v++) * m;
44  }
45  v += (j_len-1)*i_len;
46  }
47  }
48  }
49 }
50 
52 // HIGHER LEVEL INTERFACE
53 //
54 
55 template<typename elt_t>
56 static inline const Tensor<elt_t>
57 do_mmult(const Tensor<elt_t> &m1, const Sparse<elt_t> &m2)
58 {
59  index N = m1.rank();
60  index i_len = 1;
61  Indices dims(N);
62  for (index k = 0; k < N-1; k++) {
63  dims.at(k) = m1.dimension(k);
64  i_len *= dims[k];
65  }
66  index j_len = m1.dimension(-1);
67  index l_len = dims.at(N-1) = m2.columns();
68 
69  if (j_len != m2.rows()) {
70  std::cerr <<
71  "In mmult(T,S), the last index of tensor T does not match the number of rows\n"
72  "in sparse matrix S.";
73  abort();
74  }
75 
76  Tensor<elt_t> output = Tensor<elt_t>::zeros(dims);
77 
78  mult_t_sp<elt_t>(output.begin(),
79  m1.begin(),
80  m2.priv_row_start().begin(),
81  m2.priv_column().begin(), m2.priv_data().begin(),
82  i_len, j_len, 1, l_len);
83 
84  return output;
85 }
86 
87 #endif /* TENSOR_MMULT_TENSOR_SPARSE_H */
static const Tensor< elt_t > zeros(index rows)
Matrix of zeros.
Definition: tensor.h:235