joint_matrix_bfloat16_modified
joint_matrix_bfloat16_modified
#include <iostream>
#include <CL/sycl.hpp>
#include <ext/oneapi/experimental/bfloat16.hpp>
#include <ext/oneapi/matrix/matrix.hpp>
#define SG_SZ 8
#define TM 8
#define TN 8
#define TK 16
public:
T *get_data() { return mat; }
void set_data(T *data) { mat = data; }
big_matrix(T *data) : mat(data) {}
};
sycl::default_selector d_selector;
sycl::queue q(d_selector);
// Print out the device information used for the kernel code.
std::cout << "Running on device: "
<< q.get_device().get_info<sycl::info::device::name>() << "\n";
try
{
q.submit([&](handler &cgh)
{
auto accC = bufC.get_access<access::mode::read_write>(cgh);
auto accA = bufA.get_access<access::mode::read_write>(cgh);
auto accB = bufB.get_access<access::mode::read_write>(cgh);
cgh.parallel_for<class imatrix>(
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]
{
// The submatrix API has to be accessed by all the workitems in a
// subgroup these functions will be called once by the subgroup no
// code divergence between the workitems
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
joint_matrix<bfloat16, TM, TK> sub_a(sg);
// For B, since current implementation does not support non-packed
// layout, users need to specify the updated VNNI sizes along with
// the packed_b layout. By default, the layout is row_major and size
// is (TK, TN).
joint_matrix<bfloat16, TK, TN, matrix_layout::packed_b> sub_b(sg);
joint_matrix<float, TM, TN> sub_c(sg);
float make_fp32(short x)
{
unsigned int y = x;
y = y << 16;
float *res = reinterpret_cast<float *>(&y);
return *res;
}
void initialize_matrices()
{
for (int i = 0; i < MATRIX_M; i++)
{
for (int j = 0; j < MATRIX_K; j++)
{
// bfloat16 is created using unsigned short since conversion from float to
// bfloat16 is not supported on the host side yet
A[i][j] = bfloat16::from_bits(make_bf16(1.0f * (i + j)));
Aref[i][j] = make_bf16(1.0f * (i + j));
}
}
for (int i = 0; i < MATRIX_K / 2; i++)
{
for (int j = 0; j < MATRIX_N * 2; j++)
{
B[i][j] = bfloat16::from_bits((make_bf16(2.0f * i + 3.0f * j)));
Bref[i][j] = make_bf16(2.0f * i + 3.0f * j);
}
}
for (int i = 0; i < MATRIX_M; i++)
{
for (int j = 0; j < MATRIX_N; j++)
{
C[i][j] = 1.0;
D[i][j] = 1.0;
}
}
}
int main()
{
initialize_matrices();
start = std::chrono::steady_clock::now();
matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M,
MATRIX_N, MATRIX_K / 2);
end = std::chrono::steady_clock::now();
std::cout << "Elapsed time in milliseconds (reference): "
<< std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()
<< " ms" << std::endl;