PLaSK library
Loading...
Searching...
No Matches
matrices.cpp
Go to the documentation of this file.
1/*
2 * This file is part of PLaSK (https://plask.app) by Photonics Group at TUL
3 * Copyright (c) 2022 Lodz University of Technology
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, version 3.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 */
14#include "matrices.hpp"
15
16namespace plask { namespace optical { namespace modal {
17
19{
20 // Check if the A is a square matrix
21 if (A.rows() != A.cols())
22 throw ComputationError("invmult", "cannot invert rectangular matrix");
23 const std::size_t N = A.rows();
24 // Check if can multiply
25 if (B.rows() != N)
26 throw ComputationError("invmult", "cannot multiply matrices because of the dimensions mismatch");
27 const std::size_t nrhs = B.cols();
28 // Needed variables
29 std::unique_ptr<int[]> ipiv(new int[N]);
30 int info;
31 // Perform the calculation
32 zgesv(int(N), int(nrhs), A.data(), int(N), ipiv.get(), B.data(), int(N), info);
33 // Return the result
34 if (info > 0) throw ComputationError("invmult", "matrix is singular");
35 return B;
36}
37
38
40{
41 // Check if the A is a square matrix
42 if (A.rows() != A.cols())
43 throw ComputationError("invmult", "cannot invert rectangular matrix");
44 const std::size_t N = A.rows();
45 // Check if can multiply
46 if (B.size() != N)
47 throw ComputationError("invmult", "cannot multiply matrix by vector because of the dimensions mismatch");
48 // Needed variables
49 std::unique_ptr<int[]> ipiv(new int[N]);
50 int info;
51 // Perform the calculation
52 zgesv(int(N), 1, A.data(), int(N), ipiv.get(), B.data(), int(N), info);
53 // Return the result
54 if (info > 0) throw ComputationError("invmult", "matrix is singular");
55 return B;
56}
57
58
60{
61 // Check if the A is a square matrix
62 if (A.rows() != A.cols())
63 throw ComputationError("inv", "cannot invert rectangular matrix");
64 const std::size_t N = A.rows();
65
66 // Simply call invmult(A, I)
67 cmatrix result(N, N, 0.);
68 for (std::size_t i = 0; i < N; i++) result(i, i) = 1;
69
70 invmult(A, result);
71
72 return result;
73}
74
75
76dcomplex det(cmatrix& A)
77{
78 // Check if the A is a square matrix
79 if (A.rows() != A.cols())
80 throw ComputationError("det", "cannot find the determinant of rectangular matrix");
81 const std::size_t N = A.rows();
82 // Needed variables
83 std::unique_ptr<int[]> ipiv(new int[N]);
84 int info;
85 // Find the LU factorization
86 zgetrf(int(N), int(N), A.data(), int(N), ipiv.get(), info);
87 // Ok, now compute the determinant
88 dcomplex det = 1.; int p = 1;
89 for (std::size_t i = 0; i < N; i++) {
90 det *= A(i,i);
91 if (std::size_t(ipiv[i]) != i+1) p = -p;
92 }
93 // Return the result
94 if (p < 0) return -det; else return det;
95}
96
97
98
100{
101 // Check the validity of the matrices
102 if (A.rows() != A.cols())
103 throw ComputationError("eigenv", "matrix A should be square");
104 const std::size_t N = A.rows();
105 if (vals.size() != N)
106 throw ComputationError("eigenv", "eigenvalues should have the same number of rows as the original matrix.");
107 if (rightv) if (rightv->rows() != N || rightv->cols() != N)
108 throw ComputationError("eigenv", "matrices for right eigenvectors should be square");
109 if (leftv) if (leftv->rows() != N || leftv->cols() != N)
110 throw ComputationError("eigenv", "matrices for left eigenvectors should be square");
111
112 // Determine the task
113 char jobvl = (leftv==NULL)? 'N' : 'V';
114 char jobvr = (rightv==NULL)? 'N' : 'V';
115
116 // Determine the storage place for eigenvectors
117 dcomplex* vl = (leftv==NULL)? NULL : leftv->data();
118 dcomplex* vr = (rightv==NULL)? NULL : rightv->data();
119
120 // Create the workplace
121 const std::size_t lwork = 2*N+1;
122 //int lwork = N*N;
125
126 // Call the lapack subroutine
127 int info;
128 zgeev(jobvl, jobvr, int(N), A.data(), int(N), vals.data(), vl, int(N), vr, int(N), work.get(), int(lwork), rwork.get(), info);
129
130 return info;
131}
132
133}}} // namespace plask::optical::modal;