STXXL  1.4-dev
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
matrix_low_level.h
Go to the documentation of this file.
1 /***************************************************************************
2  * include/stxxl/bits/containers/matrix_low_level.h
3  *
4  * Part of the STXXL. See http://stxxl.sourceforge.net
5  *
6  * Copyright (C) 2010-2011 Raoul Steffen <[email protected]>
7  *
8  * Distributed under the Boost Software License, Version 1.0.
9  * (See accompanying file LICENSE_1_0.txt or copy at
10  * http://www.boost.org/LICENSE_1_0.txt)
11  **************************************************************************/
12 
13 #ifndef STXXL_CONTAINERS_MATRIX_LOW_LEVEL_HEADER
14 #define STXXL_CONTAINERS_MATRIX_LOW_LEVEL_HEADER
15 
16 #ifndef STXXL_BLAS
17 #define STXXL_BLAS 0
18 #endif
19 
20 #include <complex>
21 
23 #include <stxxl/bits/parallel.h>
24 
26 
27 //! \addtogroup matrix
28 //! \{
29 
30 namespace matrix_local {
31 
32 // forward declaration
33 template <typename ValueType, unsigned BlockSideLength>
34 struct matrix_operations;
35 
36 // generic declaration
37 template <unsigned BlockSideLength, bool transposed>
39 
40 // row-major specialization
41 template <unsigned BlockSideLength>
42 struct switch_major_index<BlockSideLength, false>
43 {
44  inline switch_major_index(const int_type row, const int_type col) : i(row * BlockSideLength + col) { }
45  inline operator int_type& () { return i; }
46 
47 private:
49 };
50 
51 //column-major specialization
52 template <unsigned BlockSideLength>
53 struct switch_major_index<BlockSideLength, true>
54 {
55  inline switch_major_index(const int_type row, const int_type col) : i(row + col * BlockSideLength) { }
56  inline operator int_type& () { return i; }
57 
58 private:
60 };
61 
62 //! c = a [op] b; for arbitrary entries
63 template <typename ValueType, unsigned BlockSideLength, bool a_transposed, bool b_transposed, class Op>
65 {
66  low_level_matrix_binary_ass_op(ValueType* c, const ValueType* a, const ValueType* b, Op op = Op())
67  {
68  if (a)
69  if (b)
70  #if STXXL_PARALLEL
71  #pragma omp parallel for
72  #endif
73  for (int_type row = 0; row < int_type(BlockSideLength); ++row)
74  for (int_type col = 0; col < int_type(BlockSideLength); ++col)
78  else
79  #if STXXL_PARALLEL
80  #pragma omp parallel for
81  #endif
82  for (int_type row = 0; row < int_type(BlockSideLength); ++row)
83  for (int_type col = 0; col < int_type(BlockSideLength); ++col)
86  else
87  {
88  assert(b /* do not add nothing to nothing */);
89  #if STXXL_PARALLEL
90  #pragma omp parallel for
91  #endif
92  for (int_type row = 0; row < int_type(BlockSideLength); ++row)
93  for (int_type col = 0; col < int_type(BlockSideLength); ++col)
96  }
97  }
98 };
99 
100 //! c [op]= a; for arbitrary entries
101 template <typename ValueType, unsigned BlockSideLength, bool a_transposed, class Op>
103 {
104  low_level_matrix_unary_ass_op(ValueType* c, const ValueType* a, Op op = Op())
105  {
106  if (a)
107  #if STXXL_PARALLEL
108  #pragma omp parallel for
109  #endif
110  for (int_type row = 0; row < int_type(BlockSideLength); ++row)
111  for (int_type col = 0; col < int_type(BlockSideLength); ++col)
114  }
115 };
116 
117 //! c =[op] a; for arbitrary entries
118 template <typename ValueType, unsigned BlockSideLength, bool a_transposed, class Op>
120 {
121  low_level_matrix_unary_op(ValueType* c, const ValueType* a, Op op = Op())
122  {
123  assert(a);
124  #if STXXL_PARALLEL
125  #pragma omp parallel for
126  #endif
127  for (int_type row = 0; row < int_type(BlockSideLength); ++row)
128  for (int_type col = 0; col < int_type(BlockSideLength); ++col)
131  }
132 };
133 
134 //! multiplies matrices A and B, adds result to C, for arbitrary entries
135 //! param pointer to blocks of A,B,C; elements in blocks have to be in row-major
136 /* designated usage as:
137  * void
138  * low_level_matrix_multiply_and_add(const double * a, bool a_in_col_major,
139  const double * b, bool b_in_col_major,
140  double * c, const bool c_in_col_major) */
141 template <typename ValueType, unsigned BlockSideLength>
143 {
144  low_level_matrix_multiply_and_add(const ValueType* a, bool a_in_col_major,
145  const ValueType* b, bool b_in_col_major,
146  ValueType* c, const bool c_in_col_major)
147  {
148  if (c_in_col_major)
149  {
150  std::swap(a, b);
151  bool a_cm = ! b_in_col_major;
152  b_in_col_major = ! a_in_col_major;
153  a_in_col_major = a_cm;
154  }
155  if (! a_in_col_major)
156  {
157  if (! b_in_col_major)
158  { // => both row-major
159  #if STXXL_PARALLEL
160  #pragma omp parallel for
161  #endif
162  for (int_type i = 0; i < int_type(BlockSideLength); ++i) //OpenMP does not like unsigned iteration variables
163  for (unsigned_type k = 0; k < BlockSideLength; ++k)
164  for (unsigned_type j = 0; j < BlockSideLength; ++j)
165  c[i * BlockSideLength + j] += a[i * BlockSideLength + k] * b[k * BlockSideLength + j];
166  }
167  else
168  { // => a row-major, b col-major
169  #if STXXL_PARALLEL
170  #pragma omp parallel for
171  #endif
172  for (int_type i = 0; i < int_type(BlockSideLength); ++i) //OpenMP does not like unsigned iteration variables
173  for (unsigned_type j = 0; j < BlockSideLength; ++j)
174  for (unsigned_type k = 0; k < BlockSideLength; ++k)
175  c[i * BlockSideLength + j] += a[i * BlockSideLength + k] * b[k + j * BlockSideLength];
176  }
177  }
178  else
179  {
180  if (! b_in_col_major)
181  { // => a col-major, b row-major
182  #if STXXL_PARALLEL
183  #pragma omp parallel for
184  #endif
185  for (int_type i = 0; i < int_type(BlockSideLength); ++i) //OpenMP does not like unsigned iteration variables
186  for (unsigned_type k = 0; k < BlockSideLength; ++k)
187  for (unsigned_type j = 0; j < BlockSideLength; ++j)
188  c[i * BlockSideLength + j] += a[i + k * BlockSideLength] * b[k * BlockSideLength + j];
189  }
190  else
191  { // => both col-major
192  #if STXXL_PARALLEL
193  #pragma omp parallel for
194  #endif
195  for (int_type i = 0; i < int_type(BlockSideLength); ++i) //OpenMP does not like unsigned iteration variables
196  for (unsigned_type k = 0; k < BlockSideLength; ++k)
197  for (unsigned_type j = 0; j < BlockSideLength; ++j)
198  c[i * BlockSideLength + j] += a[i + k * BlockSideLength] * b[k + j * BlockSideLength];
199  }
200  }
201  }
202 };
203 
204 #if STXXL_BLAS
205 typedef int_type blas_int;
206 typedef std::complex<double> blas_double_complex;
207 typedef std::complex<float> blas_single_complex;
208 
209 // --- vector add (used as matrix-add) -----------------
210 
211 extern "C" void daxpy_(const blas_int* n, const double* alpha, const double* x, const blas_int* incx, double* y, const blas_int* incy);
212 extern "C" void saxpy_(const blas_int* n, const float* alpha, const float* x, const blas_int* incx, float* y, const blas_int* incy);
213 extern "C" void zaxpy_(const blas_int* n, const blas_double_complex* alpha, const blas_double_complex* x, const blas_int* incx, blas_double_complex* y, const blas_int* incy);
214 extern "C" void caxpy_(const blas_int* n, const blas_single_complex* alpha, const blas_single_complex* x, const blas_int* incx, blas_single_complex* y, const blas_int* incy);
215 extern "C" void dcopy_(const blas_int* n, const double* x, const blas_int* incx, double* y, const blas_int* incy);
216 extern "C" void scopy_(const blas_int* n, const float* x, const blas_int* incx, float* y, const blas_int* incy);
217 extern "C" void zcopy_(const blas_int* n, const blas_double_complex* x, const blas_int* incx, blas_double_complex* y, const blas_int* incy);
218 extern "C" void ccopy_(const blas_int* n, const blas_single_complex* x, const blas_int* incx, blas_single_complex* y, const blas_int* incy);
219 
220 //! c = a + b; for double entries
221 template <unsigned BlockSideLength>
222 struct low_level_matrix_binary_ass_op<double, BlockSideLength, false, false, typename matrix_operations<double, BlockSideLength>::addition>
223 {
224  low_level_matrix_binary_ass_op(double* c, const double* a, const double* b, typename matrix_operations<double, BlockSideLength>::addition = typename matrix_operations<double, BlockSideLength>::addition())
225  {
226  if (a)
227  if (b)
228  {
229  low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
230  (c, a);
231  low_level_matrix_unary_ass_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
232  (c, b);
233  }
234  else
235  low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
236  (c, a);
237  else
238  {
239  assert(b /* do not add nothing to nothing */);
240  low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
241  (c, b);
242  }
243  }
244 };
245 //! c = a - b; for double entries
246 template <unsigned BlockSideLength>
247 struct low_level_matrix_binary_ass_op<double, BlockSideLength, false, false, typename matrix_operations<double, BlockSideLength>::subtraction>
248 {
249  low_level_matrix_binary_ass_op(double* c, const double* a, const double* b,
250  typename matrix_operations<double, BlockSideLength>::subtraction = typename matrix_operations<double, BlockSideLength>::subtraction())
251  {
252  if (a)
253  if (b)
254  {
255  low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
256  (c, a);
257  low_level_matrix_unary_ass_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::subtraction>
258  (c, b);
259  }
260  else
261  low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
262  (c, a);
263  else
264  {
265  assert(b /* do not add nothing to nothing */);
266  low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::subtraction>
267  (c, b);
268  }
269  }
270 };
271 //! c += a; for double entries
272 template <unsigned BlockSideLength>
273 struct low_level_matrix_unary_ass_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
274 {
275  low_level_matrix_unary_ass_op(double* c, const double* a,
276  typename matrix_operations<double, BlockSideLength>::addition = typename matrix_operations<double, BlockSideLength>::addition())
277  {
278  const blas_int size = BlockSideLength * BlockSideLength;
279  const blas_int int_one = 1;
280  const double one = 1.0;
281  if (a)
282  daxpy_(&size, &one, a, &int_one, c, &int_one);
283  }
284 };
285 //! c -= a; for double entries
286 template <unsigned BlockSideLength>
287 struct low_level_matrix_unary_ass_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::subtraction>
288 {
289  low_level_matrix_unary_ass_op(double* c, const double* a,
290  typename matrix_operations<double, BlockSideLength>::subtraction = typename matrix_operations<double, BlockSideLength>::subtraction())
291  {
292  const blas_int size = BlockSideLength * BlockSideLength;
293  const blas_int int_one = 1;
294  const double minusone = -1.0;
295  if (a)
296  daxpy_(&size, &minusone, a, &int_one, c, &int_one);
297  }
298 };
299 //! c = a; for double entries
300 template <unsigned BlockSideLength>
301 struct low_level_matrix_unary_op<double, BlockSideLength, false, typename matrix_operations<double, BlockSideLength>::addition>
302 {
303  low_level_matrix_unary_op(double* c, const double* a,
304  typename matrix_operations<double, BlockSideLength>::addition = typename matrix_operations<double, BlockSideLength>::addition())
305  {
306  const blas_int size = BlockSideLength * BlockSideLength;
307  const blas_int int_one = 1;
308  dcopy_(&size, a, &int_one, c, &int_one);
309  }
310 };
311 
312 //! c = a + b; for float entries
313 template <unsigned BlockSideLength>
314 struct low_level_matrix_binary_ass_op<float, BlockSideLength, false, false, typename matrix_operations<float, BlockSideLength>::addition>
315 {
316  low_level_matrix_binary_ass_op(float* c, const float* a, const float* b, typename matrix_operations<float, BlockSideLength>::addition = typename matrix_operations<float, BlockSideLength>::addition())
317  {
318  if (a)
319  if (b)
320  {
321  low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
322  (c, a);
323  low_level_matrix_unary_ass_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
324  (c, b);
325  }
326  else
327  low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
328  (c, a);
329  else
330  {
331  assert(b /* do not add nothing to nothing */);
332  low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
333  (c, b);
334  }
335  }
336 };
337 //! c = a - b; for float entries
338 template <unsigned BlockSideLength>
339 struct low_level_matrix_binary_ass_op<float, BlockSideLength, false, false, typename matrix_operations<float, BlockSideLength>::subtraction>
340 {
341  low_level_matrix_binary_ass_op(float* c, const float* a, const float* b,
342  typename matrix_operations<float, BlockSideLength>::subtraction = typename matrix_operations<float, BlockSideLength>::subtraction())
343  {
344  if (a)
345  if (b)
346  {
347  low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
348  (c, a);
349  low_level_matrix_unary_ass_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::subtraction>
350  (c, b);
351  }
352  else
353  low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
354  (c, a);
355  else
356  {
357  assert(b /* do not add nothing to nothing */);
358  low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::subtraction>
359  (c, b);
360  }
361  }
362 };
363 //! c += a; for float entries
364 template <unsigned BlockSideLength>
365 struct low_level_matrix_unary_ass_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
366 {
367  low_level_matrix_unary_ass_op(float* c, const float* a,
368  typename matrix_operations<float, BlockSideLength>::addition = typename matrix_operations<float, BlockSideLength>::addition())
369  {
370  const blas_int size = BlockSideLength * BlockSideLength;
371  const blas_int int_one = 1;
372  const float one = 1.0;
373  if (a)
374  saxpy_(&size, &one, a, &int_one, c, &int_one);
375  }
376 };
377 //! c -= a; for float entries
378 template <unsigned BlockSideLength>
379 struct low_level_matrix_unary_ass_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::subtraction>
380 {
381  low_level_matrix_unary_ass_op(float* c, const float* a,
382  typename matrix_operations<float, BlockSideLength>::subtraction = typename matrix_operations<float, BlockSideLength>::subtraction())
383  {
384  const blas_int size = BlockSideLength * BlockSideLength;
385  const blas_int int_one = 1;
386  const float minusone = -1.0;
387  if (a)
388  saxpy_(&size, &minusone, a, &int_one, c, &int_one);
389  }
390 };
391 //! c = a; for float entries
392 template <unsigned BlockSideLength>
393 struct low_level_matrix_unary_op<float, BlockSideLength, false, typename matrix_operations<float, BlockSideLength>::addition>
394 {
395  low_level_matrix_unary_op(float* c, const float* a,
396  typename matrix_operations<float, BlockSideLength>::addition = typename matrix_operations<float, BlockSideLength>::addition())
397  {
398  const blas_int size = BlockSideLength * BlockSideLength;
399  const blas_int int_one = 1;
400  scopy_(&size, a, &int_one, c, &int_one);
401  }
402 };
403 
404 //! c = a + b; for blas_double_complex entries
405 template <unsigned BlockSideLength>
406 struct low_level_matrix_binary_ass_op<blas_double_complex, BlockSideLength, false, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
407 {
408  low_level_matrix_binary_ass_op(blas_double_complex* c, const blas_double_complex* a, const blas_double_complex* b, typename matrix_operations<blas_double_complex, BlockSideLength>::addition = typename matrix_operations<blas_double_complex, BlockSideLength>::addition())
409  {
410  if (a)
411  if (b)
412  {
413  low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
414  (c, a);
415  low_level_matrix_unary_ass_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
416  (c, b);
417  }
418  else
419  low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
420  (c, a);
421  else
422  {
423  assert(b /* do not add nothing to nothing */);
424  low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
425  (c, b);
426  }
427  }
428 };
429 //! c = a - b; for blas_double_complex entries
430 template <unsigned BlockSideLength>
431 struct low_level_matrix_binary_ass_op<blas_double_complex, BlockSideLength, false, false, typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction>
432 {
433  low_level_matrix_binary_ass_op(blas_double_complex* c, const blas_double_complex* a, const blas_double_complex* b,
434  typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction = typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction())
435  {
436  if (a)
437  if (b)
438  {
439  low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
440  (c, a);
441  low_level_matrix_unary_ass_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction>
442  (c, b);
443  }
444  else
445  low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
446  (c, a);
447  else
448  {
449  assert(b /* do not add nothing to nothing */);
450  low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction>
451  (c, b);
452  }
453  }
454 };
455 //! c += a; for blas_double_complex entries
456 template <unsigned BlockSideLength>
457 struct low_level_matrix_unary_ass_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
458 {
459  low_level_matrix_unary_ass_op(blas_double_complex* c, const blas_double_complex* a,
460  typename matrix_operations<blas_double_complex, BlockSideLength>::addition = typename matrix_operations<blas_double_complex, BlockSideLength>::addition())
461  {
462  const blas_int size = BlockSideLength * BlockSideLength;
463  const blas_int int_one = 1;
464  const blas_double_complex one = 1.0;
465  if (a)
466  zaxpy_(&size, &one, a, &int_one, c, &int_one);
467  }
468 };
469 //! c -= a; for blas_double_complex entries
470 template <unsigned BlockSideLength>
471 struct low_level_matrix_unary_ass_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction>
472 {
473  low_level_matrix_unary_ass_op(blas_double_complex* c, const blas_double_complex* a,
474  typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction = typename matrix_operations<blas_double_complex, BlockSideLength>::subtraction())
475  {
476  const blas_int size = BlockSideLength * BlockSideLength;
477  const blas_int int_one = 1;
478  const blas_double_complex minusone = -1.0;
479  if (a)
480  zaxpy_(&size, &minusone, a, &int_one, c, &int_one);
481  }
482 };
483 //! c = a; for blas_double_complex entries
484 template <unsigned BlockSideLength>
485 struct low_level_matrix_unary_op<blas_double_complex, BlockSideLength, false, typename matrix_operations<blas_double_complex, BlockSideLength>::addition>
486 {
487  low_level_matrix_unary_op(blas_double_complex* c, const blas_double_complex* a,
488  typename matrix_operations<blas_double_complex, BlockSideLength>::addition = typename matrix_operations<blas_double_complex, BlockSideLength>::addition())
489  {
490  const blas_int size = BlockSideLength * BlockSideLength;
491  const blas_int int_one = 1;
492  zcopy_(&size, a, &int_one, c, &int_one);
493  }
494 };
495 
496 //! c = a + b; for blas_single_complex entries
497 template <unsigned BlockSideLength>
498 struct low_level_matrix_binary_ass_op<blas_single_complex, BlockSideLength, false, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
499 {
500  low_level_matrix_binary_ass_op(blas_single_complex* c, const blas_single_complex* a, const blas_single_complex* b, typename matrix_operations<blas_single_complex, BlockSideLength>::addition = typename matrix_operations<blas_single_complex, BlockSideLength>::addition())
501  {
502  if (a)
503  if (b)
504  {
505  low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
506  (c, a);
507  low_level_matrix_unary_ass_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
508  (c, b);
509  }
510  else
511  low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
512  (c, a);
513  else
514  {
515  assert(b /* do not add nothing to nothing */);
516  low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
517  (c, b);
518  }
519  }
520 };
521 //! c = a - b; for blas_single_complex entries
522 template <unsigned BlockSideLength>
523 struct low_level_matrix_binary_ass_op<blas_single_complex, BlockSideLength, false, false, typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction>
524 {
525  low_level_matrix_binary_ass_op(blas_single_complex* c, const blas_single_complex* a, const blas_single_complex* b,
526  typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction = typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction())
527  {
528  if (a)
529  if (b)
530  {
531  low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
532  (c, a);
533  low_level_matrix_unary_ass_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction>
534  (c, b);
535  }
536  else
537  low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
538  (c, a);
539  else
540  {
541  assert(b /* do not add nothing to nothing */);
542  low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction>
543  (c, b);
544  }
545  }
546 };
547 //! c += a; for blas_single_complex entries
548 template <unsigned BlockSideLength>
549 struct low_level_matrix_unary_ass_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
550 {
551  low_level_matrix_unary_ass_op(blas_single_complex* c, const blas_single_complex* a,
552  typename matrix_operations<blas_single_complex, BlockSideLength>::addition = typename matrix_operations<blas_single_complex, BlockSideLength>::addition())
553  {
554  const blas_int size = BlockSideLength * BlockSideLength;
555  const blas_int int_one = 1;
556  const blas_single_complex one = 1.0;
557  if (a)
558  caxpy_(&size, &one, a, &int_one, c, &int_one);
559  }
560 };
561 //! c -= a; for blas_single_complex entries
562 template <unsigned BlockSideLength>
563 struct low_level_matrix_unary_ass_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction>
564 {
565  low_level_matrix_unary_ass_op(blas_single_complex* c, const blas_single_complex* a,
566  typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction = typename matrix_operations<blas_single_complex, BlockSideLength>::subtraction())
567  {
568  const blas_int size = BlockSideLength * BlockSideLength;
569  const blas_int int_one = 1;
570  const blas_single_complex minusone = -1.0;
571  if (a)
572  caxpy_(&size, &minusone, a, &int_one, c, &int_one);
573  }
574 };
575 //! c = a; for blas_single_complex entries
576 template <unsigned BlockSideLength>
577 struct low_level_matrix_unary_op<blas_single_complex, BlockSideLength, false, typename matrix_operations<blas_single_complex, BlockSideLength>::addition>
578 {
579  low_level_matrix_unary_op(blas_single_complex* c, const blas_single_complex* a,
580  typename matrix_operations<blas_single_complex, BlockSideLength>::addition = typename matrix_operations<blas_single_complex, BlockSideLength>::addition())
581  {
582  const blas_int size = BlockSideLength * BlockSideLength;
583  const blas_int int_one = 1;
584  ccopy_(&size, a, &int_one, c, &int_one);
585  }
586 };
587 
588 // --- matrix-matrix multiplication ---------------
589 
590 extern "C" void dgemm_(const char* transa, const char* transb,
591  const blas_int* m, const blas_int* n, const blas_int* k,
592  const double* alpha, const double* a, const blas_int* lda,
593  const double* b, const blas_int* ldb,
594  const double* beta, double* c, const blas_int* ldc);
595 
596 extern "C" void sgemm_(const char* transa, const char* transb,
597  const blas_int* m, const blas_int* n, const blas_int* k,
598  const float* alpha, const float* a, const blas_int* lda,
599  const float* b, const blas_int* ldb,
600  const float* beta, float* c, const blas_int* ldc);
601 
602 extern "C" void zgemm_(const char* transa, const char* transb,
603  const blas_int* m, const blas_int* n, const blas_int* k,
604  const blas_double_complex* alpha, const blas_double_complex* a, const blas_int* lda,
605  const blas_double_complex* b, const blas_int* ldb,
606  const blas_double_complex* beta, blas_double_complex* c, const blas_int* ldc);
607 
608 extern "C" void cgemm_(const char* transa, const char* transb,
609  const blas_int* m, const blas_int* n, const blas_int* k,
610  const blas_single_complex* alpha, const blas_single_complex* a, const blas_int* lda,
611  const blas_single_complex* b, const blas_int* ldb,
612  const blas_single_complex* beta, blas_single_complex* c, const blas_int* ldc);
613 
614 template <typename ValueType>
615 void gemm_(const char* transa, const char* transb,
616  const blas_int* m, const blas_int* n, const blas_int* k,
617  const ValueType* alpha, const ValueType* a, const blas_int* lda,
618  const ValueType* b, const blas_int* ldb,
619  const ValueType* beta, ValueType* c, const blas_int* ldc);
620 
621 //! calculates c = alpha * a * b + beta * c
622 //! \tparam ValueType type of elements
623 //! \param n height of a and c
624 //! \param l width of a and height of b
625 //! \param m width of b and c
626 //! \param a_in_col_major if a is stored in column-major rather than row-major
627 //! \param b_in_col_major if b is stored in column-major rather than row-major
628 //! \param c_in_col_major if c is stored in column-major rather than row-major
629 template <typename ValueType>
630 void gemm_wrapper(const blas_int n, const blas_int l, const blas_int m,
631  const ValueType alpha, const bool a_in_col_major, const ValueType* a,
632  const bool b_in_col_major, const ValueType* b,
633  const ValueType beta, const bool c_in_col_major, ValueType* c)
634 {
635  const blas_int& stride_in_a = a_in_col_major ? n : l;
636  const blas_int& stride_in_b = b_in_col_major ? l : m;
637  const blas_int& stride_in_c = c_in_col_major ? n : m;
638  const char transa = a_in_col_major xor c_in_col_major ? 'T' : 'N';
639  const char transb = b_in_col_major xor c_in_col_major ? 'T' : 'N';
640  if (c_in_col_major)
641  // blas expects matrices in column-major unless specified via transa rsp. transb
642  gemm_(&transa, &transb, &n, &m, &l, &alpha, a, &stride_in_a, b, &stride_in_b, &beta, c, &stride_in_c);
643  else
644  // blas expects matrices in column-major, so we calculate c^T = alpha * b^T * a^T + beta * c^T
645  gemm_(&transb, &transa, &m, &n, &l, &alpha, b, &stride_in_b, a, &stride_in_a, &beta, c, &stride_in_c);
646 }
647 
648 template <>
649 void gemm_(const char* transa, const char* transb,
650  const blas_int* m, const blas_int* n, const blas_int* k,
651  const double* alpha, const double* a, const blas_int* lda,
652  const double* b, const blas_int* ldb,
653  const double* beta, double* c, const blas_int* ldc)
654 {
655  dgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
656 }
657 
658 template <>
659 void gemm_(const char* transa, const char* transb,
660  const blas_int* m, const blas_int* n, const blas_int* k,
661  const float* alpha, const float* a, const blas_int* lda,
662  const float* b, const blas_int* ldb,
663  const float* beta, float* c, const blas_int* ldc)
664 {
665  sgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
666 }
667 
668 template <>
669 void gemm_(const char* transa, const char* transb,
670  const blas_int* m, const blas_int* n, const blas_int* k,
671  const blas_double_complex* alpha, const blas_double_complex* a, const blas_int* lda,
672  const blas_double_complex* b, const blas_int* ldb,
673  const blas_double_complex* beta, blas_double_complex* c, const blas_int* ldc)
674 {
675  zgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
676 }
677 
678 template <>
679 void gemm_(const char* transa, const char* transb,
680  const blas_int* m, const blas_int* n, const blas_int* k,
681  const blas_single_complex* alpha, const blas_single_complex* a, const blas_int* lda,
682  const blas_single_complex* b, const blas_int* ldb,
683  const blas_single_complex* beta, blas_single_complex* c, const blas_int* ldc)
684 {
685  cgemm_(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
686 }
687 
688 //! multiplies matrices A and B, adds result to C, for double entries
689 template <unsigned BlockSideLength>
690 struct low_level_matrix_multiply_and_add<double, BlockSideLength>
691 {
692  low_level_matrix_multiply_and_add(const double* a, bool a_in_col_major,
693  const double* b, bool b_in_col_major,
694  double* c, const bool c_in_col_major)
695  {
696  gemm_wrapper<double>(BlockSideLength, BlockSideLength, BlockSideLength,
697  1.0, a_in_col_major, a,
698  /**/ b_in_col_major, b,
699  1.0, c_in_col_major, c);
700  }
701 };
702 
703 //! multiplies matrices A and B, adds result to C, for float entries
704 template <unsigned BlockSideLength>
705 struct low_level_matrix_multiply_and_add<float, BlockSideLength>
706 {
707  low_level_matrix_multiply_and_add(const float* a, bool a_in_col_major,
708  const float* b, bool b_in_col_major,
709  float* c, const bool c_in_col_major)
710  {
711  gemm_wrapper<float>(BlockSideLength, BlockSideLength, BlockSideLength,
712  1.0, a_in_col_major, a,
713  /**/ b_in_col_major, b,
714  1.0, c_in_col_major, c);
715  }
716 };
717 
718 //! multiplies matrices A and B, adds result to C, for complex<float> entries
719 template <unsigned BlockSideLength>
720 struct low_level_matrix_multiply_and_add<blas_single_complex, BlockSideLength>
721 {
722  low_level_matrix_multiply_and_add(const blas_single_complex* a, bool a_in_col_major,
723  const blas_single_complex* b, bool b_in_col_major,
724  blas_single_complex* c, const bool c_in_col_major)
725  {
726  gemm_wrapper<blas_single_complex>(BlockSideLength, BlockSideLength, BlockSideLength,
727  1.0, a_in_col_major, a,
728  /**/ b_in_col_major, b,
729  1.0, c_in_col_major, c);
730  }
731 };
732 
733 //! multiplies matrices A and B, adds result to C, for complex<double> entries
734 template <unsigned BlockSideLength>
735 struct low_level_matrix_multiply_and_add<blas_double_complex, BlockSideLength>
736 {
737  low_level_matrix_multiply_and_add(const blas_double_complex* a, bool a_in_col_major,
738  const blas_double_complex* b, bool b_in_col_major,
739  blas_double_complex* c, const bool c_in_col_major)
740  {
741  gemm_wrapper<blas_double_complex>(BlockSideLength, BlockSideLength, BlockSideLength,
742  1.0, a_in_col_major, a,
743  /**/ b_in_col_major, b,
744  1.0, c_in_col_major, c);
745  }
746 };
747 #endif
748 
749 } // namespace matrix_local
750 
751 //! \}
752 
754 
755 #endif // !STXXL_CONTAINERS_MATRIX_LOW_LEVEL_HEADER
multiplies matrices A and B, adds result to C, for arbitrary entries param pointer to blocks of A...
c = a [op] b; for arbitrary entries
low_level_matrix_multiply_and_add(const ValueType *a, bool a_in_col_major, const ValueType *b, bool b_in_col_major, ValueType *c, const bool c_in_col_major)
choose_int_types< my_pointer_size >::int_type int_type
Definition: types.h:63
#define STXXL_BEGIN_NAMESPACE
Definition: namespace.h:16
choose_int_types< my_pointer_size >::unsigned_type unsigned_type
Definition: types.h:64
switch_major_index(const int_type row, const int_type col)
#define STXXL_END_NAMESPACE
Definition: namespace.h:17