217 const size_type dim = m_A.block.dimension();
219 volatile VectorScalar *
const sh =
220 kokkos_impl_cuda_shared_memory<VectorScalar>();
221 volatile VectorScalar *
const sh_y0 =
222 sh + blockDim.x*threadIdx.y;
223 volatile VectorScalar *
const sh_a0 =
224 sh + blockDim.x*blockDim.y + MAX_COL*threadIdx.y;
225 volatile VectorScalar *
const sh_x0 =
226 sh + blockDim.x*blockDim.y + MAX_COL*blockDim.y + MAX_COL*threadIdx.y;
228 (
size_type*)(sh + blockDim.x*blockDim.y + 2*MAX_COL*blockDim.y) + MAX_COL*threadIdx.y;
231 const size_type row = blockIdx.x*blockDim.y + threadIdx.y;
232 if (row < m_A.graph.row_map.extent(0)-1) {
233 const size_type colBeg = m_A.graph.row_map[ row ];
234 const size_type colEnd = m_A.graph.row_map[ row + 1 ];
237 const TensorScalar c0 = m_A.block.value(0);
238 const TensorScalar c1 = m_A.block.value(1);
239 const TensorScalar c2 = m_A.block.value(2);
242 VectorScalar y0 = 0.0;
245 for (
size_type lcol = threadIdx.x; lcol < colEnd-colBeg;
247 sh_col[lcol] = m_A.graph.entries( lcol+colBeg );
250 for (
size_type stoch_row = threadIdx.x; stoch_row < dim;
251 stoch_row += blockDim.x) {
253 VectorScalar yi = 0.0;
261 for (
size_type col_offset = colBeg; col_offset < colEnd;
265 const size_type lcol = col_offset-colBeg;
269 const MatrixScalar ai = m_A.values( stoch_row, col_offset );
270 const VectorScalar xi = m_x( stoch_row, col );
273 if (stoch_row == 0) {
279 const MatrixScalar a0 = sh_a0[lcol];
280 const VectorScalar x0 = sh_x0[lcol];
283 if (stoch_row == 0) y0 += (c0-3.0*c1-c2)*a0*x0;
285 yi += c1*(a0*xi + ai*x0) + c2*ai*xi;
290 m_y( stoch_row, row ) = yi;
298 sh_y0[ threadIdx.x ] = y0;
299 if ( threadIdx.x + 16 < blockDim.x )
300 sh_y0[threadIdx.x] += sh_y0[threadIdx.x+16];
301 if ( threadIdx.x + 8 < blockDim.x )
302 sh_y0[threadIdx.x] += sh_y0[threadIdx.x+ 8];
303 if ( threadIdx.x + 4 < blockDim.x )
304 sh_y0[threadIdx.x] += sh_y0[threadIdx.x+ 4];
305 if ( threadIdx.x + 2 < blockDim.x )
306 sh_y0[threadIdx.x] += sh_y0[threadIdx.x+ 2];
307 if ( threadIdx.x + 1 < blockDim.x )
308 sh_y0[threadIdx.x] += sh_y0[threadIdx.x+ 1];
311 if ( threadIdx.x == 0 ) m_y( 0, row ) += sh_y0[0];