mirror of
https://github.com/Facepunch/sbox-public.git
synced 2025-12-23 22:48:07 -05:00
Use TensorPrimitives in FloatSpan (#3517)
* Use TensorPrimitives in FloatSpan instead of questionable AVX usage. * Add some basic floatspan unit tests --------- Co-authored-by: Lorenz Junglas <4759511+lolleko@users.noreply.github.com>
This commit is contained in:
@@ -44,6 +44,7 @@
|
||||
<PackageReference Include="Microsoft.Extensions.Caching.Memory" Version="10.0.0" />
|
||||
<PackageReference Include="Svg.Skia" Version="2.0.0.4" />
|
||||
<PackageReference Include="Azure.Messaging.WebPubSub.Client" Version="1.0.0" />
|
||||
<PackageReference Include="System.Numerics.Tensors" Version="10.0.0" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
using System.Numerics;
|
||||
using System;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.Intrinsics;
|
||||
using System.Runtime.Intrinsics.X86;
|
||||
using System.Numerics.Tensors;
|
||||
|
||||
namespace Sandbox;
|
||||
|
||||
/// <summary>
|
||||
/// Allows easy SIMD/AVX2 fast math on a span of floats
|
||||
/// Provides vectorized operations over a span of floats.
|
||||
/// </summary>
|
||||
public ref struct FloatSpan
|
||||
{
|
||||
@@ -17,333 +16,63 @@ public ref struct FloatSpan
|
||||
_span = span;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Uses SIMD/AVX2 to find the maximum value in a span of floats.
|
||||
/// </summary>
|
||||
[MethodImpl( MethodImplOptions.AggressiveInlining )]
|
||||
public float Max()
|
||||
{
|
||||
if ( _span.IsEmpty ) return 0.0f;
|
||||
|
||||
int i = 0;
|
||||
float max = float.MinValue;
|
||||
|
||||
if ( Avx.IsSupported )
|
||||
{
|
||||
var maxVector = Vector256.Create( float.MinValue );
|
||||
|
||||
// Get a pointer to the span data
|
||||
unsafe
|
||||
{
|
||||
fixed ( float* ptr = _span )
|
||||
{
|
||||
for ( ; i <= _span.Length - 8; i += 8 )
|
||||
{
|
||||
var v = Avx.LoadVector256( ptr + i ); // Correct memory load
|
||||
maxVector = Avx.Max( maxVector, v );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce maxVector to a single float
|
||||
max = Math.Max( max, maxVector.GetElement( 0 ) );
|
||||
max = Math.Max( max, maxVector.GetElement( 1 ) );
|
||||
max = Math.Max( max, maxVector.GetElement( 2 ) );
|
||||
max = Math.Max( max, maxVector.GetElement( 3 ) );
|
||||
max = Math.Max( max, maxVector.GetElement( 4 ) );
|
||||
max = Math.Max( max, maxVector.GetElement( 5 ) );
|
||||
max = Math.Max( max, maxVector.GetElement( 6 ) );
|
||||
max = Math.Max( max, maxVector.GetElement( 7 ) );
|
||||
}
|
||||
|
||||
// Handle remaining elements
|
||||
for ( ; i < _span.Length; i++ )
|
||||
max = Math.Max( max, _span[i] );
|
||||
|
||||
return max;
|
||||
return _span.IsEmpty ? 0.0f : TensorPrimitives.Max( _span );
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Uses SIMD/AVX2 to find the minimum value in a span of floats.
|
||||
/// </summary>
|
||||
[MethodImpl( MethodImplOptions.AggressiveInlining )]
|
||||
public float Min()
|
||||
{
|
||||
if ( _span.IsEmpty ) return 0.0f;
|
||||
|
||||
int i = 0;
|
||||
float max = float.MaxValue;
|
||||
|
||||
if ( Avx.IsSupported )
|
||||
{
|
||||
var maxVector = Vector256.Create( float.MaxValue );
|
||||
|
||||
// Get a pointer to the span data
|
||||
unsafe
|
||||
{
|
||||
fixed ( float* ptr = _span )
|
||||
{
|
||||
for ( ; i <= _span.Length - 8; i += 8 )
|
||||
{
|
||||
var v = Avx.LoadVector256( ptr + i ); // Correct memory load
|
||||
maxVector = Avx.Min( maxVector, v );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
max = Math.Min( max, maxVector.GetElement( 0 ) );
|
||||
max = Math.Min( max, maxVector.GetElement( 1 ) );
|
||||
max = Math.Min( max, maxVector.GetElement( 2 ) );
|
||||
max = Math.Min( max, maxVector.GetElement( 3 ) );
|
||||
max = Math.Min( max, maxVector.GetElement( 4 ) );
|
||||
max = Math.Min( max, maxVector.GetElement( 5 ) );
|
||||
max = Math.Min( max, maxVector.GetElement( 6 ) );
|
||||
max = Math.Min( max, maxVector.GetElement( 7 ) );
|
||||
}
|
||||
|
||||
// Handle remaining elements
|
||||
for ( ; i < _span.Length; i++ )
|
||||
max = Math.Min( max, _span[i] );
|
||||
|
||||
return max;
|
||||
return _span.IsEmpty ? 0.0f : TensorPrimitives.Min( _span );
|
||||
}
|
||||
|
||||
[MethodImpl( MethodImplOptions.AggressiveInlining )]
|
||||
public float Average()
|
||||
{
|
||||
if ( _span.IsEmpty ) return 0.0f;
|
||||
|
||||
int i = 0;
|
||||
float sum = 0f;
|
||||
float len = _span.Length;
|
||||
|
||||
if ( Avx.IsSupported )
|
||||
{
|
||||
var sumVector = Vector256<float>.Zero;
|
||||
|
||||
unsafe
|
||||
{
|
||||
fixed ( float* ptr = _span )
|
||||
{
|
||||
// Sum using AVX2
|
||||
for ( ; i <= len - 8; i += 8 )
|
||||
{
|
||||
var v = Avx.LoadVector256( ptr + i );
|
||||
sumVector = Avx.Add( sumVector, v );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce sumVector to a single float
|
||||
sum += sumVector.GetElement( 0 );
|
||||
sum += sumVector.GetElement( 1 );
|
||||
sum += sumVector.GetElement( 2 );
|
||||
sum += sumVector.GetElement( 3 );
|
||||
sum += sumVector.GetElement( 4 );
|
||||
sum += sumVector.GetElement( 5 );
|
||||
sum += sumVector.GetElement( 6 );
|
||||
sum += sumVector.GetElement( 7 );
|
||||
}
|
||||
|
||||
// Handle remaining elements
|
||||
for ( ; i < len; i++ )
|
||||
sum += _span[i];
|
||||
|
||||
return sum / len;
|
||||
return _span.IsEmpty ? 0.0f : TensorPrimitives.Average( _span );
|
||||
}
|
||||
|
||||
[MethodImpl( MethodImplOptions.AggressiveInlining )]
|
||||
public float Sum()
|
||||
{
|
||||
if ( _span.IsEmpty ) return 0.0f;
|
||||
|
||||
int i = 0;
|
||||
|
||||
float sum = 0f;
|
||||
|
||||
if ( Avx.IsSupported )
|
||||
{
|
||||
var sumVector = Vector256<float>.Zero;
|
||||
|
||||
unsafe
|
||||
{
|
||||
fixed ( float* ptr = _span )
|
||||
{
|
||||
for ( ; i <= _span.Length - 8; i += 8 )
|
||||
{
|
||||
var v = Avx.LoadVector256( ptr + i );
|
||||
sumVector = Avx.Add( sumVector, v );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce sumVector using horizontal adds
|
||||
var temp = Avx.HorizontalAdd( sumVector, sumVector );
|
||||
temp = Avx.HorizontalAdd( temp, temp );
|
||||
temp = Avx.HorizontalAdd( temp, temp );
|
||||
|
||||
sum += temp.GetElement( 0 );
|
||||
|
||||
}
|
||||
|
||||
for ( ; i < _span.Length; i++ )
|
||||
sum += _span[i];
|
||||
|
||||
return sum;
|
||||
return _span.IsEmpty ? 0.0f : TensorPrimitives.Sum( _span );
|
||||
}
|
||||
|
||||
[MethodImpl( MethodImplOptions.AggressiveInlining )]
|
||||
public void Set( float value )
|
||||
{
|
||||
int i = 0;
|
||||
|
||||
if ( Avx.IsSupported )
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
fixed ( float* ptr = _span )
|
||||
{
|
||||
var v = Vector256.Create( value );
|
||||
for ( ; i <= _span.Length - 8; i += 8 )
|
||||
{
|
||||
Avx.Store( ptr + i, v );
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for ( ; i < _span.Length; i++ )
|
||||
_span[i] = value;
|
||||
_span.Fill( value );
|
||||
}
|
||||
|
||||
[MethodImpl( MethodImplOptions.AggressiveInlining )]
|
||||
public readonly void Set( in Span<float> values )
|
||||
public readonly void Set( ReadOnlySpan<float> values )
|
||||
{
|
||||
if ( _span.Length != values.Length ) throw new ArgumentException( "Source and destination spans must be the same length." );
|
||||
|
||||
unsafe
|
||||
{
|
||||
var size = _span.Length * sizeof( float );
|
||||
|
||||
fixed ( float* srcPtr = values, dstPtr = _span )
|
||||
{
|
||||
NativeLowLevel.Copy( (IntPtr)srcPtr, (IntPtr)dstPtr, (uint)size );
|
||||
}
|
||||
}
|
||||
values.CopyTo( _span );
|
||||
}
|
||||
|
||||
[MethodImpl( MethodImplOptions.AggressiveInlining )]
|
||||
public readonly void CopyScaled( in Span<float> values, float scale )
|
||||
public readonly void CopyScaled( ReadOnlySpan<float> values, float scale )
|
||||
{
|
||||
if ( _span.Length != values.Length ) throw new ArgumentException( "Source and destination spans must be the same length." );
|
||||
|
||||
int i = 0;
|
||||
|
||||
if ( Avx.IsSupported )
|
||||
{
|
||||
var scaleVector = Vector256.Create( scale );
|
||||
|
||||
unsafe
|
||||
{
|
||||
fixed ( float* srcPtr = values, dstPtr = _span )
|
||||
{
|
||||
for ( ; i <= _span.Length - 8; i += 8 )
|
||||
{
|
||||
var v = Avx.LoadVector256( srcPtr + i );
|
||||
v = Avx.Multiply( v, scaleVector );
|
||||
Avx.Store( dstPtr + i, v );
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for ( ; i < _span.Length; i++ )
|
||||
_span[i] = values[i] * scale;
|
||||
TensorPrimitives.Multiply( values, scale, _span );
|
||||
}
|
||||
|
||||
[MethodImpl( MethodImplOptions.AggressiveInlining )]
|
||||
public readonly void Add( in Span<float> values )
|
||||
public readonly void Add( ReadOnlySpan<float> values )
|
||||
{
|
||||
if ( _span.Length != values.Length ) throw new ArgumentException( "Source and destination spans must be the same length." );
|
||||
|
||||
int i = 0;
|
||||
|
||||
if ( Avx.IsSupported )
|
||||
{
|
||||
unsafe
|
||||
{
|
||||
fixed ( float* srcPtr = values, dstPtr = _span )
|
||||
{
|
||||
for ( ; i <= _span.Length - 8; i += 8 )
|
||||
{
|
||||
var v = Avx.LoadVector256( srcPtr + i );
|
||||
var dst = Avx.LoadVector256( dstPtr + i );
|
||||
dst = Avx.Add( dst, v );
|
||||
Avx.Store( dstPtr + i, dst );
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for ( ; i < _span.Length; i++ )
|
||||
_span[i] += values[i];
|
||||
TensorPrimitives.Add( _span, values, _span );
|
||||
}
|
||||
|
||||
[MethodImpl( MethodImplOptions.AggressiveInlining )]
|
||||
public readonly void AddScaled( in Span<float> values, float scale )
|
||||
public readonly void AddScaled( ReadOnlySpan<float> values, float scale )
|
||||
{
|
||||
if ( _span.Length != values.Length ) throw new ArgumentException( "Source and destination spans must be the same length." );
|
||||
|
||||
int i = 0;
|
||||
|
||||
if ( Avx.IsSupported )
|
||||
{
|
||||
var scaleVector = Vector256.Create( scale );
|
||||
|
||||
unsafe
|
||||
{
|
||||
fixed ( float* srcPtr = values, dstPtr = _span )
|
||||
{
|
||||
for ( ; i <= _span.Length - 8; i += 8 )
|
||||
{
|
||||
var v = Avx.LoadVector256( srcPtr + i );
|
||||
v = Avx.Multiply( v, scaleVector );
|
||||
var dst = Avx.LoadVector256( dstPtr + i );
|
||||
dst = Avx.Add( dst, v );
|
||||
Avx.Store( dstPtr + i, dst );
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for ( ; i < _span.Length; i++ )
|
||||
_span[i] += values[i] * scale;
|
||||
TensorPrimitives.MultiplyAdd( values, scale, _span, _span );
|
||||
}
|
||||
|
||||
[MethodImpl( MethodImplOptions.AggressiveInlining )]
|
||||
public readonly void Scale( float scale )
|
||||
{
|
||||
int i = 0;
|
||||
|
||||
if ( Avx.IsSupported )
|
||||
{
|
||||
var scaleVector = Vector256.Create( scale );
|
||||
|
||||
unsafe
|
||||
{
|
||||
fixed ( float* dstPtr = _span )
|
||||
{
|
||||
for ( ; i <= _span.Length - 8; i += 8 )
|
||||
{
|
||||
var v = Avx.LoadVector256( dstPtr + i );
|
||||
v = Avx.Multiply( v, scaleVector );
|
||||
Avx.Store( dstPtr + i, v );
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for ( ; i < _span.Length; i++ )
|
||||
_span[i] *= scale;
|
||||
TensorPrimitives.Multiply( _span, scale, _span );
|
||||
}
|
||||
}
|
||||
|
||||
162
engine/Sandbox.Test/System/Math/FloatSpan.cs
Normal file
162
engine/Sandbox.Test/System/Math/FloatSpan.cs
Normal file
@@ -0,0 +1,162 @@
|
||||
namespace TestSystem.Math;
|
||||
|
||||
[TestClass]
|
||||
public class FloatSpanTest
|
||||
{
|
||||
[TestMethod]
|
||||
public void Max()
|
||||
{
|
||||
var data = new float[] { 1.0f, 5.0f, 3.0f, 9.0f, 2.0f };
|
||||
var span = new FloatSpan( data );
|
||||
|
||||
Assert.AreEqual( 9.0f, span.Max() );
|
||||
}
|
||||
|
||||
[TestMethod]
|
||||
public void Max_Empty()
|
||||
{
|
||||
var data = new float[0];
|
||||
var span = new FloatSpan( data );
|
||||
|
||||
Assert.AreEqual( 0.0f, span.Max() );
|
||||
}
|
||||
|
||||
[TestMethod]
|
||||
public void Min()
|
||||
{
|
||||
var data = new float[] { 1.0f, 5.0f, 3.0f, 9.0f, 2.0f };
|
||||
var span = new FloatSpan( data );
|
||||
|
||||
Assert.AreEqual( 1.0f, span.Min() );
|
||||
}
|
||||
|
||||
[TestMethod]
|
||||
public void Min_Empty()
|
||||
{
|
||||
var data = new float[0];
|
||||
var span = new FloatSpan( data );
|
||||
|
||||
Assert.AreEqual( 0.0f, span.Min() );
|
||||
}
|
||||
|
||||
[TestMethod]
|
||||
public void Average()
|
||||
{
|
||||
var data = new float[] { 2.0f, 4.0f, 6.0f, 8.0f };
|
||||
var span = new FloatSpan( data );
|
||||
|
||||
Assert.AreEqual( 5.0f, span.Average() );
|
||||
}
|
||||
|
||||
[TestMethod]
|
||||
public void Average_Empty()
|
||||
{
|
||||
var data = new float[0];
|
||||
var span = new FloatSpan( data );
|
||||
|
||||
Assert.AreEqual( 0.0f, span.Average() );
|
||||
}
|
||||
|
||||
[TestMethod]
|
||||
public void Sum()
|
||||
{
|
||||
var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
|
||||
var span = new FloatSpan( data );
|
||||
|
||||
Assert.AreEqual( 10.0f, span.Sum() );
|
||||
}
|
||||
|
||||
[TestMethod]
|
||||
public void Sum_Empty()
|
||||
{
|
||||
var data = new float[0];
|
||||
var span = new FloatSpan( data );
|
||||
|
||||
Assert.AreEqual( 0.0f, span.Sum() );
|
||||
}
|
||||
|
||||
[TestMethod]
|
||||
public void Set_Value()
|
||||
{
|
||||
var data = new float[5];
|
||||
var span = new FloatSpan( data );
|
||||
|
||||
span.Set( 7.5f );
|
||||
|
||||
foreach ( var value in data )
|
||||
{
|
||||
Assert.AreEqual( 7.5f, value );
|
||||
}
|
||||
}
|
||||
|
||||
[TestMethod]
|
||||
public void Set_Span()
|
||||
{
|
||||
var data = new float[4];
|
||||
var span = new FloatSpan( data );
|
||||
var values = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
|
||||
|
||||
span.Set( values );
|
||||
|
||||
CollectionAssert.AreEqual( values, data );
|
||||
}
|
||||
|
||||
[TestMethod]
|
||||
public void CopyScaled()
|
||||
{
|
||||
var data = new float[4];
|
||||
var span = new FloatSpan( data );
|
||||
var values = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
|
||||
|
||||
span.CopyScaled( values, 2.0f );
|
||||
|
||||
Assert.AreEqual( 2.0f, data[0] );
|
||||
Assert.AreEqual( 4.0f, data[1] );
|
||||
Assert.AreEqual( 6.0f, data[2] );
|
||||
Assert.AreEqual( 8.0f, data[3] );
|
||||
}
|
||||
|
||||
[TestMethod]
|
||||
public void Add()
|
||||
{
|
||||
var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
|
||||
var span = new FloatSpan( data );
|
||||
var values = new float[] { 10.0f, 20.0f, 30.0f, 40.0f };
|
||||
|
||||
span.Add( values );
|
||||
|
||||
Assert.AreEqual( 11.0f, data[0] );
|
||||
Assert.AreEqual( 22.0f, data[1] );
|
||||
Assert.AreEqual( 33.0f, data[2] );
|
||||
Assert.AreEqual( 44.0f, data[3] );
|
||||
}
|
||||
|
||||
[TestMethod]
|
||||
public void AddScaled()
|
||||
{
|
||||
var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
|
||||
var span = new FloatSpan( data );
|
||||
var values = new float[] { 10.0f, 20.0f, 30.0f, 40.0f };
|
||||
|
||||
span.AddScaled( values, 0.5f );
|
||||
|
||||
Assert.AreEqual( 6.0f, data[0] );
|
||||
Assert.AreEqual( 12.0f, data[1] );
|
||||
Assert.AreEqual( 18.0f, data[2] );
|
||||
Assert.AreEqual( 24.0f, data[3] );
|
||||
}
|
||||
|
||||
[TestMethod]
|
||||
public void Scale()
|
||||
{
|
||||
var data = new float[] { 2.0f, 4.0f, 6.0f, 8.0f };
|
||||
var span = new FloatSpan( data );
|
||||
|
||||
span.Scale( 0.5f );
|
||||
|
||||
Assert.AreEqual( 1.0f, data[0] );
|
||||
Assert.AreEqual( 2.0f, data[1] );
|
||||
Assert.AreEqual( 3.0f, data[2] );
|
||||
Assert.AreEqual( 4.0f, data[3] );
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user