mirror of
https://github.com/odin-lang/Odin.git
synced 2026-05-19 14:25:10 -04:00
Add a minor optimization for row_major * row_major
This commit is contained in:
@@ -672,7 +672,7 @@ gb_internal lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lh
|
||||
}
|
||||
}
|
||||
|
||||
gb_internal bool lb_is_matrix_simdable(Type *t) {
|
||||
gb_internal bool lb_is_matrix_simdable(Type *t, bool ignore_layout=false) {
|
||||
Type *mt = base_type(t);
|
||||
GB_ASSERT(mt->kind == Type_Matrix);
|
||||
|
||||
@@ -701,8 +701,10 @@ gb_internal bool lb_is_matrix_simdable(Type *t) {
|
||||
return false;
|
||||
}
|
||||
if (mt->Matrix.is_row_major) {
|
||||
// TODO(bill): make #row_major matrices work with SIMD
|
||||
return false;
|
||||
if (!ignore_layout) {
|
||||
// TODO(bill): make #row_major matrices work with SIMD
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (elem->kind == Type_Basic) {
|
||||
@@ -959,6 +961,10 @@ gb_internal lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b,
|
||||
gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
|
||||
// TODO(bill): Handle edge case for f16 types on x86(-64) platforms
|
||||
|
||||
auto const do_u32 = [](lbProcedure *p, u32 val) -> LLVMValueRef {
|
||||
return LLVMConstInt(lb_type(p->module, t_u32), val, false);
|
||||
};
|
||||
|
||||
Type *xt = base_type(lhs.type);
|
||||
Type *yt = base_type(rhs.type);
|
||||
|
||||
@@ -975,114 +981,179 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs,
|
||||
unsigned inner = cast(unsigned)xt->Matrix.column_count;
|
||||
unsigned outer_columns = cast(unsigned)yt->Matrix.column_count;
|
||||
|
||||
if (!xt->Matrix.is_row_major && lb_is_matrix_simdable(xt)) {
|
||||
unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt);
|
||||
unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt);
|
||||
if (lb_is_matrix_simdable(xt, true)) {
|
||||
if (!xt->Matrix.is_row_major) { // #column_major
|
||||
unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt);
|
||||
unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt);
|
||||
|
||||
LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs);
|
||||
LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs);
|
||||
LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs);
|
||||
LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs);
|
||||
|
||||
if (outer_rows == outer_columns && outer_rows == inner && (inner & 1) == 0) {
|
||||
// square matrix calculation
|
||||
unsigned N = outer_columns;
|
||||
if (outer_rows == outer_columns && outer_rows == inner && (inner & 1) == 0) {
|
||||
// square matrix calculation
|
||||
unsigned N = outer_columns;
|
||||
|
||||
auto x_columns = slice_make<LLVMValueRef>(permanent_allocator(), N);
|
||||
auto y_columns = slice_make<LLVMValueRef>(permanent_allocator(), N);
|
||||
auto x_columns = slice_make<LLVMValueRef>(permanent_allocator(), N);
|
||||
auto y_columns = slice_make<LLVMValueRef>(permanent_allocator(), N);
|
||||
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
LLVMValueRef mask = llvm_mask_iota(p->module, x_stride*i, inner);
|
||||
LLVMValueRef column = llvm_basic_shuffle(p, x_vector, mask);
|
||||
x_columns[i] = column;
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
LLVMValueRef mask = llvm_mask_iota(p->module, x_stride*i, N);
|
||||
LLVMValueRef column = llvm_basic_shuffle(p, x_vector, mask);
|
||||
x_columns[i] = column;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, N);
|
||||
LLVMValueRef column = llvm_basic_shuffle(p, y_vector, mask);
|
||||
y_columns[i] = column;
|
||||
}
|
||||
|
||||
|
||||
auto z_columns = slice_make<LLVMValueRef>(permanent_allocator(), N);
|
||||
auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), N);
|
||||
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
for (unsigned j = 0; j < N; j++) {
|
||||
LLVMValueRef mask = llvm_mask_same(p->module, j, N);
|
||||
mask_elems[j] = llvm_basic_shuffle(p, y_columns[i], mask);
|
||||
}
|
||||
z_columns[i] = llvm_vector_mul_pairwise_reduce_add(p, mask_elems, x_columns);
|
||||
}
|
||||
|
||||
lbAddr res = lb_add_local_generated(p, type, false);
|
||||
LLVMValueRef dest_ptr = res.addr.value;
|
||||
|
||||
LLVMTypeRef dest_ptr_type = LLVMPointerType(LLVMTypeOf(z_columns[0]), 0);
|
||||
dest_ptr = LLVMBuildPointerCast(p->builder, dest_ptr, dest_ptr_type, "");
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
LLVMValueRef indices[] = {do_u32(p, i)};
|
||||
LLVMValueRef dst = LLVMBuildInBoundsGEP2(p->builder, LLVMTypeOf(z_columns[0]), dest_ptr, indices, 1, "");
|
||||
LLVMBuildStore(p->builder, z_columns[i], dst);
|
||||
}
|
||||
|
||||
return lb_addr_load(p, res);
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
|
||||
auto x_rows = slice_make<LLVMValueRef>(permanent_allocator(), outer_rows);
|
||||
auto y_columns = slice_make<LLVMValueRef>(permanent_allocator(), outer_columns);
|
||||
|
||||
auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), inner);
|
||||
for (unsigned i = 0; i < outer_rows; i++) {
|
||||
for (unsigned j = 0; j < inner; j++) {
|
||||
unsigned offset = x_stride*j + i;
|
||||
mask_elems[j] = do_u32(p, offset);
|
||||
}
|
||||
|
||||
// transpose mask
|
||||
LLVMValueRef mask = LLVMConstVector(mask_elems.data, inner);
|
||||
LLVMValueRef row = llvm_basic_shuffle(p, x_vector, mask);
|
||||
x_rows[i] = row;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < outer_columns; i++) {
|
||||
LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, inner);
|
||||
LLVMValueRef column = llvm_basic_shuffle(p, y_vector, mask);
|
||||
y_columns[i] = column;
|
||||
}
|
||||
|
||||
|
||||
auto z_columns = slice_make<LLVMValueRef>(permanent_allocator(), N);
|
||||
|
||||
auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), N);
|
||||
auto temp_muls = slice_make<LLVMValueRef>(permanent_allocator(), N);
|
||||
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
for (unsigned j = 0; j < N; j++) {
|
||||
LLVMValueRef mask = llvm_mask_same(p->module, j, N);
|
||||
mask_elems[j] = llvm_basic_shuffle(p, y_columns[i], mask);
|
||||
lbAddr res = lb_add_local_generated(p, type, false);
|
||||
for_array(i, x_rows) {
|
||||
LLVMValueRef x_row = x_rows[i];
|
||||
for_array(j, y_columns) {
|
||||
LLVMValueRef y_column = y_columns[j];
|
||||
LLVMValueRef elem = llvm_vector_dot(p, x_row, y_column);
|
||||
lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
|
||||
LLVMBuildStore(p->builder, elem, dst.value);
|
||||
}
|
||||
for (unsigned j = 0; j < N; j++) {
|
||||
temp_muls[j] = llvm_vector_mul(p, mask_elems[j], x_columns[j]);
|
||||
}
|
||||
return lb_addr_load(p, res);
|
||||
} else { // #row_major
|
||||
unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt);
|
||||
unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt);
|
||||
|
||||
LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs);
|
||||
LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs);
|
||||
|
||||
if (outer_rows == outer_columns && outer_rows == inner && (inner & 1) == 0) {
|
||||
// square matrix calculation
|
||||
unsigned N = outer_columns;
|
||||
|
||||
auto x_rows = slice_make<LLVMValueRef>(permanent_allocator(), N);
|
||||
auto y_rows = slice_make<LLVMValueRef>(permanent_allocator(), N);
|
||||
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
LLVMValueRef mask = llvm_mask_iota(p->module, x_stride*i, N);
|
||||
LLVMValueRef column = llvm_basic_shuffle(p, x_vector, mask);
|
||||
x_rows[i] = column;
|
||||
}
|
||||
unsigned k = N;
|
||||
while (k > 1) {
|
||||
unsigned half = k/2;
|
||||
for (unsigned j = 0; j < half; j++) {
|
||||
temp_muls[j] = llvm_vector_add(p, temp_muls[2*j + 0], temp_muls[2*j + 1]);
|
||||
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, N);
|
||||
LLVMValueRef column = llvm_basic_shuffle(p, y_vector, mask);
|
||||
y_rows[i] = column;
|
||||
}
|
||||
|
||||
|
||||
auto z_rows = slice_make<LLVMValueRef>(permanent_allocator(), N);
|
||||
auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), N);
|
||||
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
for (unsigned j = 0; j < N; j++) {
|
||||
LLVMValueRef mask = llvm_mask_same(p->module, j, N);
|
||||
mask_elems[j] = llvm_basic_shuffle(p, x_rows[i], mask);
|
||||
}
|
||||
|
||||
if ((k&1) != 0) {
|
||||
temp_muls[half] = temp_muls[k-1];
|
||||
}
|
||||
k = (k+1)/2;
|
||||
z_rows[i] = llvm_vector_mul_pairwise_reduce_add(p, mask_elems, y_rows);
|
||||
}
|
||||
|
||||
z_columns[i] = temp_muls[0];
|
||||
lbAddr res = lb_add_local_generated(p, type, false);
|
||||
LLVMValueRef dest_ptr = res.addr.value;
|
||||
|
||||
LLVMTypeRef dest_ptr_type = LLVMPointerType(LLVMTypeOf(z_rows[0]), 0);
|
||||
dest_ptr = LLVMBuildPointerCast(p->builder, dest_ptr, dest_ptr_type, "");
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
LLVMValueRef indices[] = {do_u32(p, i)};
|
||||
LLVMValueRef dst = LLVMBuildInBoundsGEP2(p->builder, LLVMTypeOf(z_rows[0]), dest_ptr, indices, 1, "");
|
||||
LLVMBuildStore(p->builder, z_rows[i], dst);
|
||||
}
|
||||
|
||||
return lb_addr_load(p, res);
|
||||
}
|
||||
|
||||
auto do_u32 = [](lbProcedure *p, u32 val) -> LLVMValueRef {
|
||||
return LLVMConstInt(lb_type(p->module, t_u32), val, false);
|
||||
};
|
||||
auto x_rows = slice_make<LLVMValueRef>(permanent_allocator(), outer_rows);
|
||||
auto y_columns = slice_make<LLVMValueRef>(permanent_allocator(), outer_columns);
|
||||
|
||||
for (unsigned i = 0; i < outer_rows; i++) {
|
||||
LLVMValueRef mask = llvm_mask_iota(p->module, x_stride*i, inner);
|
||||
LLVMValueRef row = llvm_basic_shuffle(p, x_vector, mask);
|
||||
x_rows[i] = row;
|
||||
}
|
||||
|
||||
auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), inner);
|
||||
for (unsigned i = 0; i < outer_columns; i++) {
|
||||
for (unsigned j = 0; j < inner; j++) {
|
||||
unsigned offset = x_stride*j + i;
|
||||
mask_elems[j] = do_u32(p, offset);
|
||||
}
|
||||
|
||||
// transpose mask
|
||||
LLVMValueRef mask = LLVMConstVector(mask_elems.data, inner);
|
||||
LLVMValueRef column = llvm_basic_shuffle(p, y_vector, mask);
|
||||
y_columns[i] = column;
|
||||
}
|
||||
|
||||
lbAddr res = lb_add_local_generated(p, type, false);
|
||||
LLVMValueRef dest_ptr = res.addr.value;
|
||||
|
||||
LLVMTypeRef dest_ptr_type = LLVMPointerType(LLVMTypeOf(z_columns[0]), 0);
|
||||
dest_ptr = LLVMBuildPointerCast(p->builder, dest_ptr, dest_ptr_type, "");
|
||||
for (unsigned i = 0; i < N; i++) {
|
||||
LLVMValueRef indices[] = {do_u32(p, i)};
|
||||
LLVMValueRef dst = LLVMBuildInBoundsGEP2(p->builder, LLVMTypeOf(z_columns[0]), dest_ptr, indices, 1, "");
|
||||
LLVMBuildStore(p->builder, z_columns[i], dst);
|
||||
for_array(i, x_rows) {
|
||||
LLVMValueRef x_row = x_rows[i];
|
||||
for_array(j, y_columns) {
|
||||
LLVMValueRef y_column = y_columns[j];
|
||||
LLVMValueRef elem = llvm_vector_dot(p, x_row, y_column);
|
||||
lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
|
||||
LLVMBuildStore(p->builder, elem, dst.value);
|
||||
}
|
||||
}
|
||||
|
||||
return lb_addr_load(p, res);
|
||||
}
|
||||
|
||||
|
||||
auto x_rows = slice_make<LLVMValueRef>(permanent_allocator(), outer_rows);
|
||||
auto y_columns = slice_make<LLVMValueRef>(permanent_allocator(), outer_columns);
|
||||
|
||||
auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), inner);
|
||||
for (unsigned i = 0; i < outer_rows; i++) {
|
||||
for (unsigned j = 0; j < inner; j++) {
|
||||
unsigned offset = x_stride*j + i;
|
||||
mask_elems[j] = lb_const_int(p->module, t_u32, offset).value;
|
||||
}
|
||||
|
||||
// transpose mask
|
||||
LLVMValueRef mask = LLVMConstVector(mask_elems.data, inner);
|
||||
LLVMValueRef row = llvm_basic_shuffle(p, x_vector, mask);
|
||||
x_rows[i] = row;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < outer_columns; i++) {
|
||||
LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, inner);
|
||||
LLVMValueRef column = llvm_basic_shuffle(p, y_vector, mask);
|
||||
y_columns[i] = column;
|
||||
}
|
||||
|
||||
lbAddr res = lb_add_local_generated(p, type, false);
|
||||
for_array(i, x_rows) {
|
||||
LLVMValueRef x_row = x_rows[i];
|
||||
for_array(j, y_columns) {
|
||||
LLVMValueRef y_column = y_columns[j];
|
||||
LLVMValueRef elem = llvm_vector_dot(p, x_row, y_column);
|
||||
lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
|
||||
LLVMBuildStore(p->builder, elem, dst.value);
|
||||
}
|
||||
}
|
||||
return lb_addr_load(p, res);
|
||||
}
|
||||
|
||||
if (!xt->Matrix.is_row_major) {
|
||||
@@ -1186,28 +1257,7 @@ gb_internal lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbVal
|
||||
}
|
||||
}
|
||||
|
||||
auto temps = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
|
||||
for (unsigned i = 0; i < column_count; i++) {
|
||||
temps[i] = llvm_vector_mul(p, m_columns[i], v_rows[i]);
|
||||
}
|
||||
|
||||
GB_ASSERT(column_count > 0);
|
||||
|
||||
unsigned k = column_count;
|
||||
while (k > 1) {
|
||||
unsigned half = k/2;
|
||||
for (unsigned j = 0; j < half; j++) {
|
||||
temps[j] = llvm_vector_add(p, temps[2*j + 0], temps[2*j + 1]);
|
||||
}
|
||||
|
||||
if ((k&1) != 0) {
|
||||
temps[half] = temps[k-1];
|
||||
}
|
||||
k = (k+1)/2;
|
||||
}
|
||||
|
||||
LLVMValueRef vector = temps[0];
|
||||
|
||||
LLVMValueRef vector = llvm_vector_mul_pairwise_reduce_add(p, m_columns, v_rows);
|
||||
return lb_matrix_cast_vector_to_type(p, vector, type);
|
||||
}
|
||||
|
||||
|
||||
@@ -2230,6 +2230,30 @@ gb_internal LLVMValueRef llvm_vector_mul(lbProcedure *p, LLVMValueRef a, LLVMVal
|
||||
return LLVMBuildFMul(p->builder, a, b, "");
|
||||
}
|
||||
|
||||
gb_internal LLVMValueRef llvm_vector_mul_pairwise_reduce_add(lbProcedure *p, Slice<LLVMValueRef> const &a, Slice<LLVMValueRef> const &b) {
|
||||
GB_ASSERT(a.count == b.count);
|
||||
|
||||
auto temps = slice_make<LLVMValueRef>(temporary_allocator(), a.count);
|
||||
for (unsigned i = 0; i < a.count; i++) {
|
||||
temps[i] = llvm_vector_mul(p, a[i], b[i]);
|
||||
}
|
||||
|
||||
unsigned k = cast(unsigned)a.count;
|
||||
while (k > 1) {
|
||||
unsigned half = k/2;
|
||||
for (unsigned j = 0; j < half; j++) {
|
||||
temps[j] = llvm_vector_add(p, temps[2*j + 0], temps[2*j + 1]);
|
||||
}
|
||||
|
||||
if ((k&1) != 0) {
|
||||
temps[half] = temps[k-1];
|
||||
}
|
||||
k = (k+1)/2;
|
||||
}
|
||||
|
||||
return temps[0];
|
||||
}
|
||||
|
||||
|
||||
gb_internal LLVMValueRef llvm_vector_dot(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) {
|
||||
return llvm_vector_reduce_add(p, llvm_vector_mul(p, a, b));
|
||||
|
||||
Reference in New Issue
Block a user