From 58e4e6be240c7f588fdec8bcbb3c68fb1b31eab2 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Mon, 11 May 2026 13:14:02 +0100 Subject: [PATCH] Improve `matrix * vector` code generation --- src/llvm_backend_expr.cpp | 88 ++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 52 deletions(-) diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 598ab6d21..7f45b89dd 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -697,7 +697,7 @@ gb_internal bool lb_is_matrix_simdable(Type *t) { // it's not aligned well enough to use the vector instructions return false; } - if ((mt->Matrix.row_count & 1) ^ (mt->Matrix.column_count & 1)) { + if ((mt->Matrix.row_count & 1) && (mt->Matrix.column_count & 1)) { return false; } if (mt->Matrix.is_row_major) { @@ -976,35 +976,6 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, unsigned outer_columns = cast(unsigned)yt->Matrix.column_count; if (!xt->Matrix.is_row_major && lb_is_matrix_simdable(xt)) { - - // if (LLVMIsALoadInst(lhs.value) && LLVMIsALoadInst(rhs.value)) { - // auto do_u32 = [](lbProcedure *p, u32 val) -> LLVMValueRef { - // return LLVMConstInt(lb_type(p->module, t_u32), val, false); - // }; - - // LLVMValueRef llvm_stride = do_u32(p, inner); - // LLVMValueRef llvm_false = LLVMConstInt(lb_type(p->module, t_llvm_bool), false, false); - - // LLVMValueRef lhs_args[] = {LLVMGetOperand(lhs.value, 0), llvm_stride, llvm_false, do_u32(p, outer_rows), do_u32(p, inner)}; - // LLVMValueRef rhs_args[] = {LLVMGetOperand(rhs.value, 0), llvm_stride, llvm_false, do_u32(p, inner), do_u32(p, outer_columns)}; - // LLVMTypeRef types[] = {lb_type(p->module, elem)}; - - // LLVMValueRef lhs_loaded = lb_call_intrinsic(p, "llvm.matrix.column.major.load", lhs_args, gb_count_of(lhs_args), types, gb_count_of(types)); - // LLVMValueRef rhs_loaded = lb_call_intrinsic(p, "llvm.matrix.column.major.load", rhs_args, gb_count_of(rhs_args), types, gb_count_of(types)); - - // LLVMValueRef mul_args[] = {lhs_loaded, rhs_loaded, do_u32(p, outer_rows), do_u32(p, inner), do_u32(p, outer_columns)}; - // LLVMValueRef lhs_mul_rhs = lb_call_intrinsic(p, "llvm.matrix.multiply", mul_args, gb_count_of(mul_args), types, gb_count_of(types)); - - // lbAddr res = lb_add_local_generated(p, type, false); - - // LLVMValueRef store_args[] = {res.addr.value, lhs_mul_rhs, llvm_stride, llvm_false, do_u32(p, inner), do_u32(p, outer_columns)}; - // lb_call_intrinsic(p, "llvm.matrix.column.major.store", store_args, gb_count_of(store_args), types, gb_count_of(types)); - - // return lb_addr_load(p, res); - // } - - - unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt); unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt); @@ -1042,23 +1013,13 @@ gb_internal lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, mask_elems[j] = llvm_basic_shuffle(p, y_columns[i], mask); } for (unsigned j = 0; j < N; j++) { - if (is_type_float(elem)) { - temp_muls[j] = LLVMBuildFMul(p->builder, mask_elems[j], x_columns[j], ""); - // LLVMSetFastMathFlags(temp_muls[j], LLVMFastMathAll); - } else { - temp_muls[j] = LLVMBuildMul(p->builder, mask_elems[j], x_columns[j], ""); - } + temp_muls[j] = llvm_vector_mul(p, mask_elems[j], x_columns[j]); } unsigned k = N; while (k > 1) { unsigned half = k/2; for (unsigned j = 0; j < half; j++) { - if (is_type_float(elem)) { - temp_muls[j] = LLVMBuildFAdd(p->builder, temp_muls[2*j + 0], temp_muls[2*j + 1], ""); - // LLVMSetFastMathFlags(temp_muls[j], LLVMFastMathAll); - } else { - temp_muls[j] = LLVMBuildAdd(p->builder, temp_muls[2*j + 0], temp_muls[2*j + 1], ""); - } + temp_muls[j] = llvm_vector_add(p, temp_muls[2*j + 0], temp_muls[2*j + 1]); } if ((k&1) != 0) { @@ -1207,23 +1168,46 @@ gb_internal lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbVal m_columns[column_index] = column; } - for (unsigned row_index = 0; row_index < column_count; row_index++) { - LLVMValueRef value = LLVMBuildExtractValue(p->builder, rhs.value, row_index, ""); - LLVMValueRef row = llvm_vector_broadcast(p, value, row_count); - v_rows[row_index] = row; + if (LLVMIsALoadInst(rhs.value)) { + LLVMValueRef rhs_ptr = LLVMGetOperand(rhs.value, 0); + LLVMTypeRef vector_type = LLVMVectorType(lb_type(p->module, elem), cast(unsigned)vector_count); + LLVMValueRef rhs_vector = LLVMBuildLoad2(p->builder, vector_type, rhs_ptr, ""); + LLVMSetAlignment(rhs_vector, cast(unsigned)type_align_of(type)); + + for (unsigned i = 0; i < column_count; i++) { + LLVMValueRef mask = llvm_mask_same(p->module, i, row_count); + v_rows[i] = llvm_basic_shuffle(p, rhs_vector, mask); + } + } else { + for (unsigned row_index = 0; row_index < column_count; row_index++) { + LLVMValueRef value = LLVMBuildExtractValue(p->builder, rhs.value, row_index, ""); + LLVMValueRef row = llvm_vector_broadcast(p, value, row_count); + v_rows[row_index] = row; + } + } + + auto temps = slice_make(permanent_allocator(), column_count); + for (unsigned i = 0; i < column_count; i++) { + temps[i] = llvm_vector_mul(p, m_columns[i], v_rows[i]); } GB_ASSERT(column_count > 0); - LLVMValueRef vector = nullptr; - for (i64 i = 0; i < column_count; i++) { - if (i == 0) { - vector = llvm_vector_mul(p, m_columns[i], v_rows[i]); - } else { - vector = llvm_vector_mul_add(p, m_columns[i], v_rows[i], vector); + unsigned k = column_count; + while (k > 1) { + unsigned half = k/2; + for (unsigned j = 0; j < half; j++) { + temps[j] = llvm_vector_add(p, temps[2*j + 0], temps[2*j + 1]); } + + if ((k&1) != 0) { + temps[half] = temps[k-1]; + } + k = (k+1)/2; } + LLVMValueRef vector = temps[0]; + return lb_matrix_cast_vector_to_type(p, vector, type); }