PLaSK library
Loading...
Searching...
No Matches
gauss_matrix.hpp
Go to the documentation of this file.
1/*
2 * This file is part of PLaSK (https://plask.app) by Photonics Group at TUL
3 * Copyright (c) 2022 Lodz University of Technology
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, version 3.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 */
14#ifndef PLASK_COMMON_FEM_GAUSS_MATRIX_H
15#define PLASK_COMMON_FEM_GAUSS_MATRIX_H
16
17#include <cstddef>
18
19#include "matrix.hpp"
20
21// BLAS routine to multiply matrix by vector
22#define dgbmv F77_GLOBAL(dgbmv, DGBMV)
23F77SUB dgbmv(const char& trans,
24 const int& m,
25 const int& n,
26 const int& kl,
27 const int& ku,
28 const double& alpha,
29 double* a,
30 const int& lda,
31 const double* x,
32 int incx,
33 const double& beta,
34 double* y,
35 int incy);
36
37// LAPACK routines to solve set of linear equations
38#define dgbtrf F77_GLOBAL(dgbtrf, DGBTRF)
39F77SUB dgbtrf(const int& m, const int& n, const int& kl, const int& ku, double* ab, const int& ldab, int* ipiv, int& info);
40
41#define dgbtrs F77_GLOBAL(dgbtrs, DGBTRS)
42F77SUB dgbtrs(const char& trans,
43 const int& n,
44 const int& kl,
45 const int& ku,
46 const int& nrhs,
47 double* ab,
48 const int& ldab,
49 int* ipiv,
50 double* b,
51 const int& ldb,
52 int& info);
53
54namespace plask {
55
61 const size_t shift;
62
64
70 DgbMatrix(const Solver* solver, size_t rank, size_t band)
71 : BandMatrix(solver, rank, band, ((3 * band + 1 + (15 / sizeof(double))) & ~size_t(15 / sizeof(double))) - 1),
72 shift(2 * band) {}
73
74 DgbMatrix(const DgbMatrix&) = delete;
75
76 size_t index(size_t r, size_t c) {
77 assert(r < rank && c < rank);
78 if (r < c) {
79 assert(c - r <= kd);
80 // AB(kl+ku+1+i-j,j) = A(i,j)
81 return shift + r + ld * c;
82 } else {
83 assert(r - c <= kd);
84 return shift + c + ld * r;
85 }
86 }
87
88 double& operator()(size_t r, size_t c) override { return data[index(r, c)]; }
89
90 void factorize() override {
91 solver->writelog(LOG_DETAIL, "Factorizing system");
92
93 int info = 0;
95
96 mirror();
97
98 // Factorize matrix
99 dgbtrf(int(rank), int(rank), int(kd), int(kd), data, int(ld + 1), ipiv.get(), info);
100 if (info < 0) {
101 throw CriticalException("{0}: Argument {1} of `dgbtrf` has illegal value", solver->getId(), -info);
102 } else if (info > 0) {
103 throw ComputationError(solver->getId(), "matrix is singular (at {0})", info);
104 }
105 }
106
108 solver->writelog(LOG_DETAIL, "Solving matrix system");
109
110 int info = 0;
111 dgbtrs('N', int(rank), int(kd), int(kd), 1, data, int(ld + 1), ipiv.get(), B.data(), int(B.size()), info);
112 if (info < 0) throw CriticalException("{0}: Argument {1} of `dgbtrs` has illegal value", solver->getId(), -info);
113
114 std::swap(B, X);
115 }
116
123 mirror();
124 dgbmv('N', int(rank), int(rank), int(kd), int(kd), 1.0, data, int(ld) + 1, vector.data(), 1, 0.0, result.data(), 1);
125 }
126
133 mirror();
134 dgbmv('N', int(rank), int(rank), int(kd), int(kd), 1.0, data, int(ld) + 1, vector.data(), 1, 1.0, result.data(), 1);
135 }
136
137 private:
139 void mirror() {
140 for (size_t i = 0; i < rank; ++i) {
141 size_t ldi = shift + (ld + 1) * i;
142 size_t knd = min(kd, rank - 1 - i);
143 for (size_t j = 1; j <= knd; ++j) data[ldi + j] = data[ldi + ld * j];
144 }
145 }
146};
147
148} // namespace plask
149
150#endif // PLASK_COMMON_FEM_GAUSS_MATRIX_H