From dd15cc03bb90912ed1136e1bc9028493d378132c Mon Sep 17 00:00:00 2001 From: Linus Hamlin Date: Thu, 18 Jun 2026 16:34:13 +0200 Subject: [PATCH 1/3] Add ShouldEarlyExitOnNan --- ...TensorPrimitives.IIndexOfMinMaxOperator.cs | 97 +++++++++++++------ .../netcore/TensorPrimitives.IndexOfMax.cs | 1 + .../TensorPrimitives.IndexOfMaxMagnitude.cs | 1 + .../netcore/TensorPrimitives.IndexOfMin.cs | 1 + .../TensorPrimitives.IndexOfMinMagnitude.cs | 1 + 5 files changed, 74 insertions(+), 27 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Common/TensorPrimitives.IIndexOfMinMaxOperator.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Common/TensorPrimitives.IIndexOfMinMaxOperator.cs index 9887a65da80968..b2963f97320316 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Common/TensorPrimitives.IIndexOfMinMaxOperator.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Common/TensorPrimitives.IIndexOfMinMaxOperator.cs @@ -11,6 +11,7 @@ public static unsafe partial class TensorPrimitives { private interface IIndexOfMinMaxOperator { + static abstract bool ShouldEarlyExitOnNan { get; } static abstract T Aggregate(Vector128 value); static abstract T Aggregate(Vector256 value); static abstract T Aggregate(Vector512 value); @@ -60,7 +61,7 @@ private static int IndexOfMinMaxFallback(ReadOnlySpan x) { T result = x[0]; int resultIndex = 0; - if (T.IsNaN(result)) + if (TOperator.ShouldEarlyExitOnNan && T.IsNaN(result)) { return resultIndex; } @@ -68,7 +69,7 @@ private static int IndexOfMinMaxFallback(ReadOnlySpan x) for (int i = 1; i < x.Length; i++) { T current = x[i]; - if (T.IsNaN(current)) + if (TOperator.ShouldEarlyExitOnNan && T.IsNaN(current)) { return i; } @@ -79,7 +80,7 @@ private static int IndexOfMinMaxFallback(ReadOnlySpan x) } } - return resultIndex; + return !TOperator.ShouldEarlyExitOnNan && T.IsNaN(result) ? -1 : resultIndex; } private static int IndexOfMinMaxVectorized128Size4Plus(ReadOnlySpan x) @@ -91,7 +92,7 @@ private static int IndexOfMinMaxVectorized128Size4Plus(ReadO // Initialize result by reading first vector and quick return if possible. Vector128 result = Vector128.Create(x); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector128 nanMask = IsNaN(result); if (nanMask != Vector128.Zero) @@ -124,7 +125,7 @@ private static int IndexOfMinMaxVectorized128Size4Plus(ReadO } // Quick return if possible. - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector128 nanMask = IsNaN(current); if (nanMask != Vector128.Zero) @@ -145,8 +146,14 @@ private static int IndexOfMinMaxVectorized128Size4Plus(ReadO { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); + if (!TOperator.ShouldEarlyExitOnNan && T.IsNaN(aggResult)) + { + return -1; + } + Vector128 aggMask = ~Vector128.Equals(result.As(), Vector128.Create(aggResult).As()); Vector128 aggIndex = resultIndex | aggMask; + return int.CreateTruncating(HorizontalAggregate>(aggIndex)); } } @@ -158,7 +165,7 @@ private static int IndexOfMinMaxVectorized128Size2(ReadOnlySpan // Initialize result by reading first vector and quick return if possible. Vector128 result = Vector128.Create(x); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector128 nanMask = IsNaN(result); if (nanMask != Vector128.Zero) @@ -192,7 +199,7 @@ private static int IndexOfMinMaxVectorized128Size2(ReadOnlySpan } // Quick return if possible. - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector128 nanMask = IsNaN(current); if (nanMask != Vector128.Zero) @@ -216,8 +223,12 @@ private static int IndexOfMinMaxVectorized128Size2(ReadOnlySpan { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); - Vector128 aggMask = ~Vector128.Equals(result.AsInt16(), Vector128.Create(aggResult).AsInt16()); + if (!TOperator.ShouldEarlyExitOnNan && T.IsNaN(aggResult)) + { + return -1; + } + Vector128 aggMask = ~Vector128.Equals(result.AsInt16(), Vector128.Create(aggResult).AsInt16()); (Vector128 mask1, Vector128 mask2) = Vector128.Widen(aggMask); Vector128 aggIndex = resultIndex1 | mask1.AsUInt32(); aggIndex = MinOperator.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); @@ -233,7 +244,7 @@ private static int IndexOfMinMaxVectorized128Size1(ReadOnlySpan // Initialize result by reading first vector and quick return if possible. Vector128 result = Vector128.Create(x); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && typeof(T) == typeof(float) || typeof(T) == typeof(double)) { Vector128 nanMask = IsNaN(result); if (nanMask != Vector128.Zero) @@ -269,7 +280,7 @@ private static int IndexOfMinMaxVectorized128Size1(ReadOnlySpan } // Quick return if possible. - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector128 nanMask = IsNaN(current); if (nanMask != Vector128.Zero) @@ -299,8 +310,12 @@ private static int IndexOfMinMaxVectorized128Size1(ReadOnlySpan { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); - Vector128 aggMask = ~Vector128.Equals(result.AsSByte(), Vector128.Create(aggResult).AsSByte()); + if (!TOperator.ShouldEarlyExitOnNan && T.IsNaN(aggResult)) + { + return -1; + } + Vector128 aggMask = ~Vector128.Equals(result.AsSByte(), Vector128.Create(aggResult).AsSByte()); (Vector128 lowerMask, Vector128 upperMask) = Vector128.Widen(aggMask); (Vector128 mask1, Vector128 mask2) = Vector128.Widen(lowerMask); (Vector128 mask3, Vector128 mask4) = Vector128.Widen(upperMask); @@ -322,7 +337,7 @@ private static int IndexOfMinMaxVectorized256Size4Plus(ReadO // Initialize result by reading first vector and quick return if possible. Vector256 result = Vector256.Create(x); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector256 nanMask = IsNaN(result); if (nanMask != Vector256.Zero) @@ -355,7 +370,7 @@ private static int IndexOfMinMaxVectorized256Size4Plus(ReadO } // Quick return if possible. - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector256 nanMask = IsNaN(current); if (nanMask != Vector256.Zero) @@ -376,8 +391,14 @@ private static int IndexOfMinMaxVectorized256Size4Plus(ReadO { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); + if (!TOperator.ShouldEarlyExitOnNan && T.IsNaN(aggResult)) + { + return -1; + } + Vector256 aggMask = ~Vector256.Equals(result.As(), Vector256.Create(aggResult).As()); Vector256 aggIndex = resultIndex | aggMask; + return int.CreateTruncating(HorizontalAggregate>(aggIndex)); } } @@ -389,7 +410,7 @@ private static int IndexOfMinMaxVectorized256Size2(ReadOnlySpan // Initialize result by reading first vector and quick return if possible. Vector256 result = Vector256.Create(x); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector256 nanMask = IsNaN(result); if (nanMask != Vector256.Zero) @@ -423,7 +444,7 @@ private static int IndexOfMinMaxVectorized256Size2(ReadOnlySpan } // Quick return if possible. - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector256 nanMask = IsNaN(current); if (nanMask != Vector256.Zero) @@ -447,8 +468,12 @@ private static int IndexOfMinMaxVectorized256Size2(ReadOnlySpan { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); - Vector256 aggMask = ~Vector256.Equals(result.AsInt16(), Vector256.Create(aggResult).AsInt16()); + if (!TOperator.ShouldEarlyExitOnNan && T.IsNaN(aggResult)) + { + return -1; + } + Vector256 aggMask = ~Vector256.Equals(result.AsInt16(), Vector256.Create(aggResult).AsInt16()); (Vector256 mask1, Vector256 mask2) = Vector256.Widen(aggMask); Vector256 aggIndex = resultIndex1 | mask1.AsUInt32(); aggIndex = MinOperator.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); @@ -464,7 +489,7 @@ private static int IndexOfMinMaxVectorized256Size1(ReadOnlySpan // Initialize result by reading first vector and quick return if possible. Vector256 result = Vector256.Create(x); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector256 nanMask = IsNaN(result); if (nanMask != Vector256.Zero) @@ -500,7 +525,7 @@ private static int IndexOfMinMaxVectorized256Size1(ReadOnlySpan } // Quick return if possible. - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector256 nanMask = IsNaN(current); if (nanMask != Vector256.Zero) @@ -530,8 +555,12 @@ private static int IndexOfMinMaxVectorized256Size1(ReadOnlySpan { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); - Vector256 aggMask = ~Vector256.Equals(result.AsSByte(), Vector256.Create(aggResult).AsSByte()); + if (!TOperator.ShouldEarlyExitOnNan && T.IsNaN(aggResult)) + { + return -1; + } + Vector256 aggMask = ~Vector256.Equals(result.AsSByte(), Vector256.Create(aggResult).AsSByte()); (Vector256 lowerMask, Vector256 upperMask) = Vector256.Widen(aggMask); (Vector256 mask1, Vector256 mask2) = Vector256.Widen(lowerMask); (Vector256 mask3, Vector256 mask4) = Vector256.Widen(upperMask); @@ -553,7 +582,7 @@ private static int IndexOfMinMaxVectorized512Size4Plus(ReadO // Initialize result by reading first vector and quick return if possible. Vector512 result = Vector512.Create(x); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector512 nanMask = IsNaN(result); if (nanMask != Vector512.Zero) @@ -586,7 +615,7 @@ private static int IndexOfMinMaxVectorized512Size4Plus(ReadO } // Quick return if possible. - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector512 nanMask = IsNaN(current); if (nanMask != Vector512.Zero) @@ -607,8 +636,14 @@ private static int IndexOfMinMaxVectorized512Size4Plus(ReadO { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); + if (!TOperator.ShouldEarlyExitOnNan && T.IsNaN(aggResult)) + { + return -1; + } + Vector512 aggMask = ~Vector512.Equals(result.As(), Vector512.Create(aggResult).As()); Vector512 aggIndex = resultIndex | aggMask; + return int.CreateTruncating(HorizontalAggregate>(aggIndex)); } } @@ -620,7 +655,7 @@ private static int IndexOfMinMaxVectorized512Size2(ReadOnlySpan // Initialize result by reading first vector and quick return if possible. Vector512 result = Vector512.Create(x); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector512 nanMask = IsNaN(result); if (nanMask != Vector512.Zero) @@ -654,7 +689,7 @@ private static int IndexOfMinMaxVectorized512Size2(ReadOnlySpan } // Quick return if possible. - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector512 nanMask = IsNaN(current); if (nanMask != Vector512.Zero) @@ -678,8 +713,12 @@ private static int IndexOfMinMaxVectorized512Size2(ReadOnlySpan { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); - Vector512 aggMask = ~Vector512.Equals(result.AsInt16(), Vector512.Create(aggResult).AsInt16()); + if (!TOperator.ShouldEarlyExitOnNan && T.IsNaN(aggResult)) + { + return -1; + } + Vector512 aggMask = ~Vector512.Equals(result.AsInt16(), Vector512.Create(aggResult).AsInt16()); (Vector512 mask1, Vector512 mask2) = Vector512.Widen(aggMask); Vector512 aggIndex = resultIndex1 | mask1.AsUInt32(); aggIndex = MinOperator.Invoke(aggIndex, resultIndex2 | mask2.AsUInt32()); @@ -695,7 +734,7 @@ private static int IndexOfMinMaxVectorized512Size1(ReadOnlySpan // Initialize result by reading first vector and quick return if possible. Vector512 result = Vector512.Create(x); - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector512 nanMask = IsNaN(result); if (nanMask != Vector512.Zero) @@ -731,7 +770,7 @@ private static int IndexOfMinMaxVectorized512Size1(ReadOnlySpan } // Quick return if possible. - if (typeof(T) == typeof(float) || typeof(T) == typeof(double)) + if (TOperator.ShouldEarlyExitOnNan && (typeof(T) == typeof(float) || typeof(T) == typeof(double))) { Vector512 nanMask = IsNaN(current); if (nanMask != Vector512.Zero) @@ -761,8 +800,12 @@ private static int IndexOfMinMaxVectorized512Size1(ReadOnlySpan { // Where result does not bitwise-equal the aggregate min/max value; replace indices with uint.MaxValue. Then find the min index. T aggResult = TOperator.Aggregate(result); - Vector512 aggMask = ~Vector512.Equals(result.AsSByte(), Vector512.Create(aggResult).AsSByte()); + if (!TOperator.ShouldEarlyExitOnNan && T.IsNaN(aggResult)) + { + return -1; + } + Vector512 aggMask = ~Vector512.Equals(result.AsSByte(), Vector512.Create(aggResult).AsSByte()); (Vector512 lowerMask, Vector512 upperMask) = Vector512.Widen(aggMask); (Vector512 mask1, Vector512 mask2) = Vector512.Widen(lowerMask); (Vector512 mask3, Vector512 mask4) = Vector512.Widen(upperMask); diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMax.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMax.cs index 7ea6f2f349a797..858986c5503d21 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMax.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMax.cs @@ -29,6 +29,7 @@ public static int IndexOfMax(ReadOnlySpan x) /// Returns the index of MathF.Max(x, y) internal readonly struct IndexOfMaxOperator : IIndexOfMinMaxOperator where T : INumber { + public static bool ShouldEarlyExitOnNan => true; public static T Aggregate(Vector128 x) => HorizontalAggregate>(x); public static T Aggregate(Vector256 x) => HorizontalAggregate>(x); public static T Aggregate(Vector512 x) => HorizontalAggregate>(x); diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxMagnitude.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxMagnitude.cs index 5ca77310c5fa3d..aa303fc09b5774 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxMagnitude.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxMagnitude.cs @@ -28,6 +28,7 @@ public static int IndexOfMaxMagnitude(ReadOnlySpan x) internal readonly struct IndexOfMaxMagnitudeOperator : IIndexOfMinMaxOperator where T : INumber { + public static bool ShouldEarlyExitOnNan => true; public static T Aggregate(Vector128 x) => HorizontalAggregate>(x); public static T Aggregate(Vector256 x) => HorizontalAggregate>(x); public static T Aggregate(Vector512 x) => HorizontalAggregate>(x); diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMin.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMin.cs index 135ca0ac294cc6..6bb77616b4da6d 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMin.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMin.cs @@ -28,6 +28,7 @@ public static int IndexOfMin(ReadOnlySpan x) /// Returns the index of MathF.Min(x, y) internal readonly struct IndexOfMinOperator : IIndexOfMinMaxOperator where T : INumber { + public static bool ShouldEarlyExitOnNan => true; public static T Aggregate(Vector128 x) => HorizontalAggregate>(x); public static T Aggregate(Vector256 x) => HorizontalAggregate>(x); public static T Aggregate(Vector512 x) => HorizontalAggregate>(x); diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinMagnitude.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinMagnitude.cs index 437c9537e6962e..e13a5dd1915e38 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinMagnitude.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinMagnitude.cs @@ -28,6 +28,7 @@ public static int IndexOfMinMagnitude(ReadOnlySpan x) internal readonly struct IndexOfMinMagnitudeOperator : IIndexOfMinMaxOperator where T : INumber { + public static bool ShouldEarlyExitOnNan => true; public static T Aggregate(Vector128 x) => HorizontalAggregate>(x); public static T Aggregate(Vector256 x) => HorizontalAggregate>(x); public static T Aggregate(Vector512 x) => HorizontalAggregate>(x); From dc34d4da0b1fbcc9d38cf735e458ae5633c77c4d Mon Sep 17 00:00:00 2001 From: Linus Hamlin Date: Thu, 18 Jun 2026 16:34:27 +0200 Subject: [PATCH 2/3] Add new methods --- .../ref/System.Numerics.Tensors.netcore.cs | 4 + .../src/System.Numerics.Tensors.csproj | 4 + ...sorPrimitives.IndexOfMaxMagnitudeNumber.cs | 152 ++++++++++++++++++ .../TensorPrimitives.IndexOfMaxNumber.cs | 104 ++++++++++++ ...sorPrimitives.IndexOfMinMagnitudeNumber.cs | 149 +++++++++++++++++ .../TensorPrimitives.IndexOfMinNumber.cs | 105 ++++++++++++ 6 files changed, 518 insertions(+) create mode 100644 src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxMagnitudeNumber.cs create mode 100644 src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxNumber.cs create mode 100644 src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinMagnitudeNumber.cs create mode 100644 src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinNumber.cs diff --git a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs index 920fddf76ec9f4..d8ce86e5720ce7 100644 --- a/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/ref/System.Numerics.Tensors.netcore.cs @@ -989,9 +989,13 @@ public static void Ieee754Remainder(System.ReadOnlySpan x, T y, System.Spa public static void Ieee754Remainder(T x, System.ReadOnlySpan y, System.Span destination) where T : System.Numerics.IFloatingPointIeee754 { } public static void ILogB(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.IFloatingPointIeee754 { } public static void Increment(System.ReadOnlySpan x, System.Span destination) where T : System.Numerics.IIncrementOperators { } + public static int IndexOfMaxMagnitudeNumber(System.ReadOnlySpan x) where T : System.Numerics.INumber { throw null; } public static int IndexOfMaxMagnitude(System.ReadOnlySpan x) where T : System.Numerics.INumber { throw null; } + public static int IndexOfMaxNumber(System.ReadOnlySpan x) where T : System.Numerics.INumber { throw null; } public static int IndexOfMax(System.ReadOnlySpan x) where T : System.Numerics.INumber { throw null; } + public static int IndexOfMinMagnitudeNumber(System.ReadOnlySpan x) where T : System.Numerics.INumber { throw null; } public static int IndexOfMinMagnitude(System.ReadOnlySpan x) where T : System.Numerics.INumber { throw null; } + public static int IndexOfMinNumber(System.ReadOnlySpan x) where T : System.Numerics.INumber { throw null; } public static int IndexOfMin(System.ReadOnlySpan x) where T : System.Numerics.INumber { throw null; } public static bool IsCanonicalAll(System.ReadOnlySpan x) where T : System.Numerics.INumberBase { throw null; } public static bool IsCanonicalAny(System.ReadOnlySpan x) where T : System.Numerics.INumberBase { throw null; } diff --git a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj index ca14c969a3494d..c5c90cf7e47796 100644 --- a/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj +++ b/src/libraries/System.Numerics.Tensors/src/System.Numerics.Tensors.csproj @@ -111,8 +111,12 @@ + + + + diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxMagnitudeNumber.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxMagnitudeNumber.cs new file mode 100644 index 00000000000000..58d2a1310e75d7 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxMagnitudeNumber.cs @@ -0,0 +1,152 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using static System.Numerics.Tensors.TensorOperation; + +namespace System.Numerics.Tensors +{ + public static partial class TensorPrimitives + { + /// Searches for the index of the non-NaN number with the largest magnitude in the specified tensor. + /// The tensor, represented as a span. + /// The index of the element in with the largest magnitude (absolute value), or -1 if is empty or only contain NaN-values. + /// + /// + /// The determination of the maximum magnitude matches the IEEE 754:2019 `maximumMagnitudeNumber` function. NaN-values are ignored. + /// If two values have the same magnitude and one is positive and the other is negative, the positive value is considered to have the larger magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMaxMagnitudeNumber(ReadOnlySpan x) + where T : INumber => + IndexOfMinMaxCore>(x); + + internal readonly struct IndexOfMaxMagnitudeNumberOperator : IIndexOfMinMaxOperator where T : INumber + { + public static bool ShouldEarlyExitOnNan => false; + public static T Aggregate(Vector128 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector256 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector512 x) => HorizontalAggregate>(x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool Compare(T x, T y) + { + if (T.IsNaN(x)) + { + return false; + } + else if (T.IsNaN(y)) + { + return true; + } + + // Don't use T.Abs since it can throw OverflowException. + T result = T.MaxMagnitude(x, y); + if (result == x) + { + if (result == y) + { + // x and y are equal in magnitude + return T.IsPositive(x) && T.IsNegative(y); + } + else + { + // x == result && y != result means x has larger magnitude than y. + return true; + } + } + else + { + return false; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Compare(Vector128 x, Vector128 y) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) + { + Vector128 equalResult = Vector128.IsPositive(x) & Vector128.IsNegative(y); + Vector128 notNanResult = Vector128.GreaterThan(xMag, yMag) | (Vector128.Equals(xMag, yMag) & equalResult); + return notNanResult | Vector128.IsNaN(y); // notNanResult will be false if x is NaN + } + else if (typeof(T) == typeof(sbyte) + || typeof(T) == typeof(short) + || typeof(T) == typeof(int) + || typeof(T) == typeof(long) + || typeof(T) == typeof(nint)) + { + // Consider overflows (when IsNegative(Abs(x))) from Abs(MinValue) which implies maximum magnitude. + Vector128 equalResult = Vector128.IsPositive(x) & Vector128.IsNegative(y); + Vector128 nonOverflowResult = Vector128.GreaterThan(xMag, yMag) | (Vector128.Equals(xMag, yMag) & equalResult); + return Vector128.AndNot(nonOverflowResult | Vector128.IsNegative(xMag), Vector128.IsNegative(yMag)); + } + else + { + return Vector128.GreaterThan(xMag, yMag); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Compare(Vector256 x, Vector256 y) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) + { + Vector256 equalResult = Vector256.IsPositive(x) & Vector256.IsNegative(y); + Vector256 notNanResult = Vector256.GreaterThan(xMag, yMag) | (Vector256.Equals(xMag, yMag) & equalResult); + return notNanResult | Vector256.IsNaN(y); // notNanResult will be false if x is NaN + } + else if (typeof(T) == typeof(sbyte) + || typeof(T) == typeof(short) + || typeof(T) == typeof(int) + || typeof(T) == typeof(long) + || typeof(T) == typeof(nint)) + { + // Consider overflows (when IsNegative(Abs(x))) from Abs(MinValue) which implies maximum magnitude. + Vector256 equalResult = Vector256.IsPositive(x) & Vector256.IsNegative(y); + Vector256 nonOverflowResult = Vector256.GreaterThan(xMag, yMag) | (Vector256.Equals(xMag, yMag) & equalResult); + return Vector256.AndNot(nonOverflowResult | Vector256.IsNegative(xMag), Vector256.IsNegative(yMag)); + } + else + { + return Vector256.GreaterThan(xMag, yMag); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Compare(Vector512 x, Vector512 y) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) + { + Vector512 equalResult = Vector512.IsPositive(x) & Vector512.IsNegative(y); + Vector512 notNanResult = Vector512.GreaterThan(xMag, yMag) | (Vector512.Equals(xMag, yMag) & equalResult); + return notNanResult | Vector512.IsNaN(y); // notNanResult will be false if x is NaN + } + else if (typeof(T) == typeof(sbyte) + || typeof(T) == typeof(short) + || typeof(T) == typeof(int) + || typeof(T) == typeof(long) + || typeof(T) == typeof(nint)) + { + // Consider overflows (when IsNegative(Abs(x))) from Abs(MinValue) which implies maximum magnitude. + Vector512 equalResult = Vector512.IsPositive(x) & Vector512.IsNegative(y); + Vector512 nonOverflowResult = Vector512.GreaterThan(xMag, yMag) | (Vector512.Equals(xMag, yMag) & equalResult); + return Vector512.AndNot(nonOverflowResult | Vector512.IsNegative(xMag), Vector512.IsNegative(yMag)); + } + else + { + return Vector512.GreaterThan(xMag, yMag); + } + } + } + } +} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxNumber.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxNumber.cs new file mode 100644 index 00000000000000..75c02577db69d7 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMaxNumber.cs @@ -0,0 +1,104 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; + +namespace System.Numerics.Tensors +{ + public static partial class TensorPrimitives + { + /// Searches for the index of the largest non-NaN number in the specified tensor. + /// The tensor, represented as a span. + /// The index of the maximum element in , or -1 if is empty or only contain NaN-values. + /// + /// + /// The determination of the maximum element matches the IEEE 754:2019 `maximumNumber` function. NaN-values are ignored. + /// Positive 0 is considered greater than negative 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMaxNumber(ReadOnlySpan x) + where T : INumber => + IndexOfMinMaxCore>(x); + + internal readonly struct IndexOfMaxNumberOperator : IIndexOfMinMaxOperator where T : INumber + { + public static bool ShouldEarlyExitOnNan => false; + public static T Aggregate(Vector128 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector256 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector512 x) => HorizontalAggregate>(x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool Compare(T x, T y) + { + if (T.IsNaN(x)) + { + return false; + } + else if (T.IsNaN(y)) + { + return true; + } + + if (x == y) + { + return T.IsPositive(x) && T.IsNegative(y); + } + else + { + return x > y; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Compare(Vector128 x, Vector128 y) + { + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) + { + Vector128 equalResult = Vector128.IsPositive(x) & Vector128.IsNegative(y); + Vector128 notNanResult = Vector128.GreaterThan(x, y) | (Vector128.Equals(x, y) & equalResult); + return notNanResult | Vector128.IsNaN(y); // notNanResult will be false if x is NaN + } + else + { + return Vector128.GreaterThan(x, y); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Compare(Vector256 x, Vector256 y) + { + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) + { + Vector256 equalResult = Vector256.IsPositive(x) & Vector256.IsNegative(y); + Vector256 notNanResult = Vector256.GreaterThan(x, y) | (Vector256.Equals(x, y) & equalResult); + return notNanResult | Vector256.IsNaN(y); // notNanResult will be false if x is NaN + } + else + { + return Vector256.GreaterThan(x, y); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Compare(Vector512 x, Vector512 y) + { + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) + { + Vector512 equalResult = Vector512.IsPositive(x) & Vector512.IsNegative(y); + Vector512 notNanResult = Vector512.GreaterThan(x, y) | (Vector512.Equals(x, y) & equalResult); + return notNanResult | Vector512.IsNaN(y); // notNanResult will be false if x is NaN + } + else + { + return Vector512.GreaterThan(x, y); + } + } + } + } +} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinMagnitudeNumber.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinMagnitudeNumber.cs new file mode 100644 index 00000000000000..10e55c2cc0552a --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinMagnitudeNumber.cs @@ -0,0 +1,149 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using static System.Numerics.Tensors.TensorOperation; + +namespace System.Numerics.Tensors +{ + public static partial class TensorPrimitives + { + /// Searches for the index of the non-NaN number with the smallest magnitude in the specified tensor. + /// The tensor, represented as a span. + /// The index of the element in with the smallest magnitude (absolute value), or -1 if is empty or only contain NaN-values. + /// + /// + /// The determination of the minimum magnitude matches the IEEE 754:2019 `minimumMagnitudeNumber` function. NaN-values are ignored. + /// If two values have the same magnitude and one is positive and the other is negative, the negative value is considered to have the smaller magnitude. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMinMagnitudeNumber(ReadOnlySpan x) + where T : INumber => + IndexOfMinMaxCore>(x); + + internal readonly struct IndexOfMinMagnitudeNumberOperator : IIndexOfMinMaxOperator where T : INumber + { + public static bool ShouldEarlyExitOnNan => false; + public static T Aggregate(Vector128 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector256 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector512 x) => HorizontalAggregate>(x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool Compare(T x, T y) + { + if (T.IsNaN(x)) + { + return false; + } + else if (T.IsNaN(y)) + { + return true; + } + + // Don't use T.Abs since it can throw OverflowException. + T result = T.MinMagnitude(x, y); + if (result == x) + { + if (result == y) + { + // x and y are equal in magnitude + return T.IsNegative(x) && T.IsPositive(y); + } + else + { + // x == result && y != result means x has lesser magnitude than y. + return true; + } + } + else + { + return false; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Compare(Vector128 x, Vector128 y) + { + Vector128 xMag = Vector128.Abs(x), yMag = Vector128.Abs(y); + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) + { + Vector128 equalResult = Vector128.IsNegative(x) & Vector128.IsPositive(y); + Vector128 notNanResult = Vector128.LessThan(xMag, yMag) | (Vector128.Equals(xMag, yMag) & equalResult); + return notNanResult | Vector128.IsNaN(y); // notNanResult will be false if x is NaN + } + else if (typeof(T) == typeof(sbyte) + || typeof(T) == typeof(short) + || typeof(T) == typeof(int) + || typeof(T) == typeof(long) + || typeof(T) == typeof(nint)) + { + Vector128 equalResult = Vector128.IsNegative(x) & Vector128.IsPositive(y); + Vector128 nonOverflowResult = Vector128.LessThan(xMag, yMag) | (Vector128.Equals(xMag, yMag) & equalResult); + return Vector128.AndNot(nonOverflowResult | Vector128.IsNegative(yMag), Vector128.IsNegative(xMag)); + } + else + { + return Vector128.LessThan(xMag, yMag); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Compare(Vector256 x, Vector256 y) + { + Vector256 xMag = Vector256.Abs(x), yMag = Vector256.Abs(y); + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) + { + Vector256 equalResult = Vector256.IsNegative(x) & Vector256.IsPositive(y); + Vector256 notNanResult = Vector256.LessThan(xMag, yMag) | (Vector256.Equals(xMag, yMag) & equalResult); + return notNanResult | Vector256.IsNaN(y); // notNanResult will be false if x is NaN + } + else if (typeof(T) == typeof(sbyte) + || typeof(T) == typeof(short) + || typeof(T) == typeof(int) + || typeof(T) == typeof(long) + || typeof(T) == typeof(nint)) + { + Vector256 equalResult = Vector256.IsNegative(x) & Vector256.IsPositive(y); + Vector256 nonOverflowResult = Vector256.LessThan(xMag, yMag) | (Vector256.Equals(xMag, yMag) & equalResult); + return Vector256.AndNot(nonOverflowResult | Vector256.IsNegative(yMag), Vector256.IsNegative(xMag)); + } + else + { + return Vector256.LessThan(xMag, yMag); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Compare(Vector512 x, Vector512 y) + { + Vector512 xMag = Vector512.Abs(x), yMag = Vector512.Abs(y); + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) + { + Vector512 equalResult = Vector512.IsNegative(x) & Vector512.IsPositive(y); + Vector512 notNanResult = Vector512.LessThan(xMag, yMag) | (Vector512.Equals(xMag, yMag) & equalResult); + return notNanResult | Vector512.IsNaN(y); // notNanResult will be false if x is NaN + } + else if (typeof(T) == typeof(sbyte) + || typeof(T) == typeof(short) + || typeof(T) == typeof(int) + || typeof(T) == typeof(long) + || typeof(T) == typeof(nint)) + { + Vector512 equalResult = Vector512.IsNegative(x) & Vector512.IsPositive(y); + Vector512 nonOverflowResult = Vector512.LessThan(xMag, yMag) | (Vector512.Equals(xMag, yMag) & equalResult); + return Vector512.AndNot(nonOverflowResult | Vector512.IsNegative(yMag), Vector512.IsNegative(xMag)); + } + else + { + return Vector512.LessThan(xMag, yMag); + } + } + } + } +} diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinNumber.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinNumber.cs new file mode 100644 index 00000000000000..e4957382751cd4 --- /dev/null +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.IndexOfMinNumber.cs @@ -0,0 +1,105 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using static System.Numerics.Tensors.TensorOperation; + +namespace System.Numerics.Tensors +{ + public static partial class TensorPrimitives + { + /// Searches for the index of the smallest non-NaN number in the specified tensor. + /// The tensor, represented as a span. + /// The index of the minimum element in , or -1 if is empty or only contain NaN-values. + /// + /// + /// The determination of the minimum element matches the IEEE 754:2019 `minimumNumber` function. NaN-values are ignored. + /// Negative 0 is considered smaller than positive 0. + /// + /// + /// This method may call into the underlying C runtime or employ instructions specific to the current architecture. Exact results may differ between different + /// operating systems or architectures. + /// + /// + public static int IndexOfMinNumber(ReadOnlySpan x) + where T : INumber => + IndexOfMinMaxCore>(x); + + internal readonly struct IndexOfMinNumberOperator : IIndexOfMinMaxOperator where T : INumber + { + public static bool ShouldEarlyExitOnNan => false; + public static T Aggregate(Vector128 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector256 x) => HorizontalAggregate>(x); + public static T Aggregate(Vector512 x) => HorizontalAggregate>(x); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool Compare(T x, T y) + { + if (T.IsNaN(x)) + { + return false; + } + else if (T.IsNaN(y)) + { + return true; + } + + if (x == y) + { + return T.IsNegative(x) && T.IsPositive(y); + } + else + { + return x < y; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector128 Compare(Vector128 x, Vector128 y) + { + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) + { + Vector128 equalResult = Vector128.IsNegative(x) & Vector128.IsPositive(y); + Vector128 notNanResult = Vector128.LessThan(x, y) | (Vector128.Equals(x, y) & equalResult); + return notNanResult | Vector128.IsNaN(y); // notNanResult will be false if x is NaN + } + else + { + return Vector128.LessThan(x, y); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector256 Compare(Vector256 x, Vector256 y) + { + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) + { + Vector256 equalResult = Vector256.IsNegative(x) & Vector256.IsPositive(y); + Vector256 notNanResult = Vector256.LessThan(x, y) | (Vector256.Equals(x, y) & equalResult); + return notNanResult | Vector256.IsNaN(y); // notNanResult will be false if x is NaN + } + else + { + return Vector256.LessThan(x, y); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Vector512 Compare(Vector512 x, Vector512 y) + { + if (typeof(T) == typeof(double) || typeof(T) == typeof(float)) + { + Vector512 equalResult = Vector512.IsNegative(x) & Vector512.IsPositive(y); + Vector512 notNanResult = Vector512.LessThan(x, y) | (Vector512.Equals(x, y) & equalResult); + return notNanResult | Vector512.IsNaN(y); // notNanResult will be false if x is NaN + } + else + { + return Vector512.LessThan(x, y); + } + } + } + } +} From eaf0ef36a5695a456b5d02700d4398c06def3518 Mon Sep 17 00:00:00 2001 From: Linus Hamlin Date: Thu, 18 Jun 2026 16:34:35 +0200 Subject: [PATCH 3/3] Add unit tests --- .../tests/TensorPrimitives.Generic.cs | 377 +++++++++++++++++- .../tests/TensorPrimitivesTests.cs | 3 + 2 files changed, 379 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs index f9ed5e03bb1869..031ebc69753184 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitives.Generic.cs @@ -2725,7 +2725,8 @@ public unsafe abstract class GenericNumberTensorPrimitivesTests : TensorPrimi protected override T SumOfSquares(ReadOnlySpan x) => TensorPrimitives.SumOfSquares(x); protected override T ConvertFromSingle(float f) => T.CreateTruncating(f); - protected override bool IsFloatingPoint => typeof(T) == typeof(Half) || base.IsFloatingPoint; + protected override bool IsFloatingPoint => typeof(T) == typeof(NFloat) || typeof(T) == typeof(Half) || base.IsFloatingPoint; + protected override bool IsUnsignedInteger => typeof(T) == typeof(UInt128) || typeof(T) == typeof(nuint) || base.IsUnsignedInteger; protected override T NextRandom() { @@ -2856,6 +2857,380 @@ public void ScalarSpanDestination_ThrowsForOverlappingInputsWithOutputs(ScalarSp } #endregion + #region IndexOfMaxNumber + [Fact] + public void IndexOfMaxNumber_ReturnsNegative1OnEmpty() + { + Assert.Equal(-1, TensorPrimitives.IndexOfMaxNumber(ReadOnlySpan.Empty)); + } + + [Fact] + public void IndexOfMaxNumber_ReturnsNegative1OnOnlyNaNs() + { + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateTensor(tensorLength); + x.Span.Fill(NaN); + Assert.Equal(-1, TensorPrimitives.IndexOfMaxNumber(x.Span)); + }); + } + + [Fact] + public void IndexOfMaxNumber_AllLengths() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = Enumerable.Max(MemoryMarshal.ToEnumerable(x.Memory)); + int actual = TensorPrimitives.IndexOfMaxNumber(x.Span); + Assert.True(actual == expected || (actual < expected && x[actual].Equals(x[expected])), $"{tensorLength} {actual} {expected} {string.Join(",", MemoryMarshal.ToEnumerable(x.Memory))}"); + } + }); + } + + [Fact] + public void IndexOfMaxNumber_NaNsNotReturned() + { + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateTensor(tensorLength); + x.Span.Fill(NaN); + x[expected] = One; + x[tensorLength - 1] = One; + Assert.Equal(expected, TensorPrimitives.IndexOfMaxNumber(x.Span)); + } + }); + } + + [Fact] + public void IndexOfMaxNumber_Negative0LesserThanPositive0() + { + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateTensor(tensorLength); + x.Span.Fill(NegativeZero); + x[expected] = Zero; + x[tensorLength - 1] = Zero; + Assert.Equal(expected, TensorPrimitives.IndexOfMaxNumber(x.Span)); + } + }); + } + + [Fact] + public void IndexOfMaxNumber_IndexAboveMaxValue() + { + var size = IndexOfSizeExceedingMaxValue(); + if (size == null) return; + + using BoundedMemory x = CreateTensor(size.Value); + x.Span.Fill(One); + x.Span[size.Value - 1] = ConvertFromSingle(2); + Assert.Equal(size.Value - 1, TensorPrimitives.IndexOfMaxNumber(x.Span)); + } + #endregion + + #region IndexOfMaxMagnitudeNumber + [Fact] + public void IndexOfMaxMagnitudeNumber_ReturnsNegative1OnEmpty() + { + Assert.Equal(-1, TensorPrimitives.IndexOfMaxMagnitudeNumber(ReadOnlySpan.Empty)); + } + + [Fact] + public void IndexOfMaxMagnitudeNumber_ReturnsNegative1OnOnlyNaNs() + { + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateTensor(tensorLength); + x.Span.Fill(NaN); + Assert.Equal(-1, TensorPrimitives.IndexOfMaxMagnitudeNumber(x.Span)); + }); + } + + [Fact] + public void IndexOfMaxMagnitudeNumber_AllLengths() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + T max = x[0]; + for (int i = 0; i < x.Length; i++) + { + max = T.MaxMagnitudeNumber(max, x[i]); + } + x[expected] = max; + + int actual = TensorPrimitives.IndexOfMaxMagnitudeNumber(x.Span); + + if (actual != expected) + { + Assert.True(actual < expected || Comparer.Default.Compare(x[actual], x[expected]) > 0, $"{tensorLength} {actual} {expected} {string.Join(",", MemoryMarshal.ToEnumerable(x.Memory))}"); + if (IsFloatingPoint) + { + AssertEqualTolerance(x[expected], x[actual], Zero); + } + else + { + Assert.Equal(x[expected], x[actual]); + } + } + } + }); + } + + [Fact] + public void IndexOfMaxMagnitudeNumber_NaNsNotReturned() + { + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateTensor(tensorLength); + x.Span.Fill(NaN); + x[expected] = One; + x[tensorLength - 1] = One; + Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitudeNumber(x.Span)); + } + }); + } + + [Fact] + public void IndexOfMaxMagnitudeNumber_Negative1LesserThanPositive1() + { + if (IsUnsignedInteger) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateTensor(tensorLength); + x.Span.Fill(NegativeOne); + x[expected] = One; + x[tensorLength - 1] = One; + Assert.Equal(expected, TensorPrimitives.IndexOfMaxMagnitudeNumber(x.Span)); + } + }); + } + + [Fact] + public void IndexOfMaxMagnitudeNumber_IndexAboveMaxValue() + { + var size = IndexOfSizeExceedingMaxValue(); + if (size == null) return; + + using BoundedMemory x = CreateTensor(size.Value); + x.Span.Fill(One); + x.Span[size.Value - 1] = ConvertFromSingle(2); + Assert.Equal(size.Value - 1, TensorPrimitives.IndexOfMaxMagnitudeNumber(x.Span)); + } + #endregion + + #region IndexOfMinNumber + [Fact] + public void IndexOfMinNumber_ReturnsNegative1OnEmpty() + { + Assert.Equal(-1, TensorPrimitives.IndexOfMinNumber(ReadOnlySpan.Empty)); + } + + [Fact] + public void IndexOfMinNumber_ReturnsNegative1OnOnlyNaNs() + { + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateTensor(tensorLength); + x.Span.Fill(NaN); + Assert.Equal(-1, TensorPrimitives.IndexOfMinNumber(x.Span)); + }); + } + + [Fact] + public void IndexOfMinNumber_AllLengths() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + x[expected] = Enumerable.Min(MemoryMarshal.ToEnumerable(x.Memory)); + int actual = TensorPrimitives.IndexOfMinNumber(x.Span); + Assert.True(actual == expected || (actual < expected && x[actual].Equals(x[expected])), $"{tensorLength} {actual} {expected} {string.Join(",", MemoryMarshal.ToEnumerable(x.Memory))}"); + } + }); + } + + [Fact] + public void IndexOfMinNumber_NaNsNotReturned() + { + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateTensor(tensorLength); + x.Span.Fill(NaN); + x[expected] = One; + x[tensorLength - 1] = One; + Assert.Equal(expected, TensorPrimitives.IndexOfMinNumber(x.Span)); + } + }); + } + + [Fact] + public void IndexOfMinNumber_Negative0LesserThanPositive0() + { + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateTensor(tensorLength); + x.Span.Fill(Zero); + x[expected] = NegativeZero; + x[tensorLength - 1] = NegativeZero; + Assert.Equal(expected, TensorPrimitives.IndexOfMinNumber(x.Span)); + } + }); + } + + [Fact] + public void IndexOfMinNumber_IndexAboveMaxValue() + { + var size = IndexOfSizeExceedingMaxValue(); + if (size == null) return; + + using BoundedMemory x = CreateTensor(size.Value); + x.Span.Fill(ConvertFromSingle(2)); + x.Span[size.Value - 1] = One; + Assert.Equal(size.Value - 1, TensorPrimitives.IndexOfMinNumber(x.Span)); + } + #endregion + + #region IndexOfMinMagnitudeNumber + [Fact] + public void IndexOfMinMagnitudeNumber_ReturnsNegative1OnEmpty() + { + Assert.Equal(-1, TensorPrimitives.IndexOfMinMagnitudeNumber(ReadOnlySpan.Empty)); + } + + [Fact] + public void IndexOfMinMagnitudeNumber_ReturnsNegative1OnOnlyNaNs() + { + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + using BoundedMemory x = CreateTensor(tensorLength); + x.Span.Fill(NaN); + Assert.Equal(-1, TensorPrimitives.IndexOfMinMagnitudeNumber(x.Span)); + }); + } + + [Fact] + public void IndexOfMinMagnitudeNumber_AllLengths() + { + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateAndFillTensor(tensorLength); + + T min = x[0]; + for (int i = 0; i < x.Length; i++) + { + min = T.MinMagnitudeNumber(min, x[i]); + } + x[expected] = min; + + int actual = TensorPrimitives.IndexOfMinMagnitudeNumber(x.Span); + + if (actual != expected) + { + Assert.True(actual < expected || Comparer.Default.Compare(x[actual], x[expected]) > 0, $"{tensorLength} {actual} {expected} {string.Join(",", MemoryMarshal.ToEnumerable(x.Memory))}"); + if (IsFloatingPoint) + { + AssertEqualTolerance(x[expected], x[actual], Zero); + } + else + { + Assert.Equal(x[expected], x[actual]); + } + } + } + }); + } + + [Fact] + public void IndexOfMinMagnitudeNumber_NaNsNotReturned() + { + if (!IsFloatingPoint) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateTensor(tensorLength); + x.Span.Fill(NaN); + x[expected] = One; + x[tensorLength - 1] = One; + Assert.Equal(expected, TensorPrimitives.IndexOfMinMagnitudeNumber(x.Span)); + } + }); + } + + [Fact] + public void IndexOfMinMagnitudeNumber_Negative1LesserThanPositive1() + { + if (IsUnsignedInteger) return; + + Assert.All(Helpers.TensorLengths, tensorLength => + { + foreach (int expected in new[] { 0, tensorLength / 2, tensorLength - 1 }) + { + using BoundedMemory x = CreateTensor(tensorLength); + x.Span.Fill(One); + x[expected] = NegativeOne; + x[tensorLength - 1] = NegativeOne; + Assert.Equal(expected, TensorPrimitives.IndexOfMinMagnitudeNumber(x.Span)); + } + }); + } + + [Fact] + public void IndexOfMinMagnitudeNumber_IndexAboveMaxValue() + { + var size = IndexOfSizeExceedingMaxValue(); + if (size == null) return; + + using BoundedMemory x = CreateTensor(size.Value); + x.Span.Fill(ConvertFromSingle(2)); + x.Span[size.Value - 1] = One; + Assert.Equal(size.Value - 1, TensorPrimitives.IndexOfMinMagnitudeNumber(x.Span)); + } + #endregion + #region IsXx public static IEnumerable SpanDestinationIsFunctionsToTest() { diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs index 55dc2be94487d9..8407232ed1bf59 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs @@ -105,6 +105,9 @@ public abstract class TensorPrimitivesTests where T : unmanaged, IEquatable x); protected virtual bool IsFloatingPoint => typeof(T) == typeof(float) || typeof(T) == typeof(double); + protected virtual bool IsUnsignedInteger => + typeof(T) == typeof(byte) || typeof(T) == typeof(ushort) || typeof(T) == typeof(uint) || + typeof(T) == typeof(ulong) || typeof(T) == typeof(char); protected virtual int? IndexOfSizeExceedingMaxValue() => (typeof(T) == typeof(byte) || typeof(T) == typeof(sbyte)) ? Helpers.SizeGreaterThanByte :