FlexBox - A Flexible Primal-Dual ToolBox
flexConcatOperator.h
1 #ifndef flexConcatOperator_H
2 #define flexConcatOperator_H
3 
4 
5 #include "vector"
6 #include "tools.h"
7 #include "flexLinearOperator.h"
8 
10 template<typename T>
12 {
13 
14 #ifdef __CUDACC__
15  typedef thrust::device_vector<T> Tdata;
16 #else
17  typedef std::vector<T> Tdata;
18 #endif
19 
20 private:
23  mySign s;
24  Tdata tmpVec;
25 
26 public:
27 
29 
35  flexConcatOperator(flexLinearOperator<T>* aA, flexLinearOperator<T>* aB, mySign aS, bool aMinus) : A(aA), B(aB), s(aS), tmpVec(aA->getNumCols()), flexLinearOperator<T>(aA->getNumRows(), aB->getNumCols(), concatOp, aMinus)
36  {
37 
38  }
39 
41  {
42  auto cpOp = new flexConcatOperator<T>(this->A, this->B, this->s, this->isMinus);
43 
44  return cpOp;
45  }
46 
47  //to implement
48  void times(bool transposed, const Tdata &input, Tdata &output)
49  {
50  }
51 
52  void timesPlus(bool transposed, const Tdata &input, Tdata &output)
53  {
54  switch (this->s)
55  {
56  case PLUS:
57  {
58  if (this->isMinus)
59  {
60  A->timesMinus(transposed, input, output);
61  B->timesMinus(transposed, input, output);
62  }
63  else
64  {
65  A->timesPlus(transposed, input, output);
66  B->timesPlus(transposed, input, output);
67  }
68  break;
69  }
70  case MINUS:
71  {
72  if (this->isMinus)
73  {
74  A->timesMinus(transposed, input, output);
75  B->timesPlus(transposed, input, output);
76  }
77  else
78  {
79  A->timesPlus(transposed, input, output);
80  B->timesMinus(transposed, input, output);
81  }
82  break;
83  }
84  case COMPOSE:
85  {
86  if (transposed)
87  {
88  //apply A first
89  #ifdef __CUDACC__
90  thrust::fill(this->tmpVec.begin(), this->tmpVec.end(), (T)0);
91  #else
92  std::fill(this->tmpVec.begin(), this->tmpVec.end(), (T)0);
93  #endif
94 
95  A->timesPlus(true, input, this->tmpVec);
96 
97  if (this->isMinus)
98  {
99  B->timesMinus(true, this->tmpVec, output);
100  }
101  else
102  {
103  B->timesPlus(true, this->tmpVec, output);
104  }
105  }
106  else
107  {
108  //apply B first
109  #ifdef __CUDACC__
110  thrust::fill(this->tmpVec.begin(), this->tmpVec.end(), (T)0);
111  #else
112  std::fill(this->tmpVec.begin(), this->tmpVec.end(), (T)0);
113  #endif
114 
115  B->timesPlus(false, input, this->tmpVec);
116 
117  if (this->isMinus)
118  {
119  A->timesMinus(false, this->tmpVec, output);
120  }
121  else
122  {
123  A->timesPlus(false, this->tmpVec, output);
124  }
125  }
126  break;
127  }
128  }
129  }
130 
131  void timesMinus(bool transposed, const Tdata &input, Tdata &output)
132  {
133  switch (this->s)
134  {
135  case PLUS:
136  {
137  if (this->isMinus)
138  {
139  A->timesPlus(transposed, input, output);
140  B->timesPlus(transposed, input, output);
141  }
142  else
143  {
144  A->timesMinus(transposed, input, output);
145  B->timesMinus(transposed, input, output);
146  }
147  break;
148  }
149  case MINUS:
150  {
151  if (this->isMinus)
152  {
153  A->timesPlus(transposed, input, output);
154  B->timesMinus(transposed, input, output);
155  }
156  else
157  {
158  A->timesMinus(transposed, input, output);
159  B->timesPlus(transposed, input, output);
160  }
161  break;
162  }
163  case COMPOSE:
164  {
165  if (transposed)
166  {
167  //apply A first
168  #ifdef __CUDACC__
169  thrust::fill(this->tmpVec.begin(), this->tmpVec.end(), (T)0);
170  #else
171  std::fill(this->tmpVec.begin(), this->tmpVec.end(), (T)0);
172  #endif
173 
174  A->timesPlus(true, input, tmpVec);
175  if (this->isMinus)
176  {
177  B->timesPlus(true, this->tmpVec, output);
178  }
179  else
180  {
181  B->timesMinus(true, this->tmpVec, output);
182  }
183  }
184  else
185  {
186  //apply B first
187  #ifdef __CUDACC__
188  thrust::fill(this->tmpVec.begin(), this->tmpVec.end(), (T)0);
189  #else
190  std::fill(this->tmpVec.begin(), this->tmpVec.end(), (T)0);
191  #endif
192 
193  B->timesPlus(false, input, this->tmpVec);
194  if (this->isMinus)
195  {
196  A->timesPlus(false, this->tmpVec, output);
197  }
198  else
199  {
200  A->timesMinus(false, this->tmpVec, output);
201  }
202  }
203  break;
204  }
205  }
206  }
207 
208  //TODO
209  T getMaxRowSumAbs(bool transposed)
210  {
211  return static_cast<T>(1);
212  }
213 
214  std::vector<T> getAbsRowSum(bool transposed)
215  {
216  std::vector<T> result;
217 
218  auto rowSumA = A->getAbsRowSum(transposed);
219  auto rowSumB = B->getAbsRowSum(transposed);
220 
221  switch (this->s)
222  {
223  case PLUS:
224  result.resize(rowSumA.size());
225 
226  #pragma omp parallel for
227  for (int k = 0; k < result.size(); ++k)
228  {
229  result[k] = rowSumA[k] + rowSumB[k];
230  }
231  break;
232  case MINUS:
233  {
234  result.resize(rowSumA.size());
235 
236  #pragma omp parallel for
237  for (int k = 0; k < result.size(); ++k)
238  {
239  result[k] = rowSumA[k] + rowSumB[k];
240  }
241  break;
242  }
243  case COMPOSE:
244  {
245  T maxA = *std::max_element(rowSumA.begin(), rowSumA.end());
246  T maxB = *std::max_element(rowSumB.begin(), rowSumB.end());
247  T maxProd = maxA * maxB;
248 
249  if(transposed)
250  result.resize(this->B->getNumCols(), maxProd);
251  else
252  result.resize(this->A->getNumRows(), maxProd);
253 
254  break;
255  }
256  }
257 
258  return result;
259  }
260 
261  #ifdef __CUDACC__
262  thrust::device_vector<T> getAbsRowSumCUDA(bool transposed)
263  {
264  Tdata result;
265 
266  auto rowSumA = A->getAbsRowSumCUDA(transposed);
267  auto rowSumB = B->getAbsRowSumCUDA(transposed);
268 
269  switch (this->s)
270  {
271  case PLUS:
272  {
273  result.resize(rowSumA.size());
274 
275  #pragma omp parallel for
276  for (int k = 0; k < result.size(); ++k)
277  {
278  result[k] = rowSumA[k] + rowSumB[k];
279  }
280  break;
281  }
282  case MINUS:
283  {
284  result.resize(rowSumA.size());
285 
286  #pragma omp parallel for
287  for (int k = 0; k < result.size(); ++k)
288  {
289  result[k] = rowSumA[k] + rowSumB[k];
290  }
291  break;
292  }
293  case COMPOSE:
294  {
295  T maxA = *thrust::max_element(rowSumA.begin(), rowSumA.end());
296  T maxB = *thrust::max_element(rowSumB.begin(), rowSumB.end());
297  T maxProd = maxA * maxB;
298 
299  if(transposed)
300  result.resize(this->B->getNumCols(), maxProd);
301  else
302  result.resize(this->A->getNumRows(), maxProd);
303 
304  break;
305  }
306  }
307 
308  return result;
309  }
310  #endif
311 };
312 
313 #endif
int getNumRows() const
returns number of rows of the linear operator
Definition: flexLinearOperator.h:57
void timesPlus(bool transposed, const Tdata &input, Tdata &output)
applies linear operator on vector and adds its result to y
Definition: flexConcatOperator.h:52
void timesMinus(bool transposed, const Tdata &input, Tdata &output)
applies linear operator on vector and substracts its result from y
Definition: flexConcatOperator.h:131
bool isMinus
determines if operator is negated
Definition: flexLinearOperator.h:25
int getNumCols() const
returns number of columns of the linear operator
Definition: flexLinearOperator.h:48
thrust::device_vector< T > getAbsRowSumCUDA(bool transposed)
same function as getAbsRowSum() but implemented in CUDA
Definition: flexConcatOperator.h:262
represents a concatenation operator
Definition: flexConcatOperator.h:11
flexConcatOperator(flexLinearOperator< T > *aA, flexLinearOperator< T > *aB, mySign aS, bool aMinus)
initializes the concatenation operator
Definition: flexConcatOperator.h:35
virtual void timesPlus(bool transposed, const Tdata &input, Tdata &output)=0
applies linear operator on vector and adds its result to y
std::vector< T > getAbsRowSum(bool transposed)
returns a vector of sum of absolute values per row used for preconditioning
Definition: flexConcatOperator.h:214
T getMaxRowSumAbs(bool transposed)
returns the maximum sum of absolute values per row used for preconditioning
Definition: flexConcatOperator.h:209
flexConcatOperator< T > * copy()
copies the linear operator
Definition: flexConcatOperator.h:40
mySign
enum representing the type of concatenation
Definition: tools.h:56
virtual thrust::device_vector< T > getAbsRowSumCUDA(bool transposed)=0
same function as getAbsRowSum() but implemented in CUDA
virtual std::vector< T > getAbsRowSum(bool transposed)=0
returns a vector of sum of absolute values per row used for preconditioning
void times(bool transposed, const Tdata &input, Tdata &output)
applies linear operator on vector
Definition: flexConcatOperator.h:48
virtual void timesMinus(bool transposed, const Tdata &input, Tdata &output)=0
applies linear operator on vector and substracts its result from y
abstract base class for linear operators
Definition: flexLinearOperator.h:12