Actual source code: taotermmapping.c

  1: #include <petsc/private/taoimpl.h>
  2: #include <petsc/private/matimpl.h>

  4: PETSC_INTERN PetscErrorCode TaoTermMappingSetData(TaoTermMapping *mt, const char *prefix, PetscReal scale, TaoTerm term, Mat map)
  5: {
  6:   PetscBool same_name;

  8:   PetscFunctionBegin;
  9:   PetscCall(PetscStrcmp(prefix, mt->prefix, &same_name));
 10:   if (!same_name) {
 11:     PetscCall(PetscFree(mt->prefix));
 12:     PetscCall(PetscStrallocpy(prefix, &mt->prefix));
 13:   }
 14:   if (term != mt->term) {
 15:     PetscCall(VecDestroy(&mt->_unmapped_gradient));
 16:     PetscCall(MatDestroy(&mt->_unmapped_H));
 17:     PetscCall(MatDestroy(&mt->_unmapped_Hpre));
 18:     PetscCall(MatDestroy(&mt->_mapped_H));
 19:     PetscCall(MatDestroy(&mt->_mapped_Hpre));
 20:     PetscCall(MatDestroy(&mt->_mapped_H_work));
 21:     PetscCall(MatDestroy(&mt->_mapped_Hpre_work));
 22:   }
 23:   PetscCall(PetscObjectReference((PetscObject)term));
 24:   PetscCall(TaoTermDestroy(&mt->term));
 25:   mt->term  = term;
 26:   mt->scale = scale;
 27:   if (map != mt->map) PetscCall(VecDestroy(&mt->_map_output));
 28:   PetscCall(PetscObjectReference((PetscObject)map));
 29:   PetscCall(MatDestroy(&mt->map));
 30:   mt->map = map;
 31:   PetscFunctionReturn(PETSC_SUCCESS);
 32: }

 34: PETSC_INTERN PetscErrorCode TaoTermMappingReset(TaoTermMapping *mt)
 35: {
 36:   PetscFunctionBegin;
 37:   PetscCall(TaoTermMappingSetData(mt, NULL, 0.0, NULL, NULL));
 38:   PetscCall(VecDestroy(&mt->_mapped_gradient));
 39:   PetscCall(MatDestroy(&mt->_unmapped_H));
 40:   PetscCall(MatDestroy(&mt->_unmapped_Hpre));
 41:   PetscCall(MatDestroy(&mt->_mapped_H));
 42:   PetscCall(MatDestroy(&mt->_mapped_Hpre));
 43:   PetscCall(MatDestroy(&mt->_mapped_H_work));
 44:   PetscCall(MatDestroy(&mt->_mapped_Hpre_work));
 45:   mt->mask = TAOTERM_MASK_NONE;
 46:   PetscFunctionReturn(PETSC_SUCCESS);
 47: }

 49: PETSC_INTERN PetscErrorCode TaoTermMappingGetData(TaoTermMapping *mt, const char **prefix, PetscReal *scale, TaoTerm *term, Mat *map)
 50: {
 51:   PetscFunctionBegin;
 52:   if (prefix) *prefix = mt->prefix;
 53:   if (term) *term = mt->term;
 54:   if (scale) *scale = mt->scale;
 55:   if (map) *map = mt->map;
 56:   PetscFunctionReturn(PETSC_SUCCESS);
 57: }

 59: #define TaoTermMappingCheckInsertMode(mt, mode) \
 60:   do { \
 61:     PetscCheck((mt)->term, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONGSTATE, "TaoTermMapping has no TaoTerm set"); \
 62:     PetscCheck((mode) == INSERT_VALUES || (mode) == ADD_VALUES, PetscObjectComm((PetscObject)(mt)->term), PETSC_ERR_ARG_OUTOFRANGE, "insert mode must be INSERT_VALUES or ADD_VALUES"); \
 63:   } while (0)

 65: static PetscErrorCode TaoTermMappingMap(TaoTermMapping *mt, Vec x, Vec *Ax)
 66: {
 67:   PetscFunctionBegin;
 68:   *Ax = x;
 69:   if (mt->map) {
 70:     if (!mt->_map_output) PetscCall(MatCreateVecs(mt->map, NULL, &mt->_map_output));
 71:     PetscCall(MatMult(mt->map, x, mt->_map_output));
 72:     *Ax = mt->_map_output;
 73:   }
 74:   PetscFunctionReturn(PETSC_SUCCESS);
 75: }

 77: PETSC_INTERN PetscErrorCode TaoTermMappingComputeObjective(TaoTermMapping *mt, Vec x, Vec params, InsertMode mode, PetscReal *value)
 78: {
 79:   Vec       Ax;
 80:   PetscReal v;

 82:   PetscFunctionBegin;
 83:   TaoTermMappingCheckInsertMode(mt, mode);
 84:   if (TaoTermObjectiveMasked(mt->mask)) {
 85:     if (mode == INSERT_VALUES) *value = 0.0;
 86:     PetscFunctionReturn(PETSC_SUCCESS);
 87:   }
 88:   PetscCall(TaoTermMappingMap(mt, x, &Ax));
 89:   PetscCall(TaoTermComputeObjective(mt->term, Ax, params, &v));
 90:   if (mode == ADD_VALUES) *value += mt->scale * v;
 91:   else *value = mt->scale * v;
 92:   PetscFunctionReturn(PETSC_SUCCESS);
 93: }

 95: static PetscErrorCode TaoTermMappingGetGradients(TaoTermMapping *mt, InsertMode mode, Vec g, Vec *mapped_g, Vec *unmapped_g)
 96: {
 97:   PetscFunctionBegin;
 98:   *mapped_g = g;
 99:   if (mode == ADD_VALUES) {
100:     if (!mt->_mapped_gradient) PetscCall(VecDuplicate(g, &mt->_mapped_gradient));
101:     *mapped_g = mt->_mapped_gradient;
102:   }
103:   *unmapped_g = *mapped_g;
104:   if (mt->map) {
105:     if (!mt->_unmapped_gradient) PetscCall(TaoTermCreateSolutionVec(mt->term, &mt->_unmapped_gradient));
106:     *unmapped_g = mt->_unmapped_gradient;
107:   }
108:   PetscFunctionReturn(PETSC_SUCCESS);
109: }

111: static PetscErrorCode TaoTermMappingSetGradients(TaoTermMapping *mt, InsertMode mode, Vec g, Vec mapped_g, Vec unmapped_g)
112: {
113:   PetscFunctionBegin;
114:   if (mt->map) PetscCall(MatMultHermitianTranspose(mt->map, unmapped_g, mapped_g));
115:   else PetscAssert(mapped_g == unmapped_g, PETSC_COMM_SELF, PETSC_ERR_PLIB, "gradient not written to the right place");
116:   if (mode == ADD_VALUES) PetscCall(VecAXPY(g, mt->scale, mapped_g));
117:   else {
118:     PetscAssert(mapped_g == g, PETSC_COMM_SELF, PETSC_ERR_PLIB, "gradient not written to the right place");
119:     if (mt->scale != 1.0) PetscCall(VecScale(g, mt->scale));
120:   }
121:   PetscFunctionReturn(PETSC_SUCCESS);
122: }

124: PETSC_INTERN PetscErrorCode TaoTermMappingComputeGradient(TaoTermMapping *mt, Vec x, Vec params, InsertMode mode, Vec g)
125: {
126:   Vec Ax, mapped_g, unmapped_g = NULL;

128:   PetscFunctionBegin;
129:   TaoTermMappingCheckInsertMode(mt, mode);
130:   if (TaoTermGradientMasked(mt->mask)) {
131:     if (mode == INSERT_VALUES) PetscCall(VecZeroEntries(g));
132:     PetscFunctionReturn(PETSC_SUCCESS);
133:   }
134:   PetscCall(TaoTermMappingGetGradients(mt, mode, g, &mapped_g, &unmapped_g));
135:   PetscCall(TaoTermMappingMap(mt, x, &Ax));
136:   PetscCall(TaoTermComputeGradient(mt->term, Ax, params, unmapped_g));
137:   PetscCall(TaoTermMappingSetGradients(mt, mode, g, mapped_g, unmapped_g));
138:   PetscFunctionReturn(PETSC_SUCCESS);
139: }

141: PETSC_INTERN PetscErrorCode TaoTermMappingComputeObjectiveAndGradient(TaoTermMapping *mt, Vec x, Vec params, InsertMode mode, PetscReal *value, Vec g)
142: {
143:   Vec       Ax, mapped_g, unmapped_g = NULL;
144:   PetscReal v;

146:   PetscFunctionBegin;
147:   TaoTermMappingCheckInsertMode(mt, mode);
148:   if (TaoTermObjectiveMasked(mt->mask) && TaoTermGradientMasked(mt->mask)) {
149:     if (mode == INSERT_VALUES) {
150:       *value = 0.0;
151:       PetscCall(VecZeroEntries(g));
152:     }
153:     PetscFunctionReturn(PETSC_SUCCESS);
154:   }
155:   if (TaoTermObjectiveMasked(mt->mask)) {
156:     if (mode == INSERT_VALUES) *value = 0.0;
157:     PetscCall(TaoTermMappingComputeGradient(mt, x, params, mode, g));
158:     PetscFunctionReturn(PETSC_SUCCESS);
159:   }
160:   if (TaoTermGradientMasked(mt->mask)) {
161:     if (mode == INSERT_VALUES) PetscCall(VecZeroEntries(g));
162:     PetscCall(TaoTermMappingComputeObjective(mt, x, params, mode, value));
163:     PetscFunctionReturn(PETSC_SUCCESS);
164:   }
165:   PetscCall(TaoTermMappingGetGradients(mt, mode, g, &mapped_g, &unmapped_g));
166:   PetscCall(TaoTermMappingMap(mt, x, &Ax));
167:   PetscCall(TaoTermComputeObjectiveAndGradient(mt->term, Ax, params, &v, unmapped_g));
168:   PetscCall(TaoTermMappingSetGradients(mt, mode, g, mapped_g, unmapped_g));
169:   if (mode == ADD_VALUES) *value += mt->scale * v;
170:   else *value = mt->scale * v;
171:   PetscFunctionReturn(PETSC_SUCCESS);
172: }

174: static PetscErrorCode TaoTermMappingMatPtAP(Mat unmapped_H, Mat map, Mat mapped_H, Mat work)
175: {
176:   PetscBool is_uH_diag, is_map_diag, is_uH_cdiag, is_map_cdiag;

178:   PetscFunctionBegin;
179:   PetscCall(PetscObjectTypeCompare((PetscObject)unmapped_H, MATDIAGONAL, &is_uH_diag));
180:   PetscCall(PetscObjectTypeCompare((PetscObject)unmapped_H, MATCONSTANTDIAGONAL, &is_uH_cdiag));
181:   PetscCall(PetscObjectTypeCompare((PetscObject)map, MATDIAGONAL, &is_map_diag));
182:   PetscCall(PetscObjectTypeCompare((PetscObject)map, MATCONSTANTDIAGONAL, &is_map_cdiag));

184:   if (is_map_diag) {
185:     Vec m_diag;

187:     PetscCall(MatDiagonalGetDiagonal(map, &m_diag));
188:     if (is_uH_cdiag) {
189:       Vec         mapped_diag;
190:       PetscScalar cc;

192:       // mapped_H \gets cc map * map
193:       PetscCall(MatConstantDiagonalGetConstant(unmapped_H, &cc));
194:       PetscCall(MatDiagonalGetDiagonal(mapped_H, &mapped_diag));
195:       PetscCall(VecPointwiseMult(mapped_diag, m_diag, m_diag));
196:       PetscCall(VecScale(mapped_diag, cc));
197:       PetscCall(MatDiagonalRestoreDiagonal(mapped_H, &mapped_diag));
198:     } else if (is_uH_diag) {
199:       Vec mapped_diag, unmapped_diag;

201:       PetscCall(MatDiagonalGetDiagonal(mapped_H, &mapped_diag));
202:       PetscCall(MatDiagonalGetDiagonal(unmapped_H, &unmapped_diag));
203:       PetscCall(VecPointwiseMult(mapped_diag, m_diag, m_diag));
204:       PetscCall(VecPointwiseMult(mapped_diag, unmapped_diag, mapped_diag));
205:       PetscCall(MatDiagonalRestoreDiagonal(mapped_H, &mapped_diag));
206:       PetscCall(MatDiagonalRestoreDiagonal(unmapped_H, &unmapped_diag));
207:     } else {
208:       PetscCall(MatCopy(unmapped_H, mapped_H, SAME_NONZERO_PATTERN));
209:       PetscCall(MatDiagonalScale(mapped_H, m_diag, m_diag));
210:     }
211:     PetscCall(MatDiagonalRestoreDiagonal(map, &m_diag));
212:   } else if (is_map_cdiag) {
213:     PetscScalar cc;

215:     PetscCall(MatConstantDiagonalGetConstant(map, &cc));
216:     PetscCall(MatCopy(unmapped_H, mapped_H, SAME_NONZERO_PATTERN));
217:     PetscCall(MatScale(mapped_H, cc * cc));
218:   } else if (is_uH_diag) {
219:     Vec unmapped_diag;

221:     // TODO inefficient. Remove when diag PtAP gets implemented
222:     PetscCall(MatDiagonalGetDiagonal(unmapped_H, &unmapped_diag));
223:     PetscCall(MatCopy(map, work, SAME_NONZERO_PATTERN));
224:     PetscCall(MatDiagonalScale(work, unmapped_diag, NULL));
225:     PetscCall(MatTransposeMatMult(map, work, MAT_REUSE_MATRIX, PETSC_DETERMINE, &mapped_H));
226:     PetscCall(MatDiagonalRestoreDiagonal(unmapped_H, &unmapped_diag));
227:   } else if (is_uH_cdiag) {
228:     // cc * A^T A
229:     PetscScalar cc;

231:     PetscCall(MatConstantDiagonalGetConstant(unmapped_H, &cc));
232:     PetscCall(MatTransposeMatMult(map, map, MAT_REUSE_MATRIX, PETSC_DETERMINE, &mapped_H));
233:     PetscCall(MatScale(mapped_H, cc));
234:   } else PetscCall(MatPtAP(unmapped_H, map, MAT_REUSE_MATRIX, PETSC_DETERMINE, &mapped_H));
235:   PetscFunctionReturn(PETSC_SUCCESS);
236: }

238: static PetscErrorCode TaoTermMappingGetHessians(TaoTermMapping *mt, InsertMode mode, Mat H, Mat Hpre, Mat *mapped_H, Mat *mapped_Hpre, Mat *unmapped_H, Mat *unmapped_Hpre)
239: {
240:   PetscFunctionBegin;
241:   *mapped_H    = H;
242:   *mapped_Hpre = Hpre;
243:   if (mode == ADD_VALUES || mt->map) {
244:     // we will need _unmapped_H / _unmapped_Hpre
245:     if (!mt->_unmapped_H) {
246:       PetscBool is_defined = PETSC_FALSE;

248:       PetscCall(TaoTermIsCreateHessianMatricesDefined(mt->term, &is_defined));
249:       if (is_defined) {
250:         PetscCall(MatDestroy(&mt->_unmapped_Hpre));
251:         PetscCall(TaoTermCreateHessianMatrices(mt->term, &mt->_unmapped_H, &mt->_unmapped_Hpre));
252:       }
253:       if (!mt->map) {
254:         PetscCall(PetscObjectReference((PetscObject)mt->_unmapped_H));
255:         PetscCall(MatDestroy(&mt->_mapped_H));
256:         mt->_mapped_H = mt->_unmapped_H;

258:         PetscCall(PetscObjectReference((PetscObject)mt->_unmapped_Hpre));
259:         PetscCall(MatDestroy(&mt->_mapped_Hpre));
260:         mt->_mapped_Hpre = mt->_unmapped_Hpre;
261:       }
262:     }
263:   }
264:   if (mode == ADD_VALUES) {
265:     if (H) {
266:       if (!mt->_mapped_H) PetscCall(MatDuplicate(H, MAT_DO_NOT_COPY_VALUES, &mt->_mapped_H));
267:       *mapped_H = mt->_mapped_H;
268:     }
269:     if (Hpre) {
270:       if (!mt->_mapped_Hpre) PetscCall(MatDuplicate(Hpre, MAT_DO_NOT_COPY_VALUES, &mt->_mapped_Hpre));
271:       *mapped_Hpre = mt->_mapped_Hpre;
272:     }
273:   }
274:   *unmapped_H    = *mapped_H;
275:   *unmapped_Hpre = *mapped_Hpre;
276:   if (mt->map) {
277:     if (H) *unmapped_H = mt->_unmapped_H;
278:     if (Hpre) *unmapped_Hpre = mt->_unmapped_Hpre;
279:   }
280:   PetscFunctionReturn(PETSC_SUCCESS);
281: }

283: // if (map) mapped_H \gets map^T @ unmapped_H @ map
284: // else (assumes that unmapped == mapped.
285: //
286: // if INSERT
287: //   H \gets mapped_H
288: // else if ADD
289: //   H \gets H + scale * mapped_H
290: static PetscErrorCode TaoTermMappingSetHessians(TaoTermMapping *mt, InsertMode mode, Mat H, Mat Hpre, Mat mapped_H, Mat mapped_Hpre, Mat unmapped_H, Mat unmapped_Hpre)
291: {
292:   PetscFunctionBegin;
293:   if (mt->map) {
294:     // currently only implements Gauss-Newton Hessian approximation
295:     if (mapped_H) PetscCall(TaoTermMappingMatPtAP(unmapped_H, mt->map, mapped_H, mt->_mapped_H_work));
296:     if (mapped_Hpre && (mapped_Hpre != mapped_H)) PetscCall(TaoTermMappingMatPtAP(unmapped_Hpre, mt->map, mapped_Hpre, mt->_mapped_Hpre_work));
297:   }
298:   if (mode == ADD_VALUES) {
299:     if (H) PetscCall(MatAXPY(H, mt->scale, mapped_H, UNKNOWN_NONZERO_PATTERN));
300:     if (Hpre) PetscCall(MatAXPY(Hpre, mt->scale, mapped_Hpre, UNKNOWN_NONZERO_PATTERN));
301:   } else {
302:     if (H) PetscCall(MatCopy(mapped_H, H, DIFFERENT_NONZERO_PATTERN));
303:     if (Hpre && (H != Hpre)) PetscCall(MatCopy(mapped_Hpre, Hpre, DIFFERENT_NONZERO_PATTERN));
304:     if (mt->scale != 1.0) {
305:       if (H) PetscCall(MatScale(H, mt->scale));
306:       if (Hpre && Hpre != H) PetscCall(MatScale(Hpre, mt->scale));
307:     }
308:   }
309:   PetscFunctionReturn(PETSC_SUCCESS);
310: }

312: // Either called by TaoComputeHessian (one term in Tao), or by TAOTERMSUM
313: //
314: // First case: (one term in Tao)
315: // TaoComputeHessian
316: //   -> TaoTermMappingComputeHessian
317: //      (unmapped_H == mapped_H)
318: //
319: // Second case: TAOTERMSUM, (more than one term in Tao)
320: // TaoComputeHessian
321: //   -> TaoTermMappingComputeHessian
322: //     -> (mt->_unmapped_H == mt->_mapped_H == tao->hessian) (SUM does not take mapping)
323: //     -> TaoTermComputeHessian
324: //       -> TaoTermComputeHessian_Sum
325: //         -> for(i:n_terms)
326: //         -> TaoTermMappingComputeHessian
327: //           -> (unmapped_H may not == mapped_H)
328: PETSC_INTERN PetscErrorCode TaoTermMappingComputeHessian(TaoTermMapping *mt, Vec x, Vec params, InsertMode mode, Mat H, Mat Hpre)
329: {
330:   Vec Ax;
331:   Mat mapped_H, mapped_Hpre, unmapped_H = NULL, unmapped_Hpre = NULL;

333:   PetscFunctionBegin;
334:   TaoTermMappingCheckInsertMode(mt, mode);
335:   if (TaoTermHessianMasked(mt->mask)) {
336:     if (mode == INSERT_VALUES) {
337:       if (H) {
338:         PetscCall(MatZeroEntries(H));
339:         PetscCall(MatAssemblyBegin(H, MAT_FINAL_ASSEMBLY));
340:         PetscCall(MatAssemblyEnd(H, MAT_FINAL_ASSEMBLY));
341:       }
342:       if (Hpre && Hpre != H) {
343:         PetscCall(MatZeroEntries(Hpre));
344:         PetscCall(MatAssemblyBegin(Hpre, MAT_FINAL_ASSEMBLY));
345:         PetscCall(MatAssemblyEnd(Hpre, MAT_FINAL_ASSEMBLY));
346:       }
347:     }
348:     PetscFunctionReturn(PETSC_SUCCESS);
349:   }
350:   PetscCall(TaoTermMappingMap(mt, x, &Ax));
351:   PetscCall(TaoTermMappingGetHessians(mt, mode, H, Hpre, &mapped_H, &mapped_Hpre, &unmapped_H, &unmapped_Hpre));
352:   PetscCall(TaoTermComputeHessian(mt->term, Ax, params, unmapped_H, unmapped_Hpre));
353:   PetscCall(TaoTermMappingSetHessians(mt, mode, H, Hpre, mapped_H, mapped_Hpre, unmapped_H, unmapped_Hpre));
354:   PetscFunctionReturn(PETSC_SUCCESS);
355: }

357: PETSC_INTERN PetscErrorCode TaoTermMappingSetUp(TaoTermMapping *mt)
358: {
359:   PetscFunctionBegin;
360:   PetscCall(TaoTermSetUp(mt->term));
361:   if (mt->map) PetscCall(MatSetUp(mt->map));
362:   PetscFunctionReturn(PETSC_SUCCESS);
363: }

365: PETSC_INTERN PetscErrorCode TaoTermMappingCreateSolutionVec(TaoTermMapping *mt, Vec *solution)
366: {
367:   PetscFunctionBegin;
368:   if (mt->map) PetscCall(MatCreateVecs(mt->map, solution, NULL));
369:   else PetscCall(TaoTermCreateSolutionVec(mt->term, solution));
370:   PetscFunctionReturn(PETSC_SUCCESS);
371: }

373: PETSC_INTERN PetscErrorCode TaoTermMappingCreateParametersVec(TaoTermMapping *mt, Vec *params)
374: {
375:   PetscFunctionBegin;
376:   PetscCall(TaoTermCreateParametersVec(mt->term, params));
377:   PetscFunctionReturn(PETSC_SUCCESS);
378: }

380: static PetscErrorCode TaoTermMappingCreateAPWorkMatrix(Mat map, Mat unmapped, Mat *mapped_work)
381: {
382:   PetscBool is_uH_diag, is_map_diag;

384:   PetscFunctionBegin;
385:   PetscCall(PetscObjectBaseTypeCompareAny((PetscObject)map, &is_map_diag, MATDIAGONAL, MATCONSTANTDIAGONAL, ""));
386:   PetscCall(PetscObjectTypeCompare((PetscObject)unmapped, MATDIAGONAL, &is_uH_diag));
387:   if (is_uH_diag && !is_map_diag) PetscCall(MatDuplicate(map, MAT_DO_NOT_COPY_VALUES, mapped_work));
388:   PetscFunctionReturn(PETSC_SUCCESS);
389: }

391: // This function takes in unmapped_H, map, and returns matrix for mapped_H, which is PtAP
392: static PetscErrorCode TaoTermMappingCreatePtAP(Mat unmapped_H, Mat map, Mat *H)
393: {
394:   PetscBool is_uH_diag, is_uH_cdiag, is_map_diag, is_map_cdiag;

396:   PetscFunctionBegin;
397:   PetscCall(PetscObjectTypeCompare((PetscObject)unmapped_H, MATDIAGONAL, &is_uH_diag));
398:   PetscCall(PetscObjectTypeCompare((PetscObject)unmapped_H, MATCONSTANTDIAGONAL, &is_uH_cdiag));
399:   PetscCall(PetscObjectTypeCompare((PetscObject)map, MATDIAGONAL, &is_map_diag));
400:   PetscCall(PetscObjectTypeCompare((PetscObject)map, MATCONSTANTDIAGONAL, &is_map_cdiag));

402:   // TODO support for PtAP with diagonal would be ideal
403:   if (is_map_diag && is_uH_diag) {
404:     PetscCall(MatDuplicate(unmapped_H, MAT_DO_NOT_COPY_VALUES, H));
405:   } else if (is_map_cdiag && is_uH_cdiag) {
406:     // MatDiagonal does not support setvalues, thus AIJ
407:     PetscLayout rlayout;
408:     PetscInt    m, M;

410:     PetscCall(MatGetLayouts(map, &rlayout, NULL));
411:     PetscCall(MatGetSize(map, &M, NULL));
412:     PetscCall(MatGetLocalSize(map, &m, NULL));
413:     PetscCall(MatCreate(PetscObjectComm((PetscObject)map), H));
414:     PetscCall(MatSetSizes(*H, m, m, M, M));
415:     PetscCall(MatSetLayouts(*H, rlayout, rlayout));
416:     PetscCall(MatSetType(*H, MATAIJ));
417:     PetscCall(MatSetUp(*H));
418:   } else if ((is_map_diag && !is_uH_diag && !is_uH_cdiag)) {
419:     PetscCall(MatDuplicate(unmapped_H, MAT_DO_NOT_COPY_VALUES, H));
420:   } else if (is_map_cdiag && is_uH_diag) {
421:     // MatDiagonal does not support setvalues, thus AIJ
422:     PetscLayout rlayout;
423:     PetscInt    m, M;

425:     PetscCall(MatGetLayouts(map, &rlayout, NULL));
426:     PetscCall(MatGetSize(map, &M, NULL));
427:     PetscCall(MatGetLocalSize(map, &m, NULL));
428:     PetscCall(MatCreate(PetscObjectComm((PetscObject)map), H));
429:     PetscCall(MatSetSizes(*H, m, m, M, M));
430:     PetscCall(MatSetLayouts(*H, rlayout, rlayout));
431:     PetscCall(MatSetType(*H, MATAIJ));
432:     PetscCall(MatSetUp(*H));
433:   } else if (is_map_diag && is_uH_cdiag) {
434:     PetscCall(MatDuplicate(map, MAT_DO_NOT_COPY_VALUES, H));
435:   } else if ((is_uH_diag && !is_map_diag && !is_map_cdiag) || (is_uH_cdiag && !is_map_diag && !is_map_cdiag)) {
436:     PetscCall(MatTransposeMatMult(map, map, MAT_INITIAL_MATRIX, PETSC_DETERMINE, H));
437:   } else if (!is_uH_diag && !is_uH_cdiag && is_map_cdiag) {
438:     PetscScalar cc;

440:     PetscCall(MatConstantDiagonalGetConstant(map, &cc));
441:     PetscCall(MatDuplicate(unmapped_H, MAT_COPY_VALUES, H));
442:     PetscCall(MatScale(*H, cc * cc));
443:   } else {
444:     PetscCall(MatProductCreate(unmapped_H, map, NULL, H));
445:     PetscCall(MatProductSetType(*H, MATPRODUCT_PtAP));
446:     PetscCall(MatProductSetFromOptions(*H));
447:     // TODO Some other default fallback?
448:     if ((*H)->ops->productsymbolic) PetscCall(MatProductSymbolic(*H));
449:     else SETERRQ(PetscObjectComm((PetscObject)map), PETSC_ERR_SUP, "Currently does not support PtAP routines for given pair of matrices");
450:     PetscCall(MatProductNumeric(*H));
451:     PetscCall(MatZeroEntries(*H));
452:     PetscCall(MatAssemblyBegin(*H, MAT_FINAL_ASSEMBLY));
453:     PetscCall(MatAssemblyEnd(*H, MAT_FINAL_ASSEMBLY));
454:   }
455:   PetscFunctionReturn(PETSC_SUCCESS);
456: }

458: /*
459:  * Internal function to create Hessian matrices for TaoTermMapping
460:  *
461:  * map: m x n
462:  *
463:  * This function will internally create unmapped, and mapped  H and Hpre,
464:  * and return H \gets mt->_mapped_H, and Hpre \gets mt->_mapped_Hpre.
465:  *
466:  * if (mt->map)
467:  *   It also creates internal work matrices to support PtAP with diagonal matrix, which is currently unsupported natively.
468:  *   mapped:   n x n
469:  *   unmapped: m x m
470:  *
471:  * else
472:  *   mapped:   n x n
473:  *   unmapped: n x n
474:  *
475:  */
476: PETSC_INTERN PetscErrorCode TaoTermMappingCreateHessianMatrices(TaoTermMapping *mt, Mat *H, Mat *Hpre)
477: {
478:   Mat       uH, uHpre, mH, mHpre;
479:   PetscBool is_sum;

481:   PetscFunctionBegin;
482:   uH    = mt->_unmapped_H;
483:   uHpre = mt->_unmapped_Hpre;
484:   mH    = mt->_mapped_H;
485:   mHpre = mt->_mapped_Hpre;
486:   PetscCall(PetscObjectTypeCompare((PetscObject)mt->term, TAOTERMSUM, &is_sum));
487:   if (is_sum && mt->map) PetscCall(PetscInfo(mt->term, "%s: TaoTermType is TAOTERMSUM, but Map is given. Ignoring it.\n", ((PetscObject)mt->term)->prefix));
488:   PetscCheck(H, PetscObjectComm((PetscObject)mt->term), PETSC_ERR_SUP, "TaoTermMappingCreateHessianMatrices does not take NULL input for H");
489:   PetscCheck(Hpre, PetscObjectComm((PetscObject)mt->term), PETSC_ERR_SUP, "TaoTermMappingCreateHessianMatrices does not take NULL input Hpre");
490:   if (!mt->map) {
491:     // mt->_unmapped_{H,Hpre} == mt->_unmapped_{H,Hpre}
492:     if (uH && mH) PetscCheck(uH == mH, PetscObjectComm((PetscObject)mt->term), PETSC_ERR_USER, "For unmapped TaoTerm, mapped Hessian and unmapped Hessian must be same");
493:     if (uHpre && mHpre) PetscCheck(uHpre == mHpre, PetscObjectComm((PetscObject)mt->term), PETSC_ERR_USER, "For unmapped TaoTerm, mapped Hessian preconditioner and unmapped Hessian preconditioner needs to be same");

495:     // If mapped matrices are present, it should be set to unmapped matrices
496:     if (mt->_mapped_H && !mt->_unmapped_H) {
497:       PetscCall(PetscObjectReference((PetscObject)mt->_mapped_H));
498:       mt->_unmapped_H = mt->_mapped_H;
499:     }
500:     if (mt->_mapped_Hpre && !mt->_unmapped_Hpre) {
501:       PetscCall(PetscObjectReference((PetscObject)mt->_mapped_Hpre));
502:       mt->_unmapped_Hpre = mt->_mapped_Hpre;
503:     }
504:     // create _unmapped only if they are empty
505:     PetscCall(TaoTermCreateHessianMatrices(mt->term, (mt->_unmapped_H) ? NULL : &mt->_unmapped_H, (mt->_unmapped_Hpre) ? NULL : &mt->_unmapped_Hpre));
506:     // If mapped matrices are NULL, it should be set to mapped matrices
507:     if (mt->_unmapped_H && !mt->_mapped_H) {
508:       PetscCall(PetscObjectReference((PetscObject)mt->_unmapped_H));
509:       mt->_mapped_H = mt->_unmapped_H;
510:     }
511:     if (mt->_unmapped_Hpre && !mt->_mapped_Hpre) {
512:       PetscCall(PetscObjectReference((PetscObject)mt->_unmapped_Hpre));
513:       mt->_mapped_Hpre = mt->_unmapped_Hpre;
514:     }

516:     // always returns Hpre, even if same as H
517:     if (*H != mt->_unmapped_H) PetscCall(PetscObjectReference((PetscObject)mt->_unmapped_H));
518:     if (*Hpre != mt->_unmapped_Hpre) PetscCall(PetscObjectReference((PetscObject)mt->_unmapped_Hpre));
519:     *H    = mt->_unmapped_H;
520:     *Hpre = mt->_unmapped_Hpre;
521:   } else {
522:     // create _unmapped only if they are empty
523:     PetscCall(TaoTermCreateHessianMatrices(mt->term, (mt->_unmapped_H) ? NULL : &mt->_unmapped_H, (mt->_unmapped_Hpre) ? NULL : &mt->_unmapped_Hpre));
524:     // Hack to support  AIJ.... TODO
525:     PetscCall(MatAssemblyBegin(mt->_unmapped_H, MAT_FINAL_ASSEMBLY));
526:     PetscCall(MatAssemblyEnd(mt->_unmapped_H, MAT_FINAL_ASSEMBLY));
527:     PetscCall(MatShift(mt->_unmapped_H, 1.));
528:     // Create PtAP only if mt->_mapped_H is empty
529:     if (mt->_unmapped_H && !mt->_mapped_H) PetscCall(TaoTermMappingCreatePtAP(mt->_unmapped_H, mt->map, &mt->_mapped_H));
530:     // Creating expensive work matrix to store AP TODO remove when diag PtAP gets implemented
531:     if (!mt->_mapped_H_work) PetscCall(TaoTermMappingCreateAPWorkMatrix(mt->map, mt->_unmapped_H, &mt->_mapped_H_work));
532:     if (*H != mt->_mapped_H) PetscCall(PetscObjectReference((PetscObject)mt->_mapped_H));
533:     *H = mt->_mapped_H;
534:     if (mt->_unmapped_Hpre == mt->_unmapped_H) {
535:       // Hpre_is_H true, so mapped_H = mapped_Hpre
536:       if (!mt->_mapped_Hpre) {
537:         PetscCall(PetscObjectReference((PetscObject)mt->_mapped_H));
538:         mt->_mapped_Hpre = mt->_mapped_H;
539:       }
540:       if (*Hpre != mt->_mapped_Hpre) PetscCall(PetscObjectReference((PetscObject)*H));
541:       *Hpre = *H;
542:     } else {
543:       if (!mt->_mapped_Hpre) PetscCall(TaoTermMappingCreatePtAP(mt->_unmapped_Hpre, mt->map, &mt->_mapped_Hpre));
544:       if (!mt->_mapped_Hpre_work) PetscCall(TaoTermMappingCreateAPWorkMatrix(mt->map, mt->_unmapped_Hpre, &mt->_mapped_Hpre_work));
545:       if (*Hpre != mt->_mapped_Hpre) PetscCall(PetscObjectReference((PetscObject)mt->_mapped_Hpre));
546:       *Hpre = mt->_mapped_Hpre;
547:     }
548:   }
549:   PetscFunctionReturn(PETSC_SUCCESS);
550: }