Actual source code: matdiagonalcupm.hpp

  1: #pragma once

  3: #include <petscmat.h>

  5: #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"

  7: #include <petsc/private/cupminterface.hpp>
  8: #include <petsc/private/cupmobject.hpp>
  9: #include <petsc/private/deviceimpl.h>
 10: #include <petsc/private/vecimpl.h>
 11: #include <petsc/private/veccupmimpl.h>
 12: #include <petsc/private/matimpl.h>

 14: #include <thrust/device_ptr.h>
 15: #include <thrust/iterator/zip_iterator.h>
 16: #include <thrust/transform_reduce.h>

 18: namespace Petsc
 19: {

 21: namespace device
 22: {

 24: namespace cupm
 25: {

 27: namespace impl
 28: {

 30: template <DeviceType T, typename VecType>
 31: struct MatDiagonal_CUPM : vec::cupm::impl::Vec_CUPMBase<T, VecType> {
 32:   PETSC_CUPMOBJECT_HEADER(T);
 33:   using base_type = ::Petsc::vec::cupm::impl::Vec_CUPMBase<T, VecType>;
 34:   friend base_type;

 36:   static PetscErrorCode ADot(Mat A, Vec x, Vec y, PetscScalar *z) noexcept;
 37:   static PetscErrorCode ANormSq(Mat A, Vec x, PetscReal *z) noexcept;
 38: };

 40: namespace detail
 41: {
 42: struct adot_transform {
 43:   using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar>;

 45:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const argument_type &tup) const noexcept { return PetscConj(thrust::get<1>(tup)) * thrust::get<2>(tup) * thrust::get<0>(tup); }
 46: };
 47: } // namespace detail

 49: template <Petsc::device::cupm::DeviceType T, typename VecType>
 50: inline PetscErrorCode MatDiagonal_CUPM<T, VecType>::ADot(Mat A, Vec x, Vec y, PetscScalar *z) noexcept
 51: {
 52:   PetscDeviceContext dctx;
 53:   cupmStream_t       stream;
 54:   Mat_Diagonal      *ctx  = (Mat_Diagonal *)A->data;
 55:   PetscScalar        zero = 0.;
 56:   const PetscInt     n    = x->map->n;

 58:   PetscFunctionBegin;
 59:   PetscCall(GetHandles_(&dctx, &stream));

 61:   const auto xdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, x).data());
 62:   const auto ydptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, y).data());
 63:   const auto wdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, ctx->diag).data());

 65:   // clang-format off
 66:     PetscCallThrust(
 67:       *z = THRUST_CALL(
 68:         thrust::transform_reduce,
 69:         stream,
 70:         thrust::make_zip_iterator(thrust::make_tuple(xdptr, ydptr, wdptr)),
 71:         thrust::make_zip_iterator(thrust::make_tuple(xdptr + n, ydptr + n, wdptr + n)),
 72:         detail::adot_transform{},
 73:         zero,
 74:         thrust::plus<PetscScalar>()
 75:       )
 76:     );
 77:   // clang-format on
 78:   if (x->map->n > 0) PetscCall(PetscLogGpuFlops(3.0 * x->map->n));
 79:   PetscFunctionReturn(PETSC_SUCCESS);
 80: }

 82: namespace detail
 83: {
 84: struct anorm_transform {
 85:   using argument_type = thrust::tuple<PetscScalar, PetscScalar>;

 87:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const argument_type &tup) const noexcept { return thrust::get<1>(tup) * PetscConj(thrust::get<0>(tup)) * thrust::get<0>(tup); }
 88: };
 89: } // namespace detail

 91: template <Petsc::device::cupm::DeviceType T, typename VecType>
 92: inline PetscErrorCode MatDiagonal_CUPM<T, VecType>::ANormSq(Mat A, Vec x, PetscReal *z) noexcept
 93: {
 94:   PetscDeviceContext dctx;
 95:   cupmStream_t       stream;
 96:   Mat_Diagonal      *ctx  = (Mat_Diagonal *)A->data;
 97:   PetscScalar        zero = 0., res;
 98:   const PetscInt     n    = x->map->n;

100:   PetscFunctionBegin;
101:   PetscCall(GetHandles_(&dctx, &stream));

103:   const auto xdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, x).data());
104:   const auto wdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, ctx->diag).data());

106:   // clang-format off
107:   PetscCallThrust(
108:     res = THRUST_CALL(
109:       thrust::transform_reduce,
110:       stream,
111:       thrust::make_zip_iterator(thrust::make_tuple(xdptr, wdptr)),
112:       thrust::make_zip_iterator(thrust::make_tuple(xdptr + n, wdptr + n)),
113:       detail::anorm_transform{},
114:       zero,
115:       thrust::plus<PetscScalar>()
116:     )
117:   );
118:   // clang-format on
119:   *z = PetscRealPart(res);
120:   if (x->map->n > 0) PetscCall(PetscLogGpuFlops(3.0 * x->map->n));
121:   PetscFunctionReturn(PETSC_SUCCESS);
122: }

124: } // namespace impl

126: } // namespace cupm

128: } // namespace device

130: } // namespace Petsc