//#include //#include #include #include #include using namespace std; using namespace nvcuda::wmma; // Matrix sizes. // A is MxK, B is KxN, C=AB is MxN. const int M = 32; const int N = 8; const int K = 16; const int Asize = M*K; const int Bsize = K*N; const int Csize = M*N; // Matrices are stored row major. The stride is the number of columns. const int Astride = K; const int Bstride = N; const int Cstride = N; // Tile sizes. Simplest case: only one tile for the entire matrix. // Supported are M=16, N=16, K=16 or M=32, N=8, K=16 or M=8, N=32, K=16 const int Mtile = M; const int Ntile = N; const int Ktile = K; // Kernel which multiplies only one tile. __global__ void mykernel(half *a, half *b, float *c) { // Fragments (matrix tiles). Stored in registers in a tensor core. // matrix_a, matrix_b, accumulator are constants in WMMA namespace. // Templates are used instead of function arguments to have compile time constants. fragment a_frag; fragment b_frag; fragment c_frag; // Initialize accumulator with zeros fill_fragment(c_frag, 0.0f); // Load A and B into fragments // All threads of the warp work together to load the tiles. // a,b point to the upper left corner of the tile. load_matrix_sync(a_frag, a, Astride, mem_row_major); load_matrix_sync(b_frag, b, Bstride, mem_row_major); // Do the matrix multiply-accumulate with a tensor core. mma_sync(c_frag, a_frag, b_frag, c_frag); // Store result back to memory store_matrix_sync(c, c_frag, Cstride, mem_row_major); } // Access to matrix elements. template T get(int i, int j, int n, T* a) { return (float)a[i*n+j]; } template void set(int i, int j, int n, T* a, T value) { a[i*n+j] = value; } // Matrix multiplication on host. void multiply(half * a, half * b, float * c, int m, int n, int k) { int i,j,l; float sum; for(i=0; i>>(a_dev, b_dev, c_dev); cudaDeviceSynchronize(); cudaMemcpy(c_host, c_dev, Csize * sizeof(float), cudaMemcpyDeviceToHost); // Check result. multiply(a_host,b_host,check,M,N,K); compare(c_host,check); cudaFree(a_dev); cudaFree(b_dev); cudaFree(c_dev); return 0; }