Actual source code: cupmblasinterface.hpp
1: #ifndef PETSCCUPMBLASINTERFACE_HPP
2: #define PETSCCUPMBLASINTERFACE_HPP
4: #if defined(__cplusplus)
5: #include <petsc/private/cupminterface.hpp>
6: #include <petsc/private/petscadvancedmacros.h>
8: namespace Petsc
9: {
11: namespace device
12: {
14: namespace cupm
15: {
17: namespace impl
18: {
20: #define PetscCallCUPMBLAS(...) \
21: do { \
22: const cupmBlasError_t cberr_p_ = __VA_ARGS__; \
23: if (PetscUnlikely(cberr_p_ != CUPMBLAS_STATUS_SUCCESS)) { \
24: if (((cberr_p_ == CUPMBLAS_STATUS_NOT_INITIALIZED) || (cberr_p_ == CUPMBLAS_STATUS_ALLOC_FAILED)) && PetscDeviceInitialized(PETSC_DEVICE_CUPM())) { \
25: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, \
26: "%s error %d (%s). Reports not initialized or alloc failed; " \
27: "this indicates the GPU may have run out resources", \
28: cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
29: } \
30: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "%s error %d (%s)", cupmBlasName(), static_cast<PetscErrorCode>(cberr_p_), cupmBlasGetErrorName(cberr_p_)); \
31: } \
32: } while (0)
34: // given cupmBlas<T>axpy() then
35: // T = PETSC_CUPBLAS_FP_TYPE
36: // given cupmBlas<T><u>nrm2() then
37: // T = PETSC_CUPMBLAS_FP_INPUT_TYPE
38: // u = PETSC_CUPMBLAS_FP_RETURN_TYPE
39: #if PetscDefined(USE_COMPLEX)
40: #if PetscDefined(USE_REAL_SINGLE)
41: #define PETSC_CUPMBLAS_FP_TYPE_U C
42: #define PETSC_CUPMBLAS_FP_TYPE_L c
43: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U S
44: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L s
45: #elif PetscDefined(USE_REAL_DOUBLE)
46: #define PETSC_CUPMBLAS_FP_TYPE_U Z
47: #define PETSC_CUPMBLAS_FP_TYPE_L z
48: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U D
49: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L d
50: #endif
51: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
52: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
53: #else
54: #if PetscDefined(USE_REAL_SINGLE)
55: #define PETSC_CUPMBLAS_FP_TYPE_U S
56: #define PETSC_CUPMBLAS_FP_TYPE_L s
57: #elif PetscDefined(USE_REAL_DOUBLE)
58: #define PETSC_CUPMBLAS_FP_TYPE_U D
59: #define PETSC_CUPMBLAS_FP_TYPE_L d
60: #endif
61: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_U PETSC_CUPMBLAS_FP_TYPE_U
62: #define PETSC_CUPMBLAS_FP_INPUT_TYPE_L PETSC_CUPMBLAS_FP_TYPE_L
63: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_U
64: #define PETSC_CUPMBLAS_FP_RETURN_TYPE_L
65: #endif // USE_COMPLEX
67: #if !defined(PETSC_CUPMBLAS_FP_TYPE_U) && !PetscDefined(USE_REAL___FLOAT128)
68: #error "Unsupported floating-point type for CUDA/HIP BLAS"
69: #endif
71: // PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE_EXACT() - declaration to alias a CUDA/HIP BLAS integral
72: // constant value
73: //
74: // input params:
75: // OUR_PREFIX - prefix of the alias
76: // OUR_SUFFIX - suffix of the alias
77: // THEIR_PREFIX - prefix of the variable being aliased
78: // THEIR_SUFFIX - suffix of the variable being aliased
79: //
80: // example usage:
81: // PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE_EXACT(CUPMBLAS,_STATUS_SUCCESS,CUBLAS,_STATUS_SUCCESS) ->
82: // static const auto CUPMBLAS_STATUS_SUCCESS = CUBLAS_STATUS_SUCCESS
83: #define PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE_EXACT(OUR_PREFIX, OUR_SUFFIX, THEIR_PREFIX, THEIR_SUFFIX) PETSC_CUPM_ALIAS_INTEGRAL_VALUE_EXACT(OUR_PREFIX, OUR_SUFFIX, THEIR_PREFIX, THEIR_SUFFIX)
85: // PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE_COMMON() - declaration to alias a CUDA/HIP BLAS integral
86: // constant value
87: //
88: // input param:
89: // COMMON - common suffix of the CUDA/HIP blas variable being aliased
90: //
91: // notes:
92: // requires PETSC_CUPMBLAS_PREFIX_U to be defined as the specific UPPERCASE prefix of the
93: // variable being aliased
94: //
95: // example usage:
96: // #define PETSC_CUPMBLAS_PREFIX_U CUBLAS
97: // PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE_COMMON(_STATUS_SUCCESS) ->
98: // static const auto CUPMBLAS_STATUS_SUCCESS = CUBLAS_STATUS_SUCCESS
99: //
100: // #define PETSC_CUPMBLAS_PREFIX_U HIPBLAS
101: // PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE_COMMON(_STATUS_SUCCESS) ->
102: // static const auto CUPMBLAS_STATUS_SUCCESS = HIPBLAS_STATUS_SUCCESS
103: #define PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(COMMON) PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE_EXACT(CUPMBLAS, COMMON, PETSC_CUPMBLAS_PREFIX_U, COMMON)
105: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED() - Helper macro to build a "modified"
106: // blas function whose return type does not match the input type
107: //
108: // input param:
109: // func - base suffix of the blas function, e.g. nrm2
110: //
111: // notes:
112: // requires PETSC_CUPMBLAS_FP_INPUT_TYPE to be defined as the blas floating point input type
113: // letter ("S" for real/complex single, "D" for real/complex double).
114: //
115: // requires PETSC_CUPMBLAS_FP_RETURN_TYPE to be defined as the blas floating point output type
116: // letter ("c" for complex single, "z" for complex double and <absolutely nothing> for real
117: // single/double).
118: //
119: // In their infinite wisdom nvidia/amd have made the upper-case vs lower-case scheme
120: // infuriatingly inconsistent...
121: //
122: // example usage:
123: // #define PETSC_CUPMBLAS_FP_INPUT_TYPE S
124: // #define PETSC_CUPMBLAS_FP_RETURN_TYPE
125: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Snrm2
126: //
127: // #define PETSC_CUPMBLAS_FP_INPUT_TYPE D
128: // #define PETSC_CUPMBLAS_FP_RETURN_TYPE z
129: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(nrm2) -> Dznrm2
130: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_MODIFIED(func) PetscConcat(PetscConcat(PETSC_CUPMBLAS_FP_INPUT_TYPE, PETSC_CUPMBLAS_FP_RETURN_TYPE), func)
132: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE() - Helper macro to build Iamax and Iamin
133: // because they are both extra special
134: //
135: // input param:
136: // func - base suffix of the blas function, either amax or amin
137: //
138: // notes:
139: // The macro name literally stands for "I" ## "floating point type" because shockingly enough,
140: // that's what it does.
141: //
142: // requires PETSC_CUPMBLAS_FP_TYPE_L to be defined as the lower-case blas floating point input type
143: // letter ("s" for complex single, "z" for complex double, "s" for real single, and "d" for
144: // real double).
145: //
146: // example usage:
147: // #define PETSC_CUPMBLAS_FP_TYPE_L s
148: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amax) -> Isamax
149: //
150: // #define PETSC_CUPMBLAS_FP_TYPE_L z
151: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(amin) -> Izamin
152: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_IFPTYPE(func) PetscConcat(I, PetscConcat(PETSC_CUPMBLAS_FP_TYPE_L, func))
154: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD() - Helper macro to build a "standard"
155: // blas function name
156: //
157: // input param:
158: // func - base suffix of the blas function, e.g. axpy, scal
159: //
160: // notes:
161: // requires PETSC_CUPMBLAS_FP_TYPE to be defined as the blas floating-point letter ("C" for
162: // complex single, "Z" for complex double, "S" for real single, "D" for real double).
163: //
164: // example usage:
165: // #define PETSC_CUPMBLAS_FP_TYPE S
166: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Saxpy
167: //
168: // #define PETSC_CUPMBLAS_FP_TYPE Z
169: // PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(axpy) -> Zaxpy
170: #define PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_STANDARD(func) PetscConcat(PETSC_CUPMBLAS_FP_TYPE, func)
172: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT() - In case CUDA/HIP don't agree with our suffix
173: // one can provide both here
174: //
175: // input params:
176: // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
177: // IFPTYPE
178: // our_suffix - the suffix of the alias function
179: // their_suffix - the suffix of the funciton being aliased
180: //
181: // notes:
182: // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas function
183: // prefix. requires any other specific definitions required by the specific builder macro to
184: // also be defined. See PETSC_CUPM_ALIAS_FUNCTION_EXACT() for the exact expansion of the
185: // function alias.
186: //
187: // example usage:
188: // #define PETSC_CUPMBLAS_PREFIX cublas
189: // #define PETSC_CUPMBLAS_FP_TYPE C
190: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD,dot,dotc) ->
191: // template <typename... T>
192: // static constexpr auto cupmBlasXdot(T&&... args) *noexcept and returntype detection*
193: // {
194: // return cublasCdotc(std::forward<T>(args)...);
195: // }
196: #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, our_suffix, their_suffix) \
197: PETSC_CUPM_ALIAS_FUNCTION_EXACT(cupmBlasX, our_suffix, PETSC_CUPMBLAS_PREFIX, PetscConcat(PETSC_CUPMBLAS_BUILD_BLAS_FUNCTION_ALIAS_, MACRO_SUFFIX)(their_suffix))
199: // PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION() - Alias a CUDA/HIP blas function
200: //
201: // input params:
202: // MACRO_SUFFIX - suffix to one of the above blas function builder macros, e.g. STANDARD or
203: // IFPTYPE
204: // suffix - the common suffix between CUDA and HIP of the alias function
205: //
206: // notes:
207: // see PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(), this macro just calls that one with "suffix" as
208: // "our_prefix" and "their_prefix"
209: #define PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MACRO_SUFFIX, suffix) PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(MACRO_SUFFIX, suffix, suffix)
211: // PETSC_CUPMBLAS_ALIAS_FUNCTION() - Alias a CUDA/HIP library function
212: //
213: // input params:
214: // suffix - the common suffix between CUDA and HIP of the alias function
215: //
216: // notes:
217: // requires PETSC_CUPMBLAS_PREFIX to be defined as the specific CUDA/HIP blas library
218: // prefix. see PETSC_CUPMM_ALIAS_FUNCTION_EXACT() for the precise expansion of this macro.
219: //
220: // example usage:
221: // #define PETSC_CUPMBLAS_PREFIX hipblas
222: // PETSC_CUPMBLAS_ALIAS_FUNCTION(Create) ->
223: // template <typename... T>
224: // static constexpr auto cupmBlasCreate(T&&... args) *noexcept and returntype detection*
225: // {
226: // return hipblasCreate(std::forward<T>(args)...);
227: // }
228: #define PETSC_CUPMBLAS_ALIAS_FUNCTION(suffix) PETSC_CUPM_ALIAS_FUNCTION_EXACT(cupmBlas, suffix, PETSC_CUPMBLAS_PREFIX, suffix)
230: template <DeviceType T>
231: struct BlasInterfaceBase : Interface<T> {
232: PETSC_NODISCARD static constexpr const char *cupmBlasName() noexcept { return T == DeviceType::CUDA ? "cuBLAS" : "hipBLAS"; }
233: };
235: #define PETSC_CUPMBLAS_BASE_CLASS_HEADER(DEV_TYPE) \
236: using base_type = ::Petsc::device::cupm::impl::BlasInterfaceBase<DEV_TYPE>; \
237: using base_type::cupmBlasName; \
238: PETSC_CUPM_ALIAS_FUNCTION_EXACT(cupmBlas, GetErrorName, PetscConcat(Petsc, PETSC_CUPMBLAS_PREFIX_U), GetErrorName) \
239: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(interface_type, DEV_TYPE)
241: template <DeviceType>
242: struct BlasInterfaceImpl;
244: #if PetscDefined(HAVE_CUDA)
245: #define PETSC_CUPMBLAS_PREFIX cublas
246: #define PETSC_CUPMBLAS_PREFIX_U CUBLAS
247: #define PETSC_CUPMBLAS_FP_TYPE PETSC_CUPMBLAS_FP_TYPE_U
248: #define PETSC_CUPMBLAS_FP_INPUT_TYPE PETSC_CUPMBLAS_FP_INPUT_TYPE_U
249: #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
250: template <>
251: struct BlasInterfaceImpl<DeviceType::CUDA> : BlasInterfaceBase<DeviceType::CUDA> {
252: PETSC_CUPMBLAS_BASE_CLASS_HEADER(DeviceType::CUDA);
254: // typedefs
255: using cupmBlasHandle_t = cublasHandle_t;
256: using cupmBlasError_t = cublasStatus_t;
257: using cupmBlasInt_t = int;
258: using cupmSolverHandle_t = cusolverDnHandle_t;
259: using cupmSolverError_t = cusolverStatus_t;
260: using cupmBlasPointerMode_t = cublasPointerMode_t;
262: // values
263: PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_STATUS_SUCCESS);
264: PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_STATUS_NOT_INITIALIZED);
265: PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_STATUS_ALLOC_FAILED);
266: PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_POINTER_MODE_HOST);
267: PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_POINTER_MODE_DEVICE);
269: // utility functions
270: PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
271: PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
272: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
273: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
274: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
275: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)
277: // level 1 BLAS
278: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
279: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
280: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
281: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
282: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
283: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
284: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
285: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)
287: // level 2 BLAS
288: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)
290: // level 3 BLAS
291: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
293: // BLAS extensions
294: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)
296: PETSC_NODISCARD static PetscErrorCode InitializeHandle(cupmSolverHandle_t &handle) noexcept
297: {
298: if (handle) return 0;
299: for (auto i = 0; i < 3; ++i) {
300: const auto cerr = cusolverDnCreate(&handle);
301: if (PetscLikely(cerr == CUSOLVER_STATUS_SUCCESS)) break;
302: if ((cerr != CUSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUSOLVER_STATUS_ALLOC_FAILED)) cerr;
303: if (i < 2) {
304: PetscSleep(3);
305: continue;
306: }
308: }
309: return 0;
310: }
312: PETSC_NODISCARD static PetscErrorCode SetHandleStream(const cupmSolverHandle_t &handle, const cupmStream_t &stream) noexcept
313: {
314: cupmStream_t cupmStream;
316: cusolverDnGetStream(handle, &cupmStream);
317: if (cupmStream != stream) cusolverDnSetStream(handle, stream);
318: return 0;
319: }
321: PETSC_NODISCARD static PetscErrorCode DestroyHandle(cupmSolverHandle_t &handle) noexcept
322: {
323: if (handle) {
324: cusolverDnDestroy(handle);
325: handle = nullptr;
326: }
327: return 0;
328: }
329: };
330: #undef PETSC_CUPMBLAS_PREFIX
331: #undef PETSC_CUPMBLAS_PREFIX_U
332: #undef PETSC_CUPMBLAS_FP_TYPE
333: #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
334: #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
335: #endif // PetscDefined(HAVE_CUDA)
337: #if PetscDefined(HAVE_HIP)
338: #define PETSC_CUPMBLAS_PREFIX hipblas
339: #define PETSC_CUPMBLAS_PREFIX_U HIPBLAS
340: #define PETSC_CUPMBLAS_FP_TYPE PETSC_CUPMBLAS_FP_TYPE_U
341: #define PETSC_CUPMBLAS_FP_INPUT_TYPE PETSC_CUPMBLAS_FP_INPUT_TYPE_U
342: #define PETSC_CUPMBLAS_FP_RETURN_TYPE PETSC_CUPMBLAS_FP_RETURN_TYPE_L
343: template <>
344: struct BlasInterfaceImpl<DeviceType::HIP> : BlasInterfaceBase<DeviceType::HIP> {
345: PETSC_CUPMBLAS_BASE_CLASS_HEADER(DeviceType::HIP);
347: // typedefs
348: using cupmBlasHandle_t = hipblasHandle_t;
349: using cupmBlasError_t = hipblasStatus_t;
350: using cupmBlasInt_t = int; // rocblas will have its own
351: using cupmSolverHandle_t = hipsolverHandle_t;
352: using cupmSolverError_t = hipsolverStatus_t;
353: using cupmBlasPointerMode_t = hipblasPointerMode_t;
355: // values
356: PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_STATUS_SUCCESS);
357: PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_STATUS_NOT_INITIALIZED);
358: PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_STATUS_ALLOC_FAILED);
359: PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_POINTER_MODE_HOST);
360: PETSC_CUPMBLAS_ALIAS_INTEGRAL_VALUE(_POINTER_MODE_DEVICE);
362: // utility functions
363: PETSC_CUPMBLAS_ALIAS_FUNCTION(Create)
364: PETSC_CUPMBLAS_ALIAS_FUNCTION(Destroy)
365: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetStream)
366: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetStream)
367: PETSC_CUPMBLAS_ALIAS_FUNCTION(GetPointerMode)
368: PETSC_CUPMBLAS_ALIAS_FUNCTION(SetPointerMode)
370: // level 1 BLAS
371: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, axpy)
372: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, scal)
373: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dot, PetscIfPetscDefined(USE_COMPLEX, dotc, dot))
374: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION_EXACT(STANDARD, dotu, PetscIfPetscDefined(USE_COMPLEX, dotu, dot))
375: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, swap)
376: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, nrm2)
377: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(IFPTYPE, amax)
378: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(MODIFIED, asum)
380: // level 2 BLAS
381: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemv)
383: // level 3 BLAS
384: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, gemm)
386: // BLAS extensions
387: PETSC_CUPMBLAS_ALIAS_BLAS_FUNCTION(STANDARD, geam)
389: PETSC_NODISCARD static PetscErrorCode InitializeHandle(cupmSolverHandle_t &handle) noexcept
390: {
391: if (!handle) hipsolverCreate(&handle);
392: return 0;
393: }
395: PETSC_NODISCARD static PetscErrorCode SetHandleStream(cupmSolverHandle_t handle, cupmStream_t stream) noexcept
396: {
397: hipsolverSetStream(handle, stream);
398: return 0;
399: }
401: PETSC_NODISCARD static PetscErrorCode DestroyHandle(cupmSolverHandle_t &handle) noexcept
402: {
403: if (handle) {
404: hipsolverDestroy(handle);
405: handle = nullptr;
406: }
407: return 0;
408: }
409: };
410: #undef PETSC_CUPMBLAS_PREFIX
411: #undef PETSC_CUPMBLAS_PREFIX_U
412: #undef PETSC_CUPMBLAS_FP_TYPE
413: #undef PETSC_CUPMBLAS_FP_INPUT_TYPE
414: #undef PETSC_CUPMBLAS_FP_RETURN_TYPE
415: #endif // PetscDefined(HAVE_HIP)
417: #undef PETSC_CUPMBLAS_BASE_CLASS_HEADER
419: #define PETSC_CUPMBLAS_IMPL_CLASS_HEADER(base_name, T) \
420: PETSC_CUPM_INHERIT_INTERFACE_TYPEDEFS_USING(cupmInterface_t, T); \
421: using base_name = ::Petsc::device::cupm::impl::BlasInterfaceImpl<T>; \
422: /* introspection */ \
423: using base_name::cupmBlasName; \
424: using base_name::cupmBlasGetErrorName; \
425: /* types */ \
426: using typename base_name::cupmBlasHandle_t; \
427: using typename base_name::cupmBlasError_t; \
428: using typename base_name::cupmBlasInt_t; \
429: using typename base_name::cupmSolverHandle_t; \
430: using typename base_name::cupmSolverError_t; \
431: using typename base_name::cupmBlasPointerMode_t; \
432: /* values */ \
433: using base_name::CUPMBLAS_STATUS_SUCCESS; \
434: using base_name::CUPMBLAS_STATUS_NOT_INITIALIZED; \
435: using base_name::CUPMBLAS_STATUS_ALLOC_FAILED; \
436: using base_name::CUPMBLAS_POINTER_MODE_HOST; \
437: using base_name::CUPMBLAS_POINTER_MODE_DEVICE; \
438: /* utility functions */ \
439: using base_name::cupmBlasCreate; \
440: using base_name::cupmBlasDestroy; \
441: using base_name::cupmBlasGetStream; \
442: using base_name::cupmBlasSetStream; \
443: using base_name::cupmBlasGetPointerMode; \
444: using base_name::cupmBlasSetPointerMode; \
445: /* level 1 BLAS */ \
446: using base_name::cupmBlasXaxpy; \
447: using base_name::cupmBlasXscal; \
448: using base_name::cupmBlasXdot; \
449: using base_name::cupmBlasXdotu; \
450: using base_name::cupmBlasXswap; \
451: using base_name::cupmBlasXnrm2; \
452: using base_name::cupmBlasXamax; \
453: using base_name::cupmBlasXasum; \
454: /* level 2 BLAS */ \
455: using base_name::cupmBlasXgemv; \
456: /* level 3 BLAS */ \
457: using base_name::cupmBlasXgemm; \
458: /* BLAS extensions */ \
459: using base_name::cupmBlasXgeam
461: // The actual interface class
462: template <DeviceType T>
463: struct BlasInterface : BlasInterfaceImpl<T> {
464: PETSC_CUPMBLAS_IMPL_CLASS_HEADER(blasinterface_type, T);
466: PETSC_NODISCARD static PetscErrorCode PetscCUPMBlasSetPointerModeFromPointer(cupmBlasHandle_t handle, const void *ptr) noexcept
467: {
468: auto mtype = PETSC_MEMTYPE_HOST;
470: PetscCUPMGetMemType(ptr, &mtype);
471: cupmBlasSetPointerMode(handle, PetscMemTypeDevice(mtype) ? CUPMBLAS_POINTER_MODE_DEVICE : CUPMBLAS_POINTER_MODE_HOST);
472: return 0;
473: }
474: };
476: #define PETSC_CUPMBLAS_INHERIT_INTERFACE_TYPEDEFS_USING(base_name, T) \
477: PETSC_CUPMBLAS_IMPL_CLASS_HEADER(PetscConcat(base_name, _impl), T); \
478: using base_name = ::Petsc::device::cupm::impl::BlasInterface<T>; \
479: using base_name::PetscCUPMBlasSetPointerModeFromPointer
481: #if PetscDefined(HAVE_CUDA)
482: extern template struct BlasInterface<DeviceType::CUDA>;
483: #endif
485: #if PetscDefined(HAVE_HIP)
486: extern template struct BlasInterface<DeviceType::HIP>;
487: #endif
489: } // namespace impl
491: } // namespace cupm
493: } // namespace device
495: } // namespace Petsc
497: #endif // defined(__cplusplus)
499: #endif // PETSCCUPMBLASINTERFACE_HPP