Use TensorPrimitives in FloatSpan (#3517)

* Use TensorPrimitives in FloatSpan instead of questionable AVX usage.
* Add some basic floatspan unit tests

---------

Co-authored-by: Lorenz Junglas <4759511+lolleko@users.noreply.github.com>
This commit is contained in:
sboxbot
2025-12-01 08:49:24 +00:00
committed by GitHub
parent 1a29f98237
commit ac798c9be1
3 changed files with 180 additions and 288 deletions

View File

@@ -44,6 +44,7 @@
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="10.0.0" />
<PackageReference Include="Svg.Skia" Version="2.0.0.4" />
<PackageReference Include="Azure.Messaging.WebPubSub.Client" Version="1.0.0" />
<PackageReference Include="System.Numerics.Tensors" Version="10.0.0" />
</ItemGroup>
<ItemGroup>

View File

@@ -1,12 +1,11 @@
using System.Numerics;
using System;
using System.Runtime.CompilerServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using System.Numerics.Tensors;
namespace Sandbox;
/// <summary>
/// Allows easy SIMD/AVX2 fast math on a span of floats
/// Provides vectorized operations over a span of floats.
/// </summary>
public ref struct FloatSpan
{
@@ -17,333 +16,63 @@ public ref struct FloatSpan
_span = span;
}
/// <summary>
/// Uses SIMD/AVX2 to find the maximum value in a span of floats.
/// </summary>
[MethodImpl( MethodImplOptions.AggressiveInlining )]
public float Max()
{
if ( _span.IsEmpty ) return 0.0f;
int i = 0;
float max = float.MinValue;
if ( Avx.IsSupported )
{
var maxVector = Vector256.Create( float.MinValue );
// Get a pointer to the span data
unsafe
{
fixed ( float* ptr = _span )
{
for ( ; i <= _span.Length - 8; i += 8 )
{
var v = Avx.LoadVector256( ptr + i ); // Correct memory load
maxVector = Avx.Max( maxVector, v );
}
}
}
// Reduce maxVector to a single float
max = Math.Max( max, maxVector.GetElement( 0 ) );
max = Math.Max( max, maxVector.GetElement( 1 ) );
max = Math.Max( max, maxVector.GetElement( 2 ) );
max = Math.Max( max, maxVector.GetElement( 3 ) );
max = Math.Max( max, maxVector.GetElement( 4 ) );
max = Math.Max( max, maxVector.GetElement( 5 ) );
max = Math.Max( max, maxVector.GetElement( 6 ) );
max = Math.Max( max, maxVector.GetElement( 7 ) );
}
// Handle remaining elements
for ( ; i < _span.Length; i++ )
max = Math.Max( max, _span[i] );
return max;
return _span.IsEmpty ? 0.0f : TensorPrimitives.Max( _span );
}
/// <summary>
/// Uses SIMD/AVX2 to find the minimum value in a span of floats.
/// </summary>
[MethodImpl( MethodImplOptions.AggressiveInlining )]
public float Min()
{
if ( _span.IsEmpty ) return 0.0f;
int i = 0;
float max = float.MaxValue;
if ( Avx.IsSupported )
{
var maxVector = Vector256.Create( float.MaxValue );
// Get a pointer to the span data
unsafe
{
fixed ( float* ptr = _span )
{
for ( ; i <= _span.Length - 8; i += 8 )
{
var v = Avx.LoadVector256( ptr + i ); // Correct memory load
maxVector = Avx.Min( maxVector, v );
}
}
}
max = Math.Min( max, maxVector.GetElement( 0 ) );
max = Math.Min( max, maxVector.GetElement( 1 ) );
max = Math.Min( max, maxVector.GetElement( 2 ) );
max = Math.Min( max, maxVector.GetElement( 3 ) );
max = Math.Min( max, maxVector.GetElement( 4 ) );
max = Math.Min( max, maxVector.GetElement( 5 ) );
max = Math.Min( max, maxVector.GetElement( 6 ) );
max = Math.Min( max, maxVector.GetElement( 7 ) );
}
// Handle remaining elements
for ( ; i < _span.Length; i++ )
max = Math.Min( max, _span[i] );
return max;
return _span.IsEmpty ? 0.0f : TensorPrimitives.Min( _span );
}
[MethodImpl( MethodImplOptions.AggressiveInlining )]
public float Average()
{
if ( _span.IsEmpty ) return 0.0f;
int i = 0;
float sum = 0f;
float len = _span.Length;
if ( Avx.IsSupported )
{
var sumVector = Vector256<float>.Zero;
unsafe
{
fixed ( float* ptr = _span )
{
// Sum using AVX2
for ( ; i <= len - 8; i += 8 )
{
var v = Avx.LoadVector256( ptr + i );
sumVector = Avx.Add( sumVector, v );
}
}
}
// Reduce sumVector to a single float
sum += sumVector.GetElement( 0 );
sum += sumVector.GetElement( 1 );
sum += sumVector.GetElement( 2 );
sum += sumVector.GetElement( 3 );
sum += sumVector.GetElement( 4 );
sum += sumVector.GetElement( 5 );
sum += sumVector.GetElement( 6 );
sum += sumVector.GetElement( 7 );
}
// Handle remaining elements
for ( ; i < len; i++ )
sum += _span[i];
return sum / len;
return _span.IsEmpty ? 0.0f : TensorPrimitives.Average( _span );
}
[MethodImpl( MethodImplOptions.AggressiveInlining )]
public float Sum()
{
if ( _span.IsEmpty ) return 0.0f;
int i = 0;
float sum = 0f;
if ( Avx.IsSupported )
{
var sumVector = Vector256<float>.Zero;
unsafe
{
fixed ( float* ptr = _span )
{
for ( ; i <= _span.Length - 8; i += 8 )
{
var v = Avx.LoadVector256( ptr + i );
sumVector = Avx.Add( sumVector, v );
}
}
}
// Reduce sumVector using horizontal adds
var temp = Avx.HorizontalAdd( sumVector, sumVector );
temp = Avx.HorizontalAdd( temp, temp );
temp = Avx.HorizontalAdd( temp, temp );
sum += temp.GetElement( 0 );
}
for ( ; i < _span.Length; i++ )
sum += _span[i];
return sum;
return _span.IsEmpty ? 0.0f : TensorPrimitives.Sum( _span );
}
[MethodImpl( MethodImplOptions.AggressiveInlining )]
public void Set( float value )
{
int i = 0;
if ( Avx.IsSupported )
{
unsafe
{
fixed ( float* ptr = _span )
{
var v = Vector256.Create( value );
for ( ; i <= _span.Length - 8; i += 8 )
{
Avx.Store( ptr + i, v );
}
}
}
}
for ( ; i < _span.Length; i++ )
_span[i] = value;
_span.Fill( value );
}
[MethodImpl( MethodImplOptions.AggressiveInlining )]
public readonly void Set( in Span<float> values )
public readonly void Set( ReadOnlySpan<float> values )
{
if ( _span.Length != values.Length ) throw new ArgumentException( "Source and destination spans must be the same length." );
unsafe
{
var size = _span.Length * sizeof( float );
fixed ( float* srcPtr = values, dstPtr = _span )
{
NativeLowLevel.Copy( (IntPtr)srcPtr, (IntPtr)dstPtr, (uint)size );
}
}
values.CopyTo( _span );
}
[MethodImpl( MethodImplOptions.AggressiveInlining )]
public readonly void CopyScaled( in Span<float> values, float scale )
public readonly void CopyScaled( ReadOnlySpan<float> values, float scale )
{
if ( _span.Length != values.Length ) throw new ArgumentException( "Source and destination spans must be the same length." );
int i = 0;
if ( Avx.IsSupported )
{
var scaleVector = Vector256.Create( scale );
unsafe
{
fixed ( float* srcPtr = values, dstPtr = _span )
{
for ( ; i <= _span.Length - 8; i += 8 )
{
var v = Avx.LoadVector256( srcPtr + i );
v = Avx.Multiply( v, scaleVector );
Avx.Store( dstPtr + i, v );
}
}
}
}
for ( ; i < _span.Length; i++ )
_span[i] = values[i] * scale;
TensorPrimitives.Multiply( values, scale, _span );
}
[MethodImpl( MethodImplOptions.AggressiveInlining )]
public readonly void Add( in Span<float> values )
public readonly void Add( ReadOnlySpan<float> values )
{
if ( _span.Length != values.Length ) throw new ArgumentException( "Source and destination spans must be the same length." );
int i = 0;
if ( Avx.IsSupported )
{
unsafe
{
fixed ( float* srcPtr = values, dstPtr = _span )
{
for ( ; i <= _span.Length - 8; i += 8 )
{
var v = Avx.LoadVector256( srcPtr + i );
var dst = Avx.LoadVector256( dstPtr + i );
dst = Avx.Add( dst, v );
Avx.Store( dstPtr + i, dst );
}
}
}
}
for ( ; i < _span.Length; i++ )
_span[i] += values[i];
TensorPrimitives.Add( _span, values, _span );
}
[MethodImpl( MethodImplOptions.AggressiveInlining )]
public readonly void AddScaled( in Span<float> values, float scale )
public readonly void AddScaled( ReadOnlySpan<float> values, float scale )
{
if ( _span.Length != values.Length ) throw new ArgumentException( "Source and destination spans must be the same length." );
int i = 0;
if ( Avx.IsSupported )
{
var scaleVector = Vector256.Create( scale );
unsafe
{
fixed ( float* srcPtr = values, dstPtr = _span )
{
for ( ; i <= _span.Length - 8; i += 8 )
{
var v = Avx.LoadVector256( srcPtr + i );
v = Avx.Multiply( v, scaleVector );
var dst = Avx.LoadVector256( dstPtr + i );
dst = Avx.Add( dst, v );
Avx.Store( dstPtr + i, dst );
}
}
}
}
for ( ; i < _span.Length; i++ )
_span[i] += values[i] * scale;
TensorPrimitives.MultiplyAdd( values, scale, _span, _span );
}
[MethodImpl( MethodImplOptions.AggressiveInlining )]
public readonly void Scale( float scale )
{
int i = 0;
if ( Avx.IsSupported )
{
var scaleVector = Vector256.Create( scale );
unsafe
{
fixed ( float* dstPtr = _span )
{
for ( ; i <= _span.Length - 8; i += 8 )
{
var v = Avx.LoadVector256( dstPtr + i );
v = Avx.Multiply( v, scaleVector );
Avx.Store( dstPtr + i, v );
}
}
}
}
for ( ; i < _span.Length; i++ )
_span[i] *= scale;
TensorPrimitives.Multiply( _span, scale, _span );
}
}

View File

@@ -0,0 +1,162 @@
namespace TestSystem.Math;
[TestClass]
public class FloatSpanTest
{
[TestMethod]
public void Max()
{
var data = new float[] { 1.0f, 5.0f, 3.0f, 9.0f, 2.0f };
var span = new FloatSpan( data );
Assert.AreEqual( 9.0f, span.Max() );
}
[TestMethod]
public void Max_Empty()
{
var data = new float[0];
var span = new FloatSpan( data );
Assert.AreEqual( 0.0f, span.Max() );
}
[TestMethod]
public void Min()
{
var data = new float[] { 1.0f, 5.0f, 3.0f, 9.0f, 2.0f };
var span = new FloatSpan( data );
Assert.AreEqual( 1.0f, span.Min() );
}
[TestMethod]
public void Min_Empty()
{
var data = new float[0];
var span = new FloatSpan( data );
Assert.AreEqual( 0.0f, span.Min() );
}
[TestMethod]
public void Average()
{
var data = new float[] { 2.0f, 4.0f, 6.0f, 8.0f };
var span = new FloatSpan( data );
Assert.AreEqual( 5.0f, span.Average() );
}
[TestMethod]
public void Average_Empty()
{
var data = new float[0];
var span = new FloatSpan( data );
Assert.AreEqual( 0.0f, span.Average() );
}
[TestMethod]
public void Sum()
{
var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
var span = new FloatSpan( data );
Assert.AreEqual( 10.0f, span.Sum() );
}
[TestMethod]
public void Sum_Empty()
{
var data = new float[0];
var span = new FloatSpan( data );
Assert.AreEqual( 0.0f, span.Sum() );
}
[TestMethod]
public void Set_Value()
{
var data = new float[5];
var span = new FloatSpan( data );
span.Set( 7.5f );
foreach ( var value in data )
{
Assert.AreEqual( 7.5f, value );
}
}
[TestMethod]
public void Set_Span()
{
var data = new float[4];
var span = new FloatSpan( data );
var values = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
span.Set( values );
CollectionAssert.AreEqual( values, data );
}
[TestMethod]
public void CopyScaled()
{
var data = new float[4];
var span = new FloatSpan( data );
var values = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
span.CopyScaled( values, 2.0f );
Assert.AreEqual( 2.0f, data[0] );
Assert.AreEqual( 4.0f, data[1] );
Assert.AreEqual( 6.0f, data[2] );
Assert.AreEqual( 8.0f, data[3] );
}
[TestMethod]
public void Add()
{
var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
var span = new FloatSpan( data );
var values = new float[] { 10.0f, 20.0f, 30.0f, 40.0f };
span.Add( values );
Assert.AreEqual( 11.0f, data[0] );
Assert.AreEqual( 22.0f, data[1] );
Assert.AreEqual( 33.0f, data[2] );
Assert.AreEqual( 44.0f, data[3] );
}
[TestMethod]
public void AddScaled()
{
var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
var span = new FloatSpan( data );
var values = new float[] { 10.0f, 20.0f, 30.0f, 40.0f };
span.AddScaled( values, 0.5f );
Assert.AreEqual( 6.0f, data[0] );
Assert.AreEqual( 12.0f, data[1] );
Assert.AreEqual( 18.0f, data[2] );
Assert.AreEqual( 24.0f, data[3] );
}
[TestMethod]
public void Scale()
{
var data = new float[] { 2.0f, 4.0f, 6.0f, 8.0f };
var span = new FloatSpan( data );
span.Scale( 0.5f );
Assert.AreEqual( 1.0f, data[0] );
Assert.AreEqual( 2.0f, data[1] );
Assert.AreEqual( 3.0f, data[2] );
Assert.AreEqual( 4.0f, data[3] );
}
}