PLaSK library
Loading...
Searching...
No Matches
temp_matrix.hpp
Go to the documentation of this file.
1/*
2 * This file is part of PLaSK (https://plask.app) by Photonics Group at TUL
3 * Copyright (c) 2022 Lodz University of Technology
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, version 3.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * GNU General Public License for more details.
13 */
14#ifndef PLASK__SOLVER_OPTICAL_MODAL_TEMPMATRIX_H
15#define PLASK__SOLVER_OPTICAL_MODAL_TEMPMATRIX_H
16
17#include <plask/plask.hpp>
18
19#ifdef OPENMP_FOUND
20# include <omp.h>
21#endif
22
23#include "matrices.hpp"
24
25namespace plask { namespace optical { namespace modal {
26
27struct TempMatrix;
28
30 private:
31 cmatrix* tmpmx;
32
33 #ifdef OPENMP_FOUND
35 #endif
36
37 friend struct TempMatrix;
38
39 public:
40
42 #ifdef OPENMP_FOUND
43 const int nthr = omp_get_max_threads();
44 tmpmx = new cmatrix[nthr];
46 for (int i = 0; i != nthr; ++i) {
48 }
49 #else
50 tmpmx = new cmatrix();
51 #endif
52 }
53
55 #ifdef OPENMP_FOUND
56 write_debug("destroying temporary matrices");
57 const int nthr = omp_get_max_threads();
58 for (int i = 0; i != nthr; ++i) {
60 }
61 delete[] tmpmx;
62 delete[] tmplx;
63 #else
64 write_debug("destroying temporary matrix");
65 delete tmpmx;
66 #endif
67 }
68
69 TempMatrix get(size_t rows, size_t cols);
70
71 void reset() {
72 #ifdef OPENMP_FOUND
73 write_debug("freeing temporary matrices");
74 const int nthr = omp_get_max_threads();
75 for (int i = 0; i != nthr; ++i) {
76 tmpmx[i].reset();
77 }
78 #else
79 write_debug("freeing temporary matrix");
80 tmpmx->reset();
81 #endif
82 }
83
84
85};
86
87
88struct TempMatrix {
90 size_t rows, cols;
91 #ifdef OPENMP_FOUND
92 int mn;
93 #endif
94
95 #ifdef OPENMP_FOUND
96 TempMatrix(TempMatrixPool* pool, size_t rows, size_t cols): pool(pool), rows(rows), cols(cols) {
97 const int nthr = omp_get_max_threads();
98 int l;
99 for (mn = 0; mn != nthr; ++mn) {
100 l = omp_test_nest_lock(pool->tmplx+mn);
101 if (l) break;
102 }
103 assert(mn != nthr);
104 size_t NN = rows * cols;
105 if (pool->tmpmx[mn].rows() * pool->tmpmx[mn].cols() < NN) {
106 write_debug("allocating temporary matrix {}", mn);
107 pool->tmpmx[mn].reset(rows, cols);
108 }
109 write_debug("acquiring temporary matrix {} in thread {} ({})", mn, omp_get_thread_num(), l);
110 }
111 TempMatrix(TempMatrix&& src): pool(src.pool), mn(src.mn) { src.pool = nullptr; }
112 TempMatrix(const TempMatrix& src) = delete;
113 #else
115 if (pool->tmpmx->rows() * pool->tmpmx->cols() < rows * cols) {
116 write_debug("allocating temporary matrix");
117 pool->tmpmx->reset(rows, cols);
118 }
119 }
120 #endif
121
122 #ifdef OPENMP_FOUND
123 ~TempMatrix() {
124 if (pool) {
125 write_debug("releasing temporary matrix {} in thread {}", mn, omp_get_thread_num());
126 omp_unset_nest_lock(pool->tmplx+mn);
127 }
128 }
129 #endif
130
131 #ifdef OPENMP_FOUND
132 operator cmatrix() {
133 if (pool->tmpmx[mn].rows() == rows && pool->tmpmx[mn].cols() == cols)
134 return pool->tmpmx[mn];
135 else
136 return cmatrix(rows, cols, pool->tmpmx[mn].data());
137 }
138
139 dcomplex* data() { return pool->tmpmx[mn].data(); }
140 #else
141 operator cmatrix() {
142 if (pool->tmpmx->rows() == rows && pool->tmpmx->cols() == cols)
143 return *pool->tmpmx;
144 else
145 return cmatrix(rows, cols, pool->tmpmx->data());
146 }
147
148 dcomplex* data() { return pool->tmpmx->data(); }
149 #endif
150};
151
152inline TempMatrix TempMatrixPool::get(size_t rows, size_t cols) {
153 return TempMatrix(this, rows, cols);
154}
155
156
157}}} // namespace plask
158
159#endif // PLASK__SOLVER_OPTICAL_MODAL_TEMPMATRIX_H