66subroutine solve_tridi_cpu_&
67 &precision_and_suffix &
70 ldq, nblk, matrixcols, mpi_comm_all, mpi_comm_rows, &
71 mpi_comm_cols, wantdebug, success, max_threads )
84 use solve_tridi_col_gpu
86#include "../../src/general/precision_kinds.F90"
88 integer(kind=ik),
intent(in) :: na, nev, ldq, nblk, matrixCols, &
89 mpi_comm_all, mpi_comm_rows, mpi_comm_cols
91 integer(kind=c_intptr_t) :: d_dev, e_dev, q_dev
92#ifndef SOLVE_TRIDI_GPU_BUILD
93 real(kind=real_datatype),
intent(inout) :: d(na), e(na)
94#ifdef USE_ASSUMED_SIZE
95 real(kind=real_datatype),
intent(inout) :: q(ldq,*)
97 real(kind=real_datatype),
intent(inout) :: q(ldq,matrixcols)
99#else /* SOLVE_TRIDI_GPU_BUILD */
100 real(kind=real_datatype) :: d(na), e(na)
101 real(kind=real_datatype) :: q(ldq,matrixcols)
102#endif /* SOLVE_TRIDI_GPU_BUILD */
104 logical,
intent(in) :: wantDebug
106 integer(kind=c_int) :: debug
108 integer(kind=ik) :: i, j, n, np, nc, nev1, l_cols, l_rows
109 integer(kind=ik) :: my_prow, my_pcol, np_rows, np_cols
110 integer(kind=MPI_KIND) :: mpierr, my_prowMPI, my_pcolMPI, np_rowsMPI, np_colsMPI
111 integer(kind=ik),
allocatable :: limits(:), l_col(:), p_col(:), l_col_bc(:), p_col_bc(:)
113 integer(kind=ik) :: istat
114 character(200) :: errorMessage
115 character(20) :: gpuString
116 integer(kind=ik),
intent(in) :: max_threads
118 integer(kind=c_intptr_t) :: num
119 integer(kind=c_intptr_t),
parameter :: size_of_datatype = size_of_&
123 integer(kind=c_intptr_t),
parameter :: size_of_datatype_real = size_of_&
126 integer(kind=c_intptr_t) :: gpuHandle, my_stream
127 integer(kind=c_intptr_t) :: limits_dev
128 logical :: successGPU
131 if (wantdebug) debug = 1
134#ifdef SOLVE_TRIDI_GPU_BUILD
144 call obj%timer%start(
"solve_tridi" // precision_suffix // gpustring)
146 call obj%timer%start(
"mpi_communication")
147 call mpi_comm_rank(int(mpi_comm_rows,kind=mpi_kind) ,my_prowmpi, mpierr)
148 call mpi_comm_size(int(mpi_comm_rows,kind=mpi_kind) ,np_rowsmpi, mpierr)
149 call mpi_comm_rank(int(mpi_comm_cols,kind=mpi_kind) ,my_pcolmpi, mpierr)
150 call mpi_comm_size(int(mpi_comm_cols,kind=mpi_kind) ,np_colsmpi, mpierr)
152 my_prow = int(my_prowmpi,kind=c_int)
153 np_rows = int(np_rowsmpi,kind=c_int)
154 my_pcol = int(my_pcolmpi,kind=c_int)
155 np_cols = int(np_colsmpi,kind=c_int)
158 call obj%timer%stop(
"mpi_communication")
163 l_rows = local_index(na, my_prow, np_rows, nblk, -1)
164 l_cols = local_index(na, my_pcol, np_cols, nblk, -1)
168#ifdef WITH_GPU_STREAMS
169 my_stream = obj%gpu_setup%my_stream
170 successgpu = gpu_memset_async(q_dev, 0, l_rows*l_cols*size_of_datatype_real, my_stream)
171 check_memset_gpu(
"solve_tridi: tmp_dev", successgpu)
173 successgpu = gpu_memset(q_dev, 0, l_rows*l_cols*size_of_datatype_real)
174 check_memset_gpu(
"solve_tridi: tmp_dev", successgpu)
177 if (l_rows .ne. ldq)
then
178 print *,
"oh shit ldq:",l_rows,ldq
181 if (l_cols .ne. matrixcols)
then
182 print *,
"oh shit matrixCols:",l_cols,matrixcols
188 q(1:l_rows, 1:l_cols) = 0.0_rk
196 allocate(limits(0:np_cols), stat=istat, errmsg=errormessage)
197 check_allocate(
"solve_tridi: limits", istat, errormessage)
201 nc = local_index(na, np, np_cols, nblk, -1)
208 call obj%timer%stop(
"solve_tridi" // precision_suffix)
209 if (wantdebug)
write(error_unit,*)
'ELPA1_solve_tridi: ERROR: Problem contains processor column with zero width'
213 limits(np+1) = limits(np) + nc
221 num = (np_cols) * size_of_int
222 successgpu = gpu_malloc(limits_dev, num)
223 check_alloc_gpu(
"solve_tridi limits_dev: ", successgpu)
225 num = (np_cols) * size_of_int
226#ifdef WITH_GPU_STREAMS
227 my_stream = obj%gpu_setup%my_stream
228 successgpu = gpu_memcpy_async(limits_dev, int(loc(limits(1)),kind=c_intptr_t), &
229 num, gpumemcpyhosttodevice, my_stream)
230 check_memcpy_gpu(
"solve_tridi limits_dev: ", successgpu)
232 successgpu = gpu_memcpy(limits_dev, int(loc(limits(1)),kind=c_intptr_t), &
233 num, gpumemcpyhosttodevice)
234 check_memcpy_gpu(
"solve_tridi: limits_dev", successgpu)
238 my_stream = obj%gpu_setup%my_stream
239 call gpu_update_d (precision_char, d_dev, e_dev, limits_dev, np_cols, na, debug, my_stream)
241 successgpu = gpu_free(limits_dev)
242 check_dealloc_gpu(
"solve_tridi: limits_dev", successgpu)
247 d(n) = d(n)-abs(e(n))
248 d(n+1) = d(n+1)-abs(e(n))
260 nev1 = min(nev,l_cols)
265 call solve_tridi_col_gpu_&
266 &precision_and_suffix &
267 (obj, l_cols, nev1, nc, d_dev +(nc+1-1)*size_of_datatype_real, &
268 e_dev + (nc+1-1)*size_of_datatype_real, q_dev, ldq, nblk, &
269 matrixcols, mpi_comm_rows, wantdebug, success, max_threads)
271 call solve_tridi_col_cpu_&
272 &precision_and_suffix &
273 (obj, l_cols, nev1, nc, d(nc+1), e(nc+1), q, ldq, nblk, &
274 matrixcols, mpi_comm_rows, wantdebug, success, max_threads)
277 if (.not.(success))
then
278 call obj%timer%stop(
"solve_tridi" // precision_suffix // gpustring)
287 deallocate(limits, stat=istat, errmsg=errormessage)
288 check_deallocate(
"solve_tridi: limits", istat, errormessage)
290 call obj%timer%stop(
"solve_tridi" // precision_suffix // gpustring)
298 allocate(l_col(na), stat=istat, errmsg=errormessage)
299 check_allocate(
"solve_tridi: l_col", istat, errormessage)
301 allocate(p_col(na), stat=istat, errmsg=errormessage)
302 check_allocate(
"solve_tridi: p_col", istat, errormessage)
306 nc = local_index(na, np, np_cols, nblk, -1)
316 allocate(l_col_bc(na), stat=istat, errmsg=errormessage)
317 check_allocate(
"solve_tridi: l_col_bc", istat, errormessage)
319 allocate(p_col_bc(na), stat=istat, errmsg=errormessage)
320 check_allocate(
"solve_tridi: p_col_bc", istat, errormessage)
325 do i = 0, na-1, nblk*np_cols
328 if (i+j*nblk+n <= min(nev,na))
then
329 p_col_bc(i+j*nblk+n) = j
330 l_col_bc(i+j*nblk+n) = i/np_cols + n
339 num = na * size_of_datatype_real
340#ifdef WITH_GPU_STREAMS
341 my_stream = obj%gpu_setup%my_stream
342 call gpu_memcpy_async_and_stream_synchronize &
343 (
"solve_tridi d_dev -> d", d_dev, 0_c_intptr_t, &
345 1, num, gpumemcpydevicetohost, my_stream, .false., .true., .false.)
347 successgpu = gpu_memcpy(int(loc(d(1)),kind=c_intptr_t), d_dev, &
348 num, gpumemcpydevicetohost)
349 check_memcpy_gpu(
"solve_tridi: 1: d_dev", successgpu)
352 num = na * size_of_datatype_real
353#ifdef WITH_GPU_STREAMS
354 my_stream = obj%gpu_setup%my_stream
355 call gpu_memcpy_async_and_stream_synchronize &
356 (
"solve_tridi e_dev -> e", e_dev, 0_c_intptr_t, &
358 1, num, gpumemcpydevicetohost, my_stream, .false., .true., .false.)
360 successgpu = gpu_memcpy(int(loc(e(1)),kind=c_intptr_t), e_dev, &
361 num, gpumemcpydevicetohost)
362 check_memcpy_gpu(
"solve_tridi: 1: d_dev", successgpu)
370 call merge_recursive_gpu_&
372 (obj, 0, np_cols, ldq, matrixcols, nblk, &
373 l_col, p_col, l_col_bc, p_col_bc, limits, &
374 np_cols, na, q_dev, d, e, &
375 mpi_comm_all, mpi_comm_rows, mpi_comm_cols,&
376 usegpu, wantdebug, success, max_threads)
378 call merge_recursive_cpu_&
380 (obj, 0, np_cols, ldq, matrixcols, nblk, &
381 l_col, p_col, l_col_bc, p_col_bc, limits, &
382 np_cols, na, q, d, e, &
383 mpi_comm_all, mpi_comm_rows, mpi_comm_cols,&
384 usegpu, wantdebug, success, max_threads)
387 if (.not.(success))
then
388 call obj%timer%stop(
"solve_tridi" // precision_suffix // gpustring)
392 deallocate(limits,l_col,p_col,l_col_bc,p_col_bc, stat=istat, errmsg=errormessage)
393 check_deallocate(
"solve_tridi: limits, l_col, p_col, l_col_bc, p_col_bc", istat, errormessage)
398 num = na * size_of_datatype_real
399#ifdef WITH_GPU_STREAMS
400 my_stream = obj%gpu_setup%my_stream
401 call gpu_memcpy_async_and_stream_synchronize &
402 (
"solve_trid d -> d_dev", d_dev, 0_c_intptr_t, &
404 1, num, gpumemcpyhosttodevice, my_stream, .false., .false., .false.)
406 successgpu = gpu_memcpy(d_dev, int(loc(d(1)),kind=c_intptr_t), &
407 num, gpumemcpyhosttodevice)
408 check_memcpy_gpu(
"solve_tridi: d_dev", successgpu)
411 num = na * size_of_datatype_real
412#ifdef WITH_GPU_STREAMS
413 my_stream = obj%gpu_setup%my_stream
414 call gpu_memcpy_async_and_stream_synchronize &
415 (
"solve_tridi e_dev -> e", e_dev, 0_c_intptr_t, &
417 1, num, gpumemcpyhosttodevice, my_stream, .false., .false., .false.)
419 successgpu = gpu_memcpy(e_dev, int(loc(e(1)),kind=c_intptr_t), &
420 num, gpumemcpyhosttodevice)
421 check_memcpy_gpu(
"solve_tridi: e_dev", successgpu)
425 call obj%timer%stop(
"solve_tridi" // precision_suffix // gpustring)