Actual source code: letkf_local_analysis.kokkos.cxx

  1: #include "../src/ml/da/impls/ensemble/letkf/letkf.h"
  2: #include <Kokkos_Core.hpp>
  3: #include <KokkosBlas.hpp>

  5: #if defined(KOKKOS_ENABLE_CUDA)
  6:   #include <cusolverDn.h>
  7:   #include <cuda_runtime.h>
  8: #include <petscdevice_cuda.h>
  9: #elif defined(KOKKOS_ENABLE_HIP)
 10:   #include <rocsolver/rocsolver.h>
 11:   #include <hip/hip_runtime.h>
 12: #include <petscdevice_hip.h>
 13: #elif defined(KOKKOS_ENABLE_SYCL)
 14:   #include <oneapi/mkl.hpp>
 15:   #include <sycl/sycl.hpp>
 16: #endif

 18: /* ========================================================================== */
 19: /*                    Batched Eigendecomposition for LETKF                    */
 20: /* ========================================================================== */

 22: /* Structure to hold reusable workspace for eigensolvers */
 23: struct EigenWorkspace {
 24:   /* Tracking for reuse */
 25:   PetscInt max_chunk_size;
 26:   PetscInt m;
 27:   PetscInt n_obs_vertex;

 29:   /* Persistent Kokkos Views */
 30:   using exec_space = Kokkos::DefaultExecutionSpace;
 31:   using view_3d    = Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, exec_space>;
 32:   using view_2d    = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>;

 34:   view_3d Z_batch;
 35:   view_3d S_batch;
 36:   view_3d T_batch;
 37:   view_3d V_batch;
 38:   view_2d Lambda_batch;
 39:   view_3d T_sqrt_batch;
 40:   view_2d w_batch;
 41:   view_2d delta_batch;
 42:   view_2d y_batch;
 43:   view_2d y_mean_batch;
 44:   view_2d r_inv_sqrt_batch;
 45:   view_2d temp1_batch;
 46:   view_2d temp2_batch;
 47:   view_2d inv_sqrt_lambda_batch;

 49:   /* Host workspace */
 50:   PetscScalar *all_v;
 51:   PetscReal   *all_lambda;
 52:   PetscScalar *all_work;
 53: #if defined(PETSC_USE_COMPLEX)
 54:   PetscReal *all_rwork;
 55: #endif
 56:   PetscBLASInt lwork;
 57:   PetscBLASInt n_blas;

 59:   /* Device workspace */
 60: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
 61:   #if defined(KOKKOS_ENABLE_CUDA)
 62:   syevjInfo_t  syevj_params;
 63:   PetscScalar *d_work;
 64:   int         *d_info;
 65:   PetscScalar *d_A_contig;
 66:   PetscScalar *d_W_contig;
 67:   int          lwork_device;
 68:   #elif defined(KOKKOS_ENABLE_HIP)
 69:   PetscScalar *d_work;
 70:   int         *d_info;
 71:   PetscScalar *d_A_contig;
 72:   PetscScalar *d_W_contig;
 73:   int          lwork_device;
 74:   #elif defined(KOKKOS_ENABLE_SYCL)
 75:   PetscScalar *d_work;
 76:   int         *d_info;
 77:   PetscScalar *d_A_contig;
 78:   PetscScalar *d_W_contig;
 79:   int          lwork_device;
 80:   #endif
 81: #endif

 83:   EigenWorkspace() : max_chunk_size(0), m(0), n_obs_vertex(0), all_v(nullptr), all_lambda(nullptr), all_work(nullptr)
 84:   {
 85: #if defined(PETSC_USE_COMPLEX)
 86:     all_rwork = nullptr;
 87: #endif
 88: #if defined(KOKKOS_ENABLE_CUDA)
 89:     d_work       = nullptr;
 90:     d_info       = nullptr;
 91:     d_A_contig   = nullptr;
 92:     d_W_contig   = nullptr;
 93:     syevj_params = nullptr;
 94: #elif defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
 95:     d_work     = nullptr;
 96:     d_info     = nullptr;
 97:     d_A_contig = nullptr;
 98:     d_W_contig = nullptr;
 99: #endif
100:   }
101: };

103: /*
104:   BatchedEigenSolve_Host - Compute eigendecomposition for a batch of symmetric matrices (CPU version)

106:   Input Parameters:
107: + T_batch      - batch of symmetric matrices (n_batch x n_size x n_size)
108: . n_batch      - number of matrices in the batch
109: - n_size       - size of each matrix (m x m)
110: - work         - reusable workspace structure

112:   Output Parameters:
113: + Lambda_batch - eigenvalues for each matrix (n_batch x n_size)
114: - V_batch      - eigenvectors for each matrix (n_batch x n_size x n_size)

116:   Notes:
117:   Uses LAPACK's syev routine to compute eigendecomposition sequentially on host.
118: */
119: #if !defined(KOKKOS_ENABLE_CUDA) && !defined(KOKKOS_ENABLE_HIP) && !defined(KOKKOS_ENABLE_SYCL)
120: #include <petscblaslapack.h>
121: static PetscErrorCode BatchedEigenSolve_Host(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, EigenWorkspace *work)
122: {
123:   PetscFunctionBegin;
124:   /* Create host mirrors and copy data in one operation */
125:   /* This is required for HIP+complex where create_mirror_view + deep_copy fails */
126:   auto T_host      = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), T_batch);
127:   auto Lambda_host = Kokkos::create_mirror_view(Kokkos::HostSpace(), Lambda_batch);
128:   auto V_host      = Kokkos::create_mirror_view(Kokkos::HostSpace(), V_batch);

130:   /* Use pre-allocated workspace */
131:   PetscScalar *all_v      = work->all_v;
132:   PetscReal   *all_lambda = work->all_lambda;
133:   PetscScalar *all_work   = work->all_work;
134:   PetscBLASInt lwork      = work->lwork;
135:   PetscBLASInt n_blas     = work->n_blas;
136:   #if defined(PETSC_USE_COMPLEX)
137:   PetscReal *all_rwork = work->all_rwork;
138:   #endif

140:   /* Process each matrix in parallel on host using LAPACK */
141:   Kokkos::parallel_for(
142:     "BatchedEigenSolve_Host", Kokkos::RangePolicy<Kokkos::DefaultHostExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
143:       PetscBLASInt n   = n_blas;
144:       PetscBLASInt lda = n;
145:       PetscBLASInt info;
146:       PetscBLASInt lw = lwork;

148:       /* Pointers for this matrix */
149:       PetscScalar *v_ptr      = all_v + i * n_size * n_size;
150:       PetscReal   *lambda_ptr = all_lambda + i * n_size;
151:       PetscScalar *work_ptr   = all_work + i * lwork;
152:   #if defined(PETSC_USE_COMPLEX)
153:       PetscReal *rwork_ptr = all_rwork + i * (3 * n_size - 2);
154:   #endif

156:       /* Copy T_host(i, :, :) to v_ptr (column-major) */
157:       for (PetscInt j = 0; j < n_size; j++) {
158:         for (PetscInt k = 0; k < n_size; k++) v_ptr[k + j * n_size] = T_host(i, k, j);
159:       }

161:     /* Compute eigendecomposition: T = V * Lambda * V^T */
162:   #if defined(PETSC_USE_COMPLEX)
163:       LAPACKsyev_("V", "U", &n, v_ptr, &lda, lambda_ptr, work_ptr, &lw, rwork_ptr, &info);
164:   #else
165:       LAPACKsyev_("V", "U", &n, v_ptr, &lda, lambda_ptr, work_ptr, &lw, &info);
166:   #endif

168:       if (info != 0) {
169:         /* We cannot return error code from lambda, so we just abort or ignore.
170:            In production code, we should use a reduction to report errors. */
171:         Kokkos::abort("LAPACK eigendecomposition failed in parallel region");
172:       }

174:       /* Copy results back to host views */
175:       for (PetscInt j = 0; j < n_size; j++) {
176:         Lambda_host(i, j) = (PetscScalar)lambda_ptr[j];
177:         for (PetscInt k = 0; k < n_size; k++) V_host(i, k, j) = v_ptr[k + j * n_size];
178:       }
179:     });

181:   /* Copy results back to device */
182:   Kokkos::deep_copy(Lambda_batch, Lambda_host);
183:   Kokkos::deep_copy(V_batch, V_host);
184:   PetscFunctionReturn(PETSC_SUCCESS);
185: }
186: #endif

188: /*
189:   BatchedEigenSolve_Device - Compute eigendecomposition for a batch of symmetric matrices (Device version)

191:   Input Parameters:
192: + T_batch      - batch of symmetric matrices (n_batch x n_size x n_size)
193: . n_batch      - number of matrices in the batch
194: - n_size       - size of each matrix (m x m)
195: - device_handle - device-specific solver handle (cusolverDnHandle_t, rocblas_handle, or sycl::queue*)
196: - work         - reusable workspace structure

198:   Output Parameters:
199: + Lambda_batch - eigenvalues for each matrix (n_batch x n_size)
200: - V_batch      - eigenvectors for each matrix (n_batch x n_size x n_size)

202:   Notes:
203:   Uses vendor-specific batched symmetric eigensolvers:
204:   - CUDA: cuSOLVER's syevjBatched
205:   - HIP: rocSOLVER's rocsolver_dsyevj_batched
206:   - SYCL: oneMKL's syevd_batch
207: */
208: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
209:   #if defined(KOKKOS_ENABLE_CUDA)
210: static PetscErrorCode BatchedEigenSolve_Device(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, cusolverDnHandle_t cusolverH, EigenWorkspace *work)
211: {
212:   cusolverStatus_t cusolver_status;

214:   PetscFunctionBegin;
215:   /* Use pre-allocated workspace */
216:   syevjInfo_t  syevj_params = work->syevj_params;
217:   PetscScalar *d_work       = work->d_work;
218:   int         *d_info       = work->d_info;
219:   PetscScalar *d_A_contig   = work->d_A_contig;
220:   PetscScalar *d_W_contig   = work->d_W_contig;
221:   int          lwork        = work->lwork_device;

223:   /* Copy T_batch to contiguous layout for cuSOLVER */
224:   Kokkos::parallel_for(
225:     "ReorganizeForCuSOLVER", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
226:       for (int j = 0; j < n_size; j++) {
227:         for (int k = 0; k < n_size; k++) d_A_contig[i * n_size * n_size + k * n_size + j] = T_batch(i, j, k);
228:       }
229:     });
230:   Kokkos::fence();

232:     /* Solve batched eigendecomposition */
233:     #if defined(PETSC_USE_REAL_SINGLE)
234:   cusolver_status = cusolverDnSsyevjBatched(cusolverH, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, n_size, d_A_contig, n_size, d_W_contig, d_work, lwork, d_info, syevj_params, n_batch);
235:     #else
236:   cusolver_status = cusolverDnDsyevjBatched(cusolverH, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, n_size, d_A_contig, n_size, d_W_contig, d_work, lwork, d_info, syevj_params, n_batch);
237:     #endif
238:   PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDn*syevjBatched failed");

240:   /* Check info */
241:   int *h_info;
242:   PetscCall(PetscMalloc1(n_batch, &h_info));
243:   PetscCallCUDA(cudaMemcpy(h_info, d_info, sizeof(int) * n_batch, cudaMemcpyDeviceToHost));
244:   for (PetscInt i = 0; i < n_batch; i++) {
245:     if (h_info[i] != 0) PetscCheck(h_info[i] == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "cuSOLVER eigendecomposition failed for matrix %" PetscInt_FMT ": info=%d", i, h_info[i]);
246:   }
247:   PetscCall(PetscFree(h_info));

249:   /* Copy results back from contiguous layout to V_batch */
250:   Kokkos::parallel_for(
251:     "CopyResultsBack", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
252:       for (int j = 0; j < n_size; j++) {
253:         for (int k = 0; k < n_size; k++) V_batch(i, j, k) = d_A_contig[i * n_size * n_size + k * n_size + j];
254:         Lambda_batch(i, j) = d_W_contig[i * n_size + j]; // CUDA-12.6 nvcc compiler hangs if we put this line before the V_batch loop
255:       }
256:     });
257:   Kokkos::fence();
258:   PetscFunctionReturn(PETSC_SUCCESS);
259: }
260:   #elif defined(KOKKOS_ENABLE_HIP)
261: static PetscErrorCode BatchedEigenSolve_Device(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, rocblas_handle rocblasH, EigenWorkspace *work)
262: {
263:   PetscFunctionBegin;
264:   /* Use pre-allocated workspace */
265:   PetscScalar *d_work = work->d_work;
266:   (void)d_work;
267:   int         *d_info     = work->d_info;
268:   PetscScalar *d_A_contig = work->d_A_contig;
269:   PetscScalar *d_W_contig = work->d_W_contig;

271:   /* Copy T_batch to contiguous layout for rocSOLVER */
272:   Kokkos::parallel_for(
273:     "ReorganizeForRocSOLVER", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
274:       for (int j = 0; j < n_size; j++) {
275:         for (int k = 0; k < n_size; k++) d_A_contig[i * n_size * n_size + k * n_size + j] = T_batch(i, j, k);
276:       }
277:     });
278:   Kokkos::fence();

280:     /* rocSOLVER doesn't have a native batched syevj, so we loop over batch */
281:     /* Use rocsolver_dsyevd which is more efficient than calling syev in a loop */
282:     #if defined(PETSC_USE_COMPLEX)
283:   SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Complex numbers not supported on HIP backend for LETKF");
284:     #else
285:   for (int i = 0; i < n_batch; i++) {
286:     PetscScalar   *A_ptr    = d_A_contig + i * n_size * n_size;
287:     PetscScalar   *W_ptr    = d_W_contig + i * n_size;
288:     int           *info_ptr = d_info + i;
289:     rocblas_status hip_status;

291:       #if defined(PETSC_USE_REAL_SINGLE)
292:     hip_status = rocsolver_ssyevd(rocblasH, rocblas_evect_original, rocblas_fill_upper, n_size, A_ptr, n_size, W_ptr, d_work, info_ptr);
293:       #else
294:     hip_status = rocsolver_dsyevd(rocblasH, rocblas_evect_original, rocblas_fill_upper, n_size, A_ptr, n_size, W_ptr, d_work, info_ptr);
295:       #endif
296:     PetscCheck(hip_status == rocblas_status_success, PETSC_COMM_SELF, PETSC_ERR_LIB, "rocsolver_*syevd failed for batch %" PetscInt_FMT, i);
297:   }
298:     #endif

300:   /* Check info */
301:   int *h_info;
302:   PetscCall(PetscMalloc1(n_batch, &h_info));
303:   PetscCallHIP(hipMemcpy(h_info, d_info, sizeof(int) * n_batch, hipMemcpyDeviceToHost));
304:   for (PetscInt i = 0; i < n_batch; i++) {
305:     if (h_info[i] != 0) PetscCheck(h_info[i] == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "rocSOLVER eigendecomposition failed for matrix %" PetscInt_FMT ": info=%d", i, h_info[i]);
306:   }
307:   PetscCall(PetscFree(h_info));

309:   /* Copy results back from contiguous layout to V_batch */
310:   Kokkos::parallel_for(
311:     "CopyResultsBack", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
312:       for (int j = 0; j < n_size; j++) {
313:         for (int k = 0; k < n_size; k++) V_batch(i, j, k) = d_A_contig[i * n_size * n_size + k * n_size + j];
314:         Lambda_batch(i, j) = d_W_contig[i * n_size + j];
315:       }
316:     });
317:   Kokkos::fence();
318:   PetscFunctionReturn(PETSC_SUCCESS);
319: }
320:   #elif defined(KOKKOS_ENABLE_SYCL)
321: static PetscErrorCode BatchedEigenSolve_Device(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, sycl::queue *q, EigenWorkspace *work)
322: {
323:   PetscFunctionBegin;
324:   /* Use pre-allocated workspace */
325:   PetscScalar *d_work     = work->d_work;
326:   int         *d_info     = work->d_info;
327:   PetscScalar *d_A_contig = work->d_A_contig;
328:   PetscScalar *d_W_contig = work->d_W_contig;

330:   /* Copy T_batch to contiguous layout for oneMKL */
331:   Kokkos::parallel_for(
332:     "ReorganizeForOneMKL", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
333:       for (int j = 0; j < n_size; j++) {
334:         for (int k = 0; k < n_size; k++) d_A_contig[i * n_size * n_size + k * n_size + j] = T_batch(i, j, k);
335:       }
336:     });
337:   Kokkos::fence();

339:   /* oneMKL doesn't have a native batched syevd, so we loop over batch */
340:   /* Use oneapi::mkl::lapack::syevd which computes eigenvalues and eigenvectors */
341:   for (int i = 0; i < n_batch; i++) {
342:     PetscScalar *A_ptr = d_A_contig + i * n_size * n_size;
343:     PetscScalar *W_ptr = d_W_contig + i * n_size;
344:     // int         *info_ptr = d_info + i;

346:     try {
347:     #if defined(PETSC_USE_REAL_SINGLE)
348:       // oneapi::mkl::lapack::syevd(*q, oneapi::mkl::job::vec, oneapi::mkl::uplo::upper, n_size, A_ptr, n_size, W_ptr, d_work, work->lwork_device, info_ptr);
349:       oneapi::mkl::lapack::syevd(*q, oneapi::mkl::job::vec, oneapi::mkl::uplo::upper, n_size, A_ptr, n_size, W_ptr, d_work, work->lwork_device);
350:     #else
351:       // oneapi::mkl::lapack::syevd(*q, oneapi::mkl::job::vec, oneapi::mkl::uplo::upper, n_size, A_ptr, n_size, W_ptr, d_work, work->lwork_device, info_ptr);
352:       oneapi::mkl::lapack::syevd(*q, oneapi::mkl::job::vec, oneapi::mkl::uplo::upper, n_size, A_ptr, n_size, W_ptr, d_work, work->lwork_device);
353:     #endif
354:       q->wait();
355:     } catch (sycl::exception const &e) {
356:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "oneMKL syevd failed for batch %d: %s", i, e.what());
357:     }
358:   }

360:   /* Check info */
361:   int *h_info;
362:   PetscCall(PetscMalloc1(n_batch, &h_info));
363:   q->memcpy(h_info, d_info, sizeof(int) * n_batch).wait();
364:   for (PetscInt i = 0; i < n_batch; i++) {
365:     if (h_info[i] != 0) PetscCheck(h_info[i] == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "oneMKL eigendecomposition failed for matrix %" PetscInt_FMT ": info=%d", i, h_info[i]);
366:   }
367:   PetscCall(PetscFree(h_info));

369:   /* Copy results back from contiguous layout to V_batch */
370:   Kokkos::parallel_for(
371:     "CopyResultsBack", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
372:       for (int j = 0; j < n_size; j++) {
373:         for (int k = 0; k < n_size; k++) V_batch(i, j, k) = d_A_contig[i * n_size * n_size + k * n_size + j];
374:         Lambda_batch(i, j) = d_W_contig[i * n_size + j];
375:       }
376:     });
377:   Kokkos::fence();
378:   PetscFunctionReturn(PETSC_SUCCESS);
379: }
380:   #endif
381: #endif

383: /*
384:   BatchedEigenSolve - Compute eigendecomposition for a batch of symmetric matrices

386:   Input Parameters:
387: + T_batch      - batch of symmetric matrices (n_batch x n_size x n_size)
388: . n_batch      - number of matrices in the batch
389: - n_size       - size of each matrix (m x m)
390: - device_handle - device-specific solver handle (only for device builds)
391: - work         - reusable workspace structure

393:   Output Parameters:
394: + Lambda_batch - eigenvalues for each matrix (n_batch x n_size)
395: - V_batch      - eigenvectors for each matrix (n_batch x n_size x n_size)

397:   Notes:
398:   Dispatcher function that calls the appropriate backend (Device or Host).
399: */
400: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
401:   #if defined(KOKKOS_ENABLE_CUDA)
402: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, cusolverDnHandle_t device_handle, EigenWorkspace *work)
403: {
404:   PetscFunctionBegin;
405:   PetscCall(BatchedEigenSolve_Device(T_batch, Lambda_batch, V_batch, n_batch, n_size, device_handle, work));
406:   PetscFunctionReturn(PETSC_SUCCESS);
407: }
408:   #elif defined(KOKKOS_ENABLE_HIP)
409: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, rocblas_handle device_handle, EigenWorkspace *work)
410: {
411:   PetscFunctionBegin;
412:   PetscCall(BatchedEigenSolve_Device(T_batch, Lambda_batch, V_batch, n_batch, n_size, device_handle, work));
413:   PetscFunctionReturn(PETSC_SUCCESS);
414: }
415:   #elif defined(KOKKOS_ENABLE_SYCL)
416: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, sycl::queue *device_handle, EigenWorkspace *work)
417: {
418:   PetscFunctionBegin;
419:   PetscCall(BatchedEigenSolve_Device(T_batch, Lambda_batch, V_batch, n_batch, n_size, device_handle, work));
420:   PetscFunctionReturn(PETSC_SUCCESS);
421: }
422:   #endif
423: #else
424: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, EigenWorkspace *work)
425: {
426:   PetscFunctionBegin;
427:   PetscCall(BatchedEigenSolve_Host(T_batch, Lambda_batch, V_batch, n_batch, n_size, work));
428:   PetscFunctionReturn(PETSC_SUCCESS);
429: }
430: #endif

432: /*
433:   PetscDALETKFSetupLocalization_Kokkos - Prepares device views for localization matrix Q
434: */
435: PetscErrorCode PetscDALETKFSetupLocalization_Kokkos(PetscDA_LETKF *impl, Mat H)
436: {
437:   PetscInt nrows;

439:   PetscFunctionBegin;
440:   PetscCheck(impl->Q, PETSC_COMM_SELF, PETSC_ERR_LIB, "impl->Q = 0");
441:   PetscCall(PetscKokkosInitializeCheck());

443:   /* Get CSR data */
444:   PetscInt rstart, rend, i, nnz;
445:   PetscCall(MatGetOwnershipRange(impl->Q, &rstart, &rend));
446:   nrows = rend - rstart;

448:   /* Create IS for local observations needed by this process */
449:   /* We need to find all unique column indices in the local rows of Q */
450:   {
451:     PetscInt     *obs_indices;
452:     PetscInt      n_obs_local_total = 0;
453:     PetscInt      max_obs           = nrows * impl->n_obs_vertex;
454:     PetscInt      count             = 0;
455:     PetscHMapI    ht;
456:     PetscHashIter iter;
457:     PetscBool     missing;

459:     PetscCall(PetscHMapICreate(&ht));
460:     PetscCall(PetscMalloc1(max_obs, &obs_indices));

462:     for (i = 0; i < nrows; i++) {
463:       const PetscInt    *cols;
464:       const PetscScalar *vals;
465:       PetscCall(MatGetRow(impl->Q, rstart + i, &nnz, &cols, &vals));
466:       for (PetscInt k = 0; k < nnz; k++) {
467:         PetscCall(PetscHMapIPut(ht, cols[k], &iter, &missing));
468:         if (missing) {
469:           obs_indices[count] = cols[k];
470:           count++;
471:         }
472:       }
473:       PetscCall(MatRestoreRow(impl->Q, rstart + i, &nnz, &cols, &vals));
474:     }
475:     n_obs_local_total = count;

477:     /* Sort indices for consistent ordering */
478:     PetscCall(PetscSortInt(n_obs_local_total, obs_indices));

480:     /* Create IS and VecScatter */
481:     PetscCall(ISCreateGeneral(PETSC_COMM_SELF, n_obs_local_total, obs_indices, PETSC_COPY_VALUES, &impl->obs_is_local));

483:     /* Create global-to-local map for observations */
484:     PetscCall(PetscHMapICreate(&impl->obs_g2l));
485:     for (i = 0; i < n_obs_local_total; i++) {
486:       PetscCall(PetscHMapIPut(impl->obs_g2l, obs_indices[i], &iter, &missing));
487:       PetscCall(PetscHMapIIterSet(impl->obs_g2l, iter, i));
488:     }

490:     PetscCall(PetscFree(obs_indices));
491:     PetscCall(PetscHMapIDestroy(&ht));
492:   }

494:   /* Create work vectors and scatter context */
495:   {
496:     PetscInt n_obs_local_total;
497:     PetscCall(ISGetLocalSize(impl->obs_is_local, &n_obs_local_total));

499:     PetscCall(VecCreateSeq(PETSC_COMM_SELF, n_obs_local_total, &impl->obs_work));
500:     PetscCall(VecCreateSeq(PETSC_COMM_SELF, n_obs_local_total, &impl->y_mean_work));
501:     PetscCall(VecCreateSeq(PETSC_COMM_SELF, n_obs_local_total, &impl->r_inv_sqrt_work));

503:     Vec gvec;
504:     IS  is_to;
505:     PetscCall(MatCreateVecs(H, NULL, &gvec)); /* Create template global vector (left vector = rows = observations) */
506:     PetscCall(ISCreateStride(PETSC_COMM_SELF, n_obs_local_total, 0, 1, &is_to));
507:     PetscCall(VecScatterCreate(gvec, impl->obs_is_local, impl->obs_work, is_to, &impl->obs_scat));
508:     PetscCall(VecDestroy(&gvec));
509:     PetscCall(ISDestroy(&is_to));
510:   }

512:   /* Define View types */
513:   using view_1d_int    = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
514:   using view_1d_scalar = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft>;

516:   /* Allocate device views */
517:   view_1d_int    *d_Q_i = new view_1d_int("Q_i", nrows + 1);
518:   view_1d_int    *d_Q_j = new view_1d_int("Q_j", nrows * impl->n_obs_vertex);
519:   view_1d_scalar *d_Q_a = new view_1d_scalar("Q_a", nrows * impl->n_obs_vertex);

521:   /* Create host mirrors */
522:   auto h_Q_i = Kokkos::create_mirror_view(*d_Q_i);
523:   auto h_Q_j = Kokkos::create_mirror_view(*d_Q_j);
524:   auto h_Q_a = Kokkos::create_mirror_view(*d_Q_a);

526:   /* Fill host mirrors with LOCAL indices into obs_work */
527:   h_Q_i(0) = 0;
528:   for (i = 0; i < nrows; i++) {
529:     const PetscInt    *cols;
530:     const PetscScalar *vals;
531:     PetscCall(MatGetRow(impl->Q, rstart + i, &nnz, &cols, &vals));
532:     h_Q_i(i + 1) = h_Q_i(i) + nnz;
533:     for (PetscInt k = 0; k < nnz; k++) {
534:       PetscInt local_idx;
535:       PetscCall(ISLocate(impl->obs_is_local, cols[k], &local_idx));
536:       PetscCheck(local_idx >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Observation index %" PetscInt_FMT " not found in local IS", cols[k]);
537:       h_Q_j(h_Q_i(i) + k) = local_idx;
538:       h_Q_a(h_Q_i(i) + k) = vals[k];
539:     }
540:     PetscCall(MatRestoreRow(impl->Q, rstart + i, &nnz, &cols, &vals));
541:   }

543:   /* Copy to device */
544:   Kokkos::deep_copy(*d_Q_i, h_Q_i);
545:   Kokkos::deep_copy(*d_Q_j, h_Q_j);
546:   Kokkos::deep_copy(*d_Q_a, h_Q_a);

548:   /* Store in impl */
549:   PetscCheck(!impl->Q_device_i, PETSC_COMM_SELF, PETSC_ERR_LIB, "impl->Q = 0");
550:   impl->Q_device_i = static_cast<void *>(d_Q_i);
551:   impl->Q_device_j = static_cast<void *>(d_Q_j);
552:   impl->Q_device_a = static_cast<void *>(d_Q_a);
553:   PetscFunctionReturn(PETSC_SUCCESS);
554: }

556: PetscErrorCode PetscDALETKFDestroyLocalization_Kokkos(PetscDA_LETKF *impl)
557: {
558:   PetscFunctionBegin;
559:   PetscCall(VecDestroy(&impl->obs_work));
560:   PetscCall(VecDestroy(&impl->y_mean_work));
561:   PetscCall(VecDestroy(&impl->r_inv_sqrt_work));
562:   PetscCall(VecScatterDestroy(&impl->obs_scat));
563:   PetscCall(MatDestroy(&impl->Z_work));
564:   PetscCall(PetscHMapIDestroy(&impl->obs_g2l));
565:   if (impl->Q_device_i) {
566:     using view_1d_int = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
567:     delete static_cast<view_1d_int *>(impl->Q_device_i);
568:     impl->Q_device_i = NULL;
569:   }
570:   if (impl->Q_device_j) {
571:     using view_1d_int = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
572:     delete static_cast<view_1d_int *>(impl->Q_device_j);
573:     impl->Q_device_j = NULL;
574:   }
575:   if (impl->Q_device_a) {
576:     using view_1d_scalar = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft>;
577:     delete static_cast<view_1d_scalar *>(impl->Q_device_a);
578:     impl->Q_device_a = NULL;
579:   }

581:   /* Destroy solver handle and workspace */
582:   if (impl->eigen_work) {
583:     EigenWorkspace *work = static_cast<EigenWorkspace *>(impl->eigen_work);

585: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
586:   #if defined(KOKKOS_ENABLE_CUDA)
587:     PetscCallCUDA(cudaFree(work->d_A_contig));
588:     PetscCallCUDA(cudaFree(work->d_W_contig));
589:     PetscCallCUDA(cudaFree(work->d_work));
590:     PetscCallCUDA(cudaFree(work->d_info));
591:     if (work->syevj_params) cusolverDnDestroySyevjInfo(work->syevj_params);
592:   #elif defined(KOKKOS_ENABLE_HIP)
593:     PetscCallHIP(hipFree(work->d_A_contig));
594:     PetscCallHIP(hipFree(work->d_W_contig));
595:     PetscCallHIP(hipFree(work->d_work));
596:     PetscCallHIP(hipFree(work->d_info));
597:   #elif defined(KOKKOS_ENABLE_SYCL)
598:     if (impl->solver_handle) {
599:       sycl::queue *q = static_cast<sycl::queue *>(impl->solver_handle);
600:       if (work->d_A_contig) sycl::free(work->d_A_contig, *q);
601:       if (work->d_W_contig) sycl::free(work->d_W_contig, *q);
602:       if (work->d_work) sycl::free(work->d_work, *q);
603:       if (work->d_info) sycl::free(work->d_info, *q);
604:     }
605:   #endif
606: #else
607:   #if defined(PETSC_USE_COMPLEX)
608:     PetscCall(PetscFree4(work->all_v, work->all_lambda, work->all_work, work->all_rwork));
609:   #else
610:     PetscCall(PetscFree3(work->all_v, work->all_lambda, work->all_work));
611:   #endif
612: #endif

614:     delete work;
615:     impl->eigen_work = NULL;
616:   }

618:   if (impl->solver_handle) {
619: #if defined(KOKKOS_ENABLE_CUDA)
620:     cusolverDnDestroy(static_cast<cusolverDnHandle_t>(impl->solver_handle));
621: #elif defined(KOKKOS_ENABLE_HIP)
622:     rocblas_destroy_handle(static_cast<rocblas_handle>(impl->solver_handle));
623: #elif defined(KOKKOS_ENABLE_SYCL)
624:     delete static_cast<sycl::queue *>(impl->solver_handle);
625: #endif
626:     impl->solver_handle = NULL;
627:   }
628:   PetscFunctionReturn(PETSC_SUCCESS);
629: }

631: /* ========================================================================== */
632: /*                    LETKF Local Analysis (Main Function)                    */
633: /* ========================================================================== */

635: /*
636:   PetscDALETKFLocalAnalysis_GPU - Performs local LETKF analysis for all grid points (Kokkos version)

638:   Input Parameters:
639: + da             - the PetscDA context
640: . impl           - LETKF implementation data
641: . m              - ensemble size
642: . n_vertices     - number of grid points
643: . X              - global anomaly matrix (state_size x m)
644: . observation    - observation vector
645: . Z_global       - global observation ensemble (obs_size x m)
646: . y_mean_global  - global observation mean
647: - r_inv_sqrt_global - global R^{-1/2}

649:   Output:
650: . da->ensemble - updated with analysis ensemble

652:   Notes:
653:   This function performs the local analysis loop for LETKF, processing each grid point
654:   independently using its local observations defined by the localization matrix Q.
655:   This is the CPU version that does not use Kokkos acceleration.

657:   All local analysis workspace objects (Z_local, S_local, T_sqrt_local, G_local, y_local,
658:   y_mean_local, delta_scaled_local, r_inv_sqrt_local, w_local, s_transpose_delta) are
659:   created with PETSC_COMM_SELF because the analysis at each vertex is serial and independent.
660: */
661: PetscErrorCode PetscDALETKFLocalAnalysis_GPU(PetscDA da, PetscDA_LETKF *impl, PetscInt m, PetscInt n_vertices, Mat X, Vec observation, Mat Z_global, Vec y_mean_global, Vec r_inv_sqrt_global)
662: {
663:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
664:   PetscInt          ndof;
665:   PetscReal         sqrt_m_minus_1, scale, inflation_inv;

667:   PetscFunctionBegin;
668:   ndof           = da->ndof;
669:   scale          = 1.0 / PetscSqrtReal((PetscReal)(m - 1));
670:   sqrt_m_minus_1 = PetscSqrtReal((PetscReal)(m - 1));
671:   inflation_inv  = 1.0 / en->inflation; /* (1/rho) for T matrix: T = (1/rho)I + S^T*S */

673:   /* ===================================================================== */
674:   /* Step 2.1.1: Create batched workspace for ALL grid points            */
675:   /* ===================================================================== */
676:   /*
677:      NOTE ON PARALLELISM STRATEGY:
678:      We use Kokkos::RangePolicy over grid points (n_vertices) combined with KokkosBatched::Serial kernels.
679:      Since the data layout is LayoutLeft (Column-Major) to match PETSc/LAPACK, the index 'i' (grid point)
680:      is the fastest varying index (stride 1).

682:      RangePolicy maps consecutive threads to consecutive 'i', ensuring perfect memory coalescing
683:      when accessing arrays like S_batch(i, p, j).

685:      Using TeamPolicy/TeamVectorRange to parallelize inner loops (m or p) would assign a team to 'i',
686:      causing threads within the team to access S_batch with stride 'n_vertices', which leads to
687:      uncoalesced memory access and poor performance on GPUs.

689:      Therefore, RangePolicy + SerialGemm is the optimal strategy for this data layout.
690:   */
691:   using exec_space = Kokkos::DefaultExecutionSpace;
692:   using view_3d    = Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, exec_space>;
693:   using view_2d    = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>;

695:   /* ===================================================================== */
696:   /* Step 2.1.2a: Pre-extract Q matrix CSR data for device access        */
697:   /* ===================================================================== */
698:   using view_1d_int_const    = Kokkos::View<const PetscInt *, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
699:   using view_1d_scalar_const = Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
700:   using view_1d_int          = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
701:   using view_1d_scalar       = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft>;

703:   view_1d_int_const    Q_i_view;
704:   view_1d_int_const    Q_j_view;
705:   view_1d_scalar_const Q_a_view;

707:   if (impl->Q_device_i) {
708:     /* Use pre-allocated device views */
709:     view_1d_int    *d_Q_i = static_cast<view_1d_int *>(impl->Q_device_i);
710:     view_1d_int    *d_Q_j = static_cast<view_1d_int *>(impl->Q_device_j);
711:     view_1d_scalar *d_Q_a = static_cast<view_1d_scalar *>(impl->Q_device_a);

713:     Q_i_view = view_1d_int_const(d_Q_i->data(), d_Q_i->extent(0));
714:     Q_j_view = view_1d_int_const(d_Q_j->data(), d_Q_j->extent(0));
715:     Q_a_view = view_1d_scalar_const(d_Q_a->data(), d_Q_a->extent(0));
716:   } else {
717:     /* Fallback to host pointers (unsafe if not UVM) */
718:     PetscCheck(PETSC_FALSE, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Q matrix must be setup with PetscDALETKFSetupLocalization_Kokkos");
719:   }

721:   /* Get global observation data arrays */
722:   const PetscScalar *z_global_array, *y_global_array, *y_mean_global_array, *r_inv_sqrt_global_array;
723:   PetscInt           lda_z_global;
724:   PetscMemType       z_mem_type, y_mem_type, y_mean_mem_type, r_inv_sqrt_mem_type;

726:   PetscCall(MatDenseGetArrayReadAndMemType(Z_global, &z_global_array, &z_mem_type));
727:   PetscCall(VecGetArrayReadAndMemType(observation, &y_global_array, &y_mem_type));
728:   PetscCall(VecGetArrayReadAndMemType(y_mean_global, &y_mean_global_array, &y_mean_mem_type));
729:   PetscCall(VecGetArrayReadAndMemType(r_inv_sqrt_global, &r_inv_sqrt_global_array, &r_inv_sqrt_mem_type));
730:   PetscCall(MatDenseGetLDA(Z_global, &lda_z_global));

732:   /* Handle memory mirroring for observation data */
733:   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space> z_managed;
734:   Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>  y_managed;
735:   Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>  y_mean_managed;
736:   Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>  r_inv_sqrt_managed;

738:   const PetscScalar *z_ptr          = z_global_array;
739:   const PetscScalar *y_ptr          = y_global_array;
740:   const PetscScalar *y_mean_ptr     = y_mean_global_array;
741:   const PetscScalar *r_inv_sqrt_ptr = r_inv_sqrt_global_array;

743:   if (z_mem_type == PETSC_MEMTYPE_HOST) {
744:     z_managed = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>("z_managed", lda_z_global, m);
745:     Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(z_global_array, lda_z_global, m);
746:     Kokkos::deep_copy(z_managed, src);
747:     z_ptr = z_managed.data();
748:   }
749:   if (y_mem_type == PETSC_MEMTYPE_HOST) {
750:     y_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("y_managed", lda_z_global);
751:     Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(y_global_array, lda_z_global);
752:     Kokkos::deep_copy(y_managed, src);
753:     y_ptr = y_managed.data();
754:   }
755:   if (y_mean_mem_type == PETSC_MEMTYPE_HOST) {
756:     y_mean_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("y_mean_managed", lda_z_global);
757:     Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(y_mean_global_array, lda_z_global);
758:     Kokkos::deep_copy(y_mean_managed, src);
759:     y_mean_ptr = y_mean_managed.data();
760:   }
761:   if (r_inv_sqrt_mem_type == PETSC_MEMTYPE_HOST) {
762:     r_inv_sqrt_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("r_inv_sqrt_managed", lda_z_global);
763:     Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(r_inv_sqrt_global_array, lda_z_global);
764:     Kokkos::deep_copy(r_inv_sqrt_managed, src);
765:     r_inv_sqrt_ptr = r_inv_sqrt_managed.data();
766:   }

768:   /* Create unmanaged Kokkos views for global observation data */
769:   using view_2d_unmanaged = Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
770:   using view_1d_unmanaged = Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

772:   view_2d_unmanaged Z_global_view(z_ptr, lda_z_global, m);
773:   view_1d_unmanaged y_global_view(y_ptr, lda_z_global);
774:   view_1d_unmanaged y_mean_global_view(y_mean_ptr, lda_z_global);
775:   view_1d_unmanaged r_inv_sqrt_global_view(r_inv_sqrt_ptr, lda_z_global);

777:   /* Get access to global X matrix and mean vector */
778:   const PetscScalar *x_array, *mean_array;
779:   PetscScalar       *e_array;
780:   PetscInt           lda_x, lda_e;
781:   PetscMemType       x_mem_type, mean_mem_type, e_mem_type;

783:   PetscCall(MatDenseGetArrayReadAndMemType(X, &x_array, &x_mem_type));
784:   PetscCall(VecGetArrayReadAndMemType(impl->mean, &mean_array, &mean_mem_type));
785:   PetscCall(MatDenseGetArrayWriteAndMemType(en->ensemble, &e_array, &e_mem_type));
786:   PetscCall(MatDenseGetLDA(X, &lda_x));
787:   PetscCall(MatDenseGetLDA(en->ensemble, &lda_e));

789:   /* Handle memory mirroring for state data */
790:   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space> x_managed;
791:   Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>  mean_managed;
792:   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space> e_managed;

794:   const PetscScalar *x_ptr     = x_array;
795:   const PetscScalar *mean_ptr  = mean_array;
796:   PetscScalar       *e_ptr     = e_array;
797:   bool               e_is_copy = false;

799:   if (x_mem_type == PETSC_MEMTYPE_HOST) {
800:     x_managed = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>("x_managed", lda_x, m);
801:     Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(x_array, lda_x, m);
802:     Kokkos::deep_copy(x_managed, src);
803:     x_ptr = x_managed.data();
804:   }
805:   if (mean_mem_type == PETSC_MEMTYPE_HOST) {
806:     mean_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("mean_managed", lda_x);
807:     Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(mean_array, lda_x);
808:     Kokkos::deep_copy(mean_managed, src);
809:     mean_ptr = mean_managed.data();
810:   }
811:   if (e_mem_type == PETSC_MEMTYPE_HOST) {
812:     e_managed = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>("e_managed", lda_e, m);
813:     Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(e_array, lda_e, m);
814:     Kokkos::deep_copy(e_managed, src);
815:     e_ptr     = e_managed.data();
816:     e_is_copy = true;
817:   }

819:   /* Create unmanaged Kokkos views for global data */
820:   using view_2d_unmanaged_write = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
821:   view_2d_unmanaged       X_view(const_cast<PetscScalar *>(x_ptr), lda_x, m);
822:   view_1d_unmanaged       mean_view(mean_ptr, lda_x);
823:   view_2d_unmanaged_write E_view(e_ptr, lda_e, m);

825:   /* Determine chunk size to avoid OOM on large grids */
826:   PetscInt chunk_size;
827:   if (impl->batch_size > 0) {
828:     chunk_size = impl->batch_size;
829:   } else {
830:     /* Target ~2GB workspace. Approx memory per point: m*m*8 (T) + p*m*8 (Z) */
831:     /* With reuse: m*m*8 + p*m*8 */
832:     PetscInt mem_per_point = sizeof(PetscScalar) * (m * m + impl->n_obs_vertex * m);
833:     chunk_size             = (PetscInt)(2.0 * 1024 * 1024 * 1024 / mem_per_point);
834:     /* Clamp to reasonable max to avoid huge allocations even if memory allows */
835:     if (chunk_size > 32768) chunk_size = 32768;
836:   }

838:   if (chunk_size < 1) chunk_size = 1;
839:   if (chunk_size > n_vertices) chunk_size = n_vertices;

841:   /* OPTIMIZATION: Create device solver handle once, reuse across chunks */
842: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
843:   #if defined(KOKKOS_ENABLE_CUDA)
844:   cusolverDnHandle_t device_handle = nullptr;
845:   cusolverStatus_t   cusolver_status;
846:   if (impl->solver_handle) {
847:     device_handle = static_cast<cusolverDnHandle_t>(impl->solver_handle);
848:   } else {
849:     cusolver_status = cusolverDnCreate(&device_handle);
850:     PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDnCreate failed");
851:     impl->solver_handle = static_cast<void *>(device_handle);
852:   }
853:   #elif defined(KOKKOS_ENABLE_HIP)
854:   rocblas_handle device_handle = nullptr;
855:   if (impl->solver_handle) {
856:     device_handle = static_cast<rocblas_handle>(impl->solver_handle);
857:   } else {
858:     rocblas_status hip_status = rocblas_create_handle(&device_handle);
859:     PetscCheck(hip_status == rocblas_status_success, PETSC_COMM_SELF, PETSC_ERR_LIB, "rocblas_create_handle failed");
860:     impl->solver_handle = static_cast<void *>(device_handle);
861:   }
862:   #elif defined(KOKKOS_ENABLE_SYCL)
863:   sycl::queue *device_handle = nullptr;
864:   if (impl->solver_handle) {
865:     device_handle = static_cast<sycl::queue *>(impl->solver_handle);
866:   } else {
867:     device_handle       = new sycl::queue(sycl::gpu_selector_v);
868:     impl->solver_handle = static_cast<void *>(device_handle);
869:   }
870:   #endif
871: #endif

873:   /* ===================================================================== */
874:   /* OPTIMIZATION: Hoist allocations outside the chunk loop                */
875:   /* ===================================================================== */
876:   /* Allocate Kokkos Views once for the maximum chunk size */
877:   PetscInt n_obs_vertex_copy = impl->n_obs_vertex;

879:   EigenWorkspace *eigen_work = static_cast<EigenWorkspace *>(impl->eigen_work);
880:   if (!eigen_work) {
881:     eigen_work       = new EigenWorkspace();
882:     impl->eigen_work = static_cast<void *>(eigen_work);
883:   }

885:   /* Check if reallocation is needed */
886:   if (eigen_work->max_chunk_size < chunk_size || eigen_work->m != m || eigen_work->n_obs_vertex != n_obs_vertex_copy) {
887:     /* Free old device workspace if exists */
888: #if defined(KOKKOS_ENABLE_CUDA)
889:     PetscCallCUDA(cudaFree(eigen_work->d_work));
890:     PetscCallCUDA(cudaFree(eigen_work->d_info));
891:     PetscCallCUDA(cudaFree(eigen_work->d_A_contig));
892:     PetscCallCUDA(cudaFree(eigen_work->d_W_contig));
893:     if (eigen_work->syevj_params) cusolverDnDestroySyevjInfo(eigen_work->syevj_params);
894:     eigen_work->syevj_params = nullptr;
895: #elif defined(KOKKOS_ENABLE_HIP)
896:     PetscCallHIP(hipFree(eigen_work->d_work));
897:     PetscCallHIP(hipFree(eigen_work->d_info));
898:     PetscCallHIP(hipFree(eigen_work->d_A_contig));
899:     PetscCallHIP(hipFree(eigen_work->d_W_contig));
900: #elif defined(KOKKOS_ENABLE_SYCL)
901:     if (eigen_work->d_work) sycl::free(eigen_work->d_work, *device_handle);
902:     if (eigen_work->d_info) sycl::free(eigen_work->d_info, *device_handle);
903:     if (eigen_work->d_A_contig) sycl::free(eigen_work->d_A_contig, *device_handle);
904:     if (eigen_work->d_W_contig) sycl::free(eigen_work->d_W_contig, *device_handle);
905: #endif

907: #if !defined(KOKKOS_ENABLE_CUDA) && !defined(KOKKOS_ENABLE_HIP) && !defined(KOKKOS_ENABLE_SYCL)
908:   #if defined(PETSC_USE_COMPLEX)
909:     if (eigen_work->all_v) PetscCall(PetscFree4(eigen_work->all_v, eigen_work->all_lambda, eigen_work->all_work, eigen_work->all_rwork));
910:   #else
911:     if (eigen_work->all_v) PetscCall(PetscFree3(eigen_work->all_v, eigen_work->all_lambda, eigen_work->all_work));
912:   #endif
913: #endif

915:     /* Update dimensions */
916:     eigen_work->max_chunk_size = chunk_size;
917:     eigen_work->m              = m;
918:     eigen_work->n_obs_vertex   = n_obs_vertex_copy;

920:     /* Allocate Kokkos Views */
921:     eigen_work->Z_batch               = view_3d("Z_batch", chunk_size, n_obs_vertex_copy, m);
922:     eigen_work->S_batch               = eigen_work->Z_batch;
923:     eigen_work->T_batch               = view_3d("T_batch", chunk_size, m, m);
924:     eigen_work->V_batch               = eigen_work->T_batch;
925:     eigen_work->Lambda_batch          = view_2d("Lambda_batch", chunk_size, m);
926:     eigen_work->T_sqrt_batch          = view_3d("T_sqrt_batch", chunk_size, m, m);
927:     eigen_work->w_batch               = view_2d("w_batch", chunk_size, m);
928:     eigen_work->delta_batch           = view_2d("delta_batch", chunk_size, n_obs_vertex_copy);
929:     eigen_work->y_batch               = view_2d("y_batch", chunk_size, n_obs_vertex_copy);
930:     eigen_work->y_mean_batch          = view_2d("y_mean_batch", chunk_size, n_obs_vertex_copy);
931:     eigen_work->r_inv_sqrt_batch      = view_2d("r_inv_sqrt_batch", chunk_size, n_obs_vertex_copy);
932:     eigen_work->temp1_batch           = view_2d("temp1_batch", chunk_size, m);
933:     eigen_work->temp2_batch           = view_2d("temp2_batch", chunk_size, m);
934:     eigen_work->inv_sqrt_lambda_batch = view_2d("inv_sqrt_lambda_batch", chunk_size, m);

936:     /* Allocate solver workspace */
937: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
938:   #if defined(KOKKOS_ENABLE_CUDA)
939:     {
940:       /* Create syevj params */
941:       cusolver_status = cusolverDnCreateSyevjInfo(&eigen_work->syevj_params);
942:       PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDnCreateSyevjInfo failed");

944:       /* Set default params */
945:       cusolverDnXsyevjSetTolerance(eigen_work->syevj_params, 1e-7);
946:       cusolverDnXsyevjSetMaxSweeps(eigen_work->syevj_params, 100);
947:       cusolverDnXsyevjSetSortEig(eigen_work->syevj_params, 1); /* Sort eigenvalues */

949:       /* Query workspace size */
950:       PetscScalar *d_A = eigen_work->T_batch.data();
951:       PetscScalar *d_W = eigen_work->Lambda_batch.data();
952:       int          lwork;
953:     #if defined(PETSC_USE_REAL_SINGLE)
954:       cusolver_status = cusolverDnSsyevjBatched_bufferSize(device_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, m, d_A, m, d_W, &lwork, eigen_work->syevj_params, chunk_size);
955:     #else
956:       cusolver_status = cusolverDnDsyevjBatched_bufferSize(device_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, m, d_A, m, d_W, &lwork, eigen_work->syevj_params, chunk_size);
957:     #endif
958:       PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDn*syevjBatched_bufferSize failed");
959:       eigen_work->lwork_device = lwork;

961:       /* Allocate workspace */
962:       PetscCallCUDA(cudaMalloc(&eigen_work->d_work, sizeof(PetscScalar) * lwork));
963:       PetscCallCUDA(cudaMalloc(&eigen_work->d_info, sizeof(int) * chunk_size));
964:       PetscCallCUDA(cudaMalloc(&eigen_work->d_A_contig, sizeof(PetscScalar) * chunk_size * m * m));
965:       PetscCallCUDA(cudaMalloc(&eigen_work->d_W_contig, sizeof(PetscScalar) * chunk_size * m));
966:     }
967:   #elif defined(KOKKOS_ENABLE_HIP)
968:     {
969:         /* rocsolver_dsyevd does not support size query via -1.
970:          We use a safe upper bound estimate based on LAPACK dsyevd requirements.
971:       */
972:     #if defined(PETSC_USE_COMPLEX)
973:       int lwork = 0; /* Complex not supported on device */
974:     #else
975:       int lwork = 1 + 6 * m + 2 * m * m;
976:     #endif
977:       eigen_work->lwork_device = lwork;

979:       /* Allocate workspace */
980:       if (lwork > 0) {
981:         PetscCallHIP(hipMalloc(&eigen_work->d_work, sizeof(PetscScalar) * lwork));
982:         PetscCallHIP(hipMalloc(&eigen_work->d_info, sizeof(int) * chunk_size));
983:         PetscCallHIP(hipMalloc(&eigen_work->d_A_contig, sizeof(PetscScalar) * chunk_size * m * m));
984:         PetscCallHIP(hipMalloc(&eigen_work->d_W_contig, sizeof(PetscScalar) * chunk_size * m));
985:       }
986:     }
987:   #elif defined(KOKKOS_ENABLE_SYCL)
988:     {
989:       /* Query workspace size for oneapi::mkl::lapack::syevd */
990:       /* For syevd, workspace size is typically: */
991:       /* lwork >= 1 + 6*n + 2*n*n for real, or */
992:       /* lwork >= 2*n + n*n for complex */
993:       int lwork;
994:     #if defined(PETSC_USE_COMPLEX)
995:       lwork = 2 * m + m * m;
996:     #else
997:       lwork = 1 + 6 * m + 2 * m * m;
998:     #endif
999:       eigen_work->lwork_device = lwork;

1001:       /* Allocate workspace using SYCL malloc_device */
1002:       eigen_work->d_work     = sycl::malloc_device<PetscScalar>(lwork, *device_handle);
1003:       eigen_work->d_info     = sycl::malloc_device<int>(chunk_size, *device_handle);
1004:       eigen_work->d_A_contig = sycl::malloc_device<PetscScalar>(chunk_size * m * m, *device_handle);
1005:       eigen_work->d_W_contig = sycl::malloc_device<PetscScalar>(chunk_size * m, *device_handle);
1006:       PetscCheck(eigen_work->d_work && eigen_work->d_info && eigen_work->d_A_contig && eigen_work->d_W_contig, PETSC_COMM_SELF, PETSC_ERR_MEM, "SYCL memory allocation failed");
1007:     }
1008:   #endif
1009: #else
1010:     {
1011:       PetscBLASInt n_blas;
1012:       PetscCall(PetscBLASIntCast(m, &n_blas));
1013:       eigen_work->n_blas = n_blas;

1015:       /* Query workspace size */
1016:       PetscBLASInt lwork_query = -1;
1017:       PetscScalar  work_query;
1018:       PetscBLASInt info;
1019:   #if defined(PETSC_USE_COMPLEX)
1020:       PetscReal rwork_query;
1021:       LAPACKsyev_("V", "U", &n_blas, &work_query, &n_blas, &rwork_query, &work_query, &lwork_query, &rwork_query, &info);
1022:   #else
1023:       LAPACKsyev_("V", "U", &n_blas, &work_query, &n_blas, &work_query, &work_query, &lwork_query, &info);
1024:   #endif
1025:       PetscCheck(info == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "LAPACK workspace query failed");
1026:       eigen_work->lwork = (PetscBLASInt)PetscRealPart(work_query);

1028:       /* Allocate workspace */
1029:   #if defined(PETSC_USE_COMPLEX)
1030:       PetscCall(PetscMalloc4(chunk_size * m * m, &eigen_work->all_v, chunk_size * m, &eigen_work->all_lambda, chunk_size * eigen_work->lwork, &eigen_work->all_work, chunk_size * (3 * m - 2), &eigen_work->all_rwork));
1031:   #else
1032:       PetscCall(PetscMalloc3(chunk_size * m * m, &eigen_work->all_v, chunk_size * m, &eigen_work->all_lambda, chunk_size * eigen_work->lwork, &eigen_work->all_work));
1033:   #endif
1034:     }
1035: #endif
1036:   }

1038:   /* Create aliases for current function use */
1039:   view_3d Z_batch_alloc               = eigen_work->Z_batch;
1040:   view_3d S_batch_alloc               = eigen_work->S_batch;
1041:   view_3d T_batch_alloc               = eigen_work->T_batch;
1042:   view_3d V_batch_alloc               = eigen_work->V_batch;
1043:   view_2d Lambda_batch_alloc          = eigen_work->Lambda_batch;
1044:   view_3d T_sqrt_batch_alloc          = eigen_work->T_sqrt_batch;
1045:   view_2d w_batch_alloc               = eigen_work->w_batch;
1046:   view_2d delta_batch_alloc           = eigen_work->delta_batch;
1047:   view_2d y_batch_alloc               = eigen_work->y_batch;
1048:   view_2d y_mean_batch_alloc          = eigen_work->y_mean_batch;
1049:   view_2d r_inv_sqrt_batch_alloc      = eigen_work->r_inv_sqrt_batch;
1050:   view_2d temp1_batch_alloc           = eigen_work->temp1_batch;
1051:   view_2d temp2_batch_alloc           = eigen_work->temp2_batch;
1052:   view_2d inv_sqrt_lambda_batch_alloc = eigen_work->inv_sqrt_lambda_batch;

1054:   /* Loop over chunks */
1055:   for (PetscInt chunk_start = 0; chunk_start < n_vertices; chunk_start += chunk_size) {
1056:     PetscInt chunk_end       = (chunk_start + chunk_size > n_vertices) ? n_vertices : chunk_start + chunk_size;
1057:     PetscInt n_batch_current = chunk_end - chunk_start;

1059:     /* Create subviews for current batch size */
1060:     auto Z_batch               = Kokkos::subview(Z_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1061:     auto S_batch               = Kokkos::subview(S_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1062:     auto T_batch               = Kokkos::subview(T_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1063:     auto V_batch               = Kokkos::subview(V_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1064:     auto Lambda_batch          = Kokkos::subview(Lambda_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1065:     auto T_sqrt_batch          = Kokkos::subview(T_sqrt_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1066:     auto w_batch               = Kokkos::subview(w_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1067:     auto delta_batch           = Kokkos::subview(delta_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1068:     auto y_batch               = Kokkos::subview(y_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1069:     auto y_mean_batch          = Kokkos::subview(y_mean_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1070:     auto r_inv_sqrt_batch      = Kokkos::subview(r_inv_sqrt_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1071:     auto temp1_batch           = Kokkos::subview(temp1_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1072:     auto temp2_batch           = Kokkos::subview(temp2_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1073:     auto inv_sqrt_lambda_batch = Kokkos::subview(inv_sqrt_lambda_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());

1075:     /* ===================================================================== */
1076:     /* Step 2.1.2: Fused observation extraction and S/Delta computation     */
1077:     /* ===================================================================== */
1078:     /* Extract local observations and immediately compute S and delta       */
1079:     /* This fusion eliminates one kernel launch and improves cache locality */
1080:     Kokkos::parallel_for(
1081:       "ExtractAndComputeSAndDelta", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i_local) {
1082:         PetscInt i_global = chunk_start + i_local;
1083:         /* Get Q row for this grid point using CSR format */
1084:         PetscInt row_start = Q_i_view(i_global);
1085:         PetscInt row_end   = Q_i_view(i_global + 1);
1086:         PetscInt ncols     = row_end - row_start;

1088:         /* Extract observations and compute S/delta for this grid point */
1089:         for (PetscInt k = 0; k < ncols; k++) {
1090:           PetscInt    obs_idx = Q_j_view(row_start + k);
1091:           PetscScalar weight  = Q_a_view(row_start + k);

1093:           /* Extract observation vectors */
1094:           PetscScalar y_val      = y_global_view(obs_idx);
1095:           PetscScalar y_mean_val = y_mean_global_view(obs_idx);
1096:           PetscScalar r_inv_sqrt = r_inv_sqrt_global_view(obs_idx) * Kokkos::sqrt(PetscRealPart(weight));

1098:           /* Store for later use if needed */
1099:           y_batch(i_local, k)          = y_val;
1100:           y_mean_batch(i_local, k)     = y_mean_val;
1101:           r_inv_sqrt_batch(i_local, k) = r_inv_sqrt;

1103:           /* Compute delta immediately: delta = R^{-1/2}(y - y_mean) */
1104:           delta_batch(i_local, k) = (y_val - y_mean_val) * r_inv_sqrt;

1106:           /* Compute S row: S = R^{-1/2}(Z - y_mean * 1')/sqrt(m-1) */
1107:           PetscScalar scale_factor = scale * r_inv_sqrt;
1108:           for (int j = 0; j < m; j++) {
1109:             PetscScalar z_val      = Z_global_view(obs_idx, j);
1110:             Z_batch(i_local, k, j) = z_val; /* Store Z for potential later use */
1111:             S_batch(i_local, k, j) = (z_val - y_mean_val) * scale_factor;
1112:           }
1113:         }
1114:       });
1115:     Kokkos::fence();

1117:     /* DEBUG: Check S for NaNs */
1118:     if (PetscDefined(USE_DEBUG)) {
1119:       PetscInt nan_count = 0;
1120:       Kokkos::parallel_reduce(
1121:         "CheckS", Kokkos::RangePolicy<exec_space>(0, n_batch_current),
1122:         KOKKOS_LAMBDA(const int i, PetscInt &l_count) {
1123:           for (int j = 0; j < n_obs_vertex_copy; j++) {
1124:             for (int k = 0; k < m; k++) {
1125:               if (S_batch(i, j, k) != S_batch(i, j, k)) l_count++;
1126:             }
1127:           }
1128:         },
1129:         nan_count);
1130:       PetscCheck(nan_count == 0, PETSC_COMM_SELF, PETSC_ERR_FP, "Found %" PetscInt_FMT " NaNs in S_batch at chunk_start %" PetscInt_FMT, nan_count, chunk_start);
1131:     }

1133:     /* ===================================================================== */
1134:     /* Step 2.1.4: Optimized T matrix formation (T = (1/rho)I + S^T * S)    */
1135:     /* ===================================================================== */
1136:     /* Compute T_i = (1/rho)I + S_i^T * S_i for current chunk */
1137:     /* Exploit symmetry: only compute upper triangle, then copy to lower */
1138:     /* This reduces operations by ~50% */
1139:     Kokkos::parallel_for(
1140:       "ComputeAllTMatrices", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i) {
1141:         auto S_i = Kokkos::subview(S_batch, i, Kokkos::ALL(), Kokkos::ALL());
1142:         auto T_i = Kokkos::subview(T_batch, i, Kokkos::ALL(), Kokkos::ALL());

1144:         /* Compute upper triangle of T_i = (1/rho)I + S_i^T * S_i */
1145:         /* T_i(j,k) = (1/rho)*delta_jk + sum_p S_i(p,j) * S_i(p,k) for j <= k */
1146:         for (int j = 0; j < m; j++) {
1147:           for (int k = j; k < m; k++) {
1148:             PetscScalar sum = (j == k) ? inflation_inv : 0.0;
1149:             for (int p = 0; p < n_obs_vertex_copy; p++) sum += S_i(p, j) * S_i(p, k);
1150:             T_i(j, k) = sum;
1151:           }
1152:         }

1154:         /* Copy upper triangle to lower triangle (T is symmetric) */
1155:         for (int j = 0; j < m; j++) {
1156:           for (int k = 0; k < j; k++) T_i(j, k) = T_i(k, j);
1157:         }
1158:       });
1159:     Kokkos::fence();

1161:     /* DEBUG: Check T for NaNs */
1162:     if (PetscDefined(USE_DEBUG)) {
1163:       PetscInt nan_count = 0;
1164:       Kokkos::parallel_reduce(
1165:         "CheckT", Kokkos::RangePolicy<exec_space>(0, n_batch_current),
1166:         KOKKOS_LAMBDA(const int i, PetscInt &l_count) {
1167:           for (int j = 0; j < m; j++) {
1168:             for (int k = 0; k < m; k++) {
1169:               if (T_batch(i, j, k) != T_batch(i, j, k)) l_count++;
1170:             }
1171:           }
1172:         },
1173:         nan_count);
1174:       PetscCheck(nan_count == 0, PETSC_COMM_SELF, PETSC_ERR_FP, "Found %" PetscInt_FMT " NaNs in T_batch at chunk_start %" PetscInt_FMT, nan_count, chunk_start);
1175:     }

1177:     /* ===================================================================== */
1178:     /* Step 3.1.1: Batched eigendecomposition for current chunk            */
1179:     /* ===================================================================== */
1180:     /* Compute T_i = V_i * Lambda_i * V_i^T for current chunk */
1181: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
1182:     PetscCall(BatchedEigenSolve(T_batch, Lambda_batch, V_batch, n_batch_current, m, device_handle, eigen_work));
1183: #else
1184:     PetscCall(BatchedEigenSolve(T_batch, Lambda_batch, V_batch, n_batch_current, m, eigen_work));
1185: #endif

1187:     /* DEBUG: Check Lambda for NaNs or negative values */
1188:     if (PetscDefined(USE_DEBUG)) {
1189:       PetscInt bad_lambda = 0;
1190:       Kokkos::parallel_reduce(
1191:         "CheckLambda", Kokkos::RangePolicy<exec_space>(0, n_batch_current),
1192:         KOKKOS_LAMBDA(const int i, PetscInt &l_count) {
1193:           for (int k = 0; k < m; k++) {
1194:             if (Lambda_batch(i, k) != Lambda_batch(i, k) || PetscRealPart(Lambda_batch(i, k)) < -1e-8) l_count++;
1195:           }
1196:         },
1197:         bad_lambda);
1198:       PetscCheck(bad_lambda == 0, PETSC_COMM_SELF, PETSC_ERR_FP, "Found %" PetscInt_FMT " bad eigenvalues (NaN or negative) at chunk_start %" PetscInt_FMT, bad_lambda, chunk_start);
1199:     }

1201:     /* ===================================================================== */
1202:     /* Step 3.1.2: Precompute w and inv_sqrt_lambda for ensemble update    */
1203:     /* ===================================================================== */
1204:     /* Compute w_i = T_i^{-1} * (S_i^T * delta_i) using eigendecomposition */
1205:     /* Precompute 1/sqrt(Lambda) for use in ensemble update */
1206:     Kokkos::parallel_for(
1207:       "ComputeWeightsAndInvSqrtLambda", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i) {
1208:         auto S_i               = Kokkos::subview(S_batch, i, Kokkos::ALL(), Kokkos::ALL());
1209:         auto V_i               = Kokkos::subview(V_batch, i, Kokkos::ALL(), Kokkos::ALL());
1210:         auto Lambda_i          = Kokkos::subview(Lambda_batch, i, Kokkos::ALL());
1211:         auto delta_i           = Kokkos::subview(delta_batch, i, Kokkos::ALL());
1212:         auto w_i               = Kokkos::subview(w_batch, i, Kokkos::ALL());
1213:         auto inv_sqrt_lambda_i = Kokkos::subview(inv_sqrt_lambda_batch, i, Kokkos::ALL());
1214:         auto temp1             = Kokkos::subview(temp1_batch, i, Kokkos::ALL());
1215:         auto temp2             = Kokkos::subview(temp2_batch, i, Kokkos::ALL());

1217:         /* 1. Compute w_i = V * L^-1 * V^T * S^T * delta */
1218:         /* Step 1a: temp1 = S^T * delta using KokkosBlas::gemv for better vectorization */
1219:         KokkosBlas::SerialGemv<KokkosBlas::Trans::Transpose, KokkosBlas::Algo::Gemv::Unblocked>::invoke(1.0, S_i, delta_i, 0.0, temp1);

1221:         /* Step 1b: temp2 = V^T * temp1 using KokkosBlas::gemv for better vectorization */
1222:         KokkosBlas::SerialGemv<KokkosBlas::Trans::Transpose, KokkosBlas::Algo::Gemv::Unblocked>::invoke(1.0, V_i, temp1, 0.0, temp2);

1224:         /* Step 1c: temp2 = temp2 / Lambda */
1225:         for (int j = 0; j < m; j++) temp2(j) /= (Lambda_i(j) + 1.0e-14);

1227:         /* Step 1d: w = V * temp2 using KokkosBlas::gemv for better vectorization */
1228:         KokkosBlas::SerialGemv<KokkosBlas::Trans::NoTranspose, KokkosBlas::Algo::Gemv::Unblocked>::invoke(1.0, V_i, temp2, 0.0, w_i);

1230:         /* 2. Precompute 1/sqrt(Lambda) for ensemble update */
1231:         for (int p = 0; p < m; p++) inv_sqrt_lambda_i(p) = 1.0 / Kokkos::sqrt(PetscRealPart(Lambda_i(p)) + 1.0e-14);
1232:       });
1233:     Kokkos::fence();

1235:     /* ===================================================================== */
1236:     /* Step 3.1.3: Fused G computation and ensemble update                  */
1237:     /* ===================================================================== */
1238:     /* Compute E[i,:] = mean[i] + X[i,:] * G_i on-the-fly */
1239:     /* G_i is computed column-by-column and immediately applied */
1240:     /* This eliminates the need to store G_batch, saving m*m*n_batch memory */
1241:     Kokkos::parallel_for(
1242:       "FusedGComputeAndEnsembleUpdate", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i_local) {
1243:         PetscInt i_global = chunk_start + i_local;

1245:         auto X_i    = Kokkos::subview(X_view, Kokkos::make_pair(i_global * ndof, (i_global + 1) * ndof), Kokkos::ALL());
1246:         auto E_i    = Kokkos::subview(E_view, Kokkos::make_pair(i_global * ndof, (i_global + 1) * ndof), Kokkos::ALL());
1247:         auto mean_i = Kokkos::subview(mean_view, Kokkos::make_pair(i_global * ndof, (i_global + 1) * ndof));

1249:         auto V_i               = Kokkos::subview(V_batch, i_local, Kokkos::ALL(), Kokkos::ALL());
1250:         auto w_i               = Kokkos::subview(w_batch, i_local, Kokkos::ALL());
1251:         auto inv_sqrt_lambda_i = Kokkos::subview(inv_sqrt_lambda_batch, i_local, Kokkos::ALL());
1252:         auto T_sqrt_i          = Kokkos::subview(T_sqrt_batch, i_local, Kokkos::ALL(), Kokkos::ALL());

1254:         /* Initialize E_i with mean */
1255:         for (int row = 0; row < ndof; row++) {
1256:           PetscScalar m_val = mean_i(row);
1257:           for (int col = 0; col < m; col++) E_i(row, col) = m_val;
1258:         }

1260:         /* Compute T_sqrt = V * diag(1/sqrt(Lambda)) * V^T */
1261:         /* Optimized: Exploit symmetry - only compute upper triangle, then copy to lower */
1262:         /* T_sqrt(j,k) = sum_p V(j,p) * V(k,p) / sqrt(Lambda(p)) for j <= k */
1263:         for (int j = 0; j < m; j++) {
1264:           for (int k = j; k < m; k++) {
1265:             PetscScalar sum = 0.0;
1266:             for (int p = 0; p < m; p++) sum += V_i(j, p) * V_i(k, p) * inv_sqrt_lambda_i(p);
1267:             T_sqrt_i(j, k) = sum;
1268:           }
1269:         }
1270:         /* Copy upper triangle to lower triangle (T_sqrt is symmetric) */
1271:         for (int j = 0; j < m; j++) {
1272:           for (int k = 0; k < j; k++) T_sqrt_i(j, k) = T_sqrt_i(k, j);
1273:         }

1275:         /* Compute E_i += X_i * G_i column-by-column */
1276:         /* G_i(:,k) = w_i + sqrt(m-1) * T_sqrt_i(:,k) */
1277:         for (int k = 0; k < m; k++) {
1278:           /* Compute column k of G on-the-fly */
1279:           for (int row = 0; row < ndof; row++) {
1280:             PetscScalar sum = 0.0;
1281:             for (int j = 0; j < m; j++) {
1282:               /* G_i(j,k) = w_i(j) + sqrt(m-1) * T_sqrt_i(j,k) */
1283:               PetscScalar G_jk = w_i(j) + sqrt_m_minus_1 * T_sqrt_i(j, k);
1284:               sum += X_i(row, j) * G_jk;
1285:             }
1286:             E_i(row, k) += sum;
1287:           }
1288:         }
1289:       });
1290:     Kokkos::fence();
1291:   }

1293:   /* Cleanup workspace */
1294:   /* NOTE: Workspace is now persistent in impl->eigen_work and impl->solver_handle */
1295:   /* It will be destroyed in PetscDALETKFDestroyLocalization_Kokkos */

1297:   /* Copy back updated ensemble if needed */
1298:   if (e_is_copy) {
1299:     Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> dst(e_array, lda_e, m);
1300:     Kokkos::deep_copy(dst, e_managed);
1301:   }

1303:   /* Restore arrays */
1304:   PetscCall(MatDenseRestoreArrayWriteAndMemType(en->ensemble, &e_array));
1305:   PetscCall(VecRestoreArrayReadAndMemType(impl->mean, &mean_array));
1306:   PetscCall(MatDenseRestoreArrayReadAndMemType(X, &x_array));

1308:   /* Restore global observation arrays */
1309:   PetscCall(VecRestoreArrayReadAndMemType(r_inv_sqrt_global, &r_inv_sqrt_global_array));
1310:   PetscCall(VecRestoreArrayReadAndMemType(y_mean_global, &y_mean_global_array));
1311:   PetscCall(VecRestoreArrayReadAndMemType(observation, &y_global_array));
1312:   PetscCall(MatDenseRestoreArrayReadAndMemType(Z_global, &z_global_array));

1314:   /* Ensemble has been updated in batched form above */
1315:   PetscCall(MatAssemblyBegin(en->ensemble, MAT_FINAL_ASSEMBLY));
1316:   PetscCall(MatAssemblyEnd(en->ensemble, MAT_FINAL_ASSEMBLY));

1318:   {
1319:     MatInfo   info;
1320:     PetscReal flops = 0.0;
1321:     PetscReal n_obs_total;

1323:     if (impl->Q) {
1324:       PetscCall(MatGetInfo(impl->Q, MAT_LOCAL, &info));
1325:       n_obs_total = info.nz_used;
1326:     } else {
1327:       n_obs_total = 0.0;
1328:     }

1330:     /* Step 2.1.2: Fused observation extraction and S/Delta computation */
1331:     flops += n_obs_total * (2.0 + 2.0 * m);

1333:     /* Step 2.1.4: Optimized T matrix formation */
1334:     flops += (PetscReal)n_vertices * m * (m + 1) * impl->n_obs_vertex;

1336:     /* Step 3.1.2: Precompute w and inv_sqrt_lambda */
1337:     flops += (PetscReal)n_vertices * (2.0 * m * impl->n_obs_vertex + 4.0 * m * m + 3.0 * m);

1339:     /* Step 3.1.3: Fused G computation and ensemble update */
1340:     /* T_sqrt: 1.5*m^3 + 1.5*m^2 */
1341:     flops += (PetscReal)n_vertices * (1.5 * m * m * m + 1.5 * m * m);
1342:     /* E update: ndof * m * (4*m + 1) */
1343:     /* Note: G_jk computation (2 flops) is inside the inner loop, so it's 2*m*ndof*m */
1344:     /* Matrix product X*G (2 flops) is also 2*m*ndof*m */
1345:     flops += (PetscReal)n_vertices * ndof * m * (4.0 * m + 1.0);

1347:     PetscCall(PetscLogGpuFlops(flops));
1348:   }
1349:   PetscFunctionReturn(PETSC_SUCCESS);
1350: }