Skip to content

Commit 2126d6b

Browse files
committed
Refactor AI model inference handling and IO bindings
Significantly refactored `AIManager.cs` to improve input and output binding management for model inference. Introduced a new method `InitializeIOBinding` for robust handling of Ort values and tensors, supporting both float and half-precision formats. Updated `PrepareKDTreeData` to remove KD-tree references, now using a reusable list for predictions. Added float conversion methods in `MathUtil.cs` and enhanced `ModelManager.cs` with properties for input names and asynchronous model loading for better performance across different execution providers.
1 parent e0a928c commit 2126d6b

File tree

3 files changed

+384
-29
lines changed

3 files changed

+384
-29
lines changed

Aimmy2/AILogic/AIManager.cs

Lines changed: 258 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,23 @@ public void RequestSizeChange(int newSize)
8686
private float _scaleY => ScreenHeight / (float)IMAGE_SIZE;
8787

8888
// Tensor reuse (model inference)
89-
private DenseTensor<float>? _reusableTensor;
9089
private float[]? _reusableInputArray;
90+
private DenseTensor<float>? _reusableTensor;
9191
private List<NamedOnnxValue>? _reusableInputs;
9292

93+
private ushort[]? _inputU16Buffer;
94+
private ushort[]? _outputU16Buffer;
95+
private float[]? _outputFloatBuffer; // used when output element type == float
96+
97+
private TensorElementType _modelInputElementType = TensorElementType.Float;
98+
private TensorElementType _modelOutputElementType = TensorElementType.Float;
99+
100+
private OrtIoBinding _ioBinding;
101+
private OrtValue _inputOrtValue;
102+
private OrtValue _outputOrtValue;
103+
private bool _ioBindingInitialized = false;
104+
private readonly object _ioBindingLock = new object();
105+
93106
// Benchmarking
94107
private readonly Dictionary<string, BenchmarkData> _benchmarks = new();
95108
private readonly object _benchmarkLock = new();
@@ -249,6 +262,159 @@ public async Task InitializeModel(string modelPath)
249262
}
250263
}
251264
}
265+
// this is a complete mess by the way, needs to be cleaned up at some point
266+
// but it works so idc.
267+
private void InitializeIOBinding(int imageSize)
268+
{
269+
if (_ioBindingInitialized &&
270+
_reusableTensor != null &&
271+
_reusableTensor.Dimensions[2] == imageSize)
272+
{
273+
return;
274+
}
275+
276+
lock (_ioBindingLock)
277+
{
278+
if (_modelManager.onnxModel == null)
279+
{
280+
return;
281+
}
282+
283+
if (Dictionary.dropdownState["Execution Provider"] == "CPU")
284+
{
285+
_ioBindingInitialized = false;
286+
return; // Skip IO Binding for CPU execution
287+
}
288+
try
289+
{
290+
_ioBinding?.Dispose();
291+
_inputOrtValue?.Dispose();
292+
_outputOrtValue?.Dispose();
293+
}
294+
catch { }
295+
296+
_ioBinding = null;
297+
_inputOrtValue = null;
298+
_outputOrtValue = null;
299+
_ioBindingInitialized = false;
300+
301+
try
302+
{
303+
var inputMeta = _modelManager.onnxModel.InputMetadata;
304+
var outputMeta = _modelManager.onnxModel.OutputMetadata;
305+
306+
if (inputMeta != null &&
307+
inputMeta.TryGetValue(_modelManager.inputName ?? inputMeta.Keys.First(), out var inMeta))
308+
{
309+
_modelInputElementType = inMeta.ElementDataType;
310+
}
311+
312+
313+
if (_modelManager.outputNames?.Count > 0 &&
314+
outputMeta != null &&
315+
outputMeta.TryGetValue(_modelManager.outputNames[0], out var outMeta))
316+
{
317+
_modelOutputElementType = outMeta.ElementDataType;
318+
}
319+
320+
Log(LogLevel.Info, $"Model Input element type: {_modelInputElementType}; Output element type: {_modelOutputElementType}");
321+
322+
_ioBinding = _modelManager.onnxModel.CreateIoBinding();
323+
var memoryInfo = OrtMemoryInfo.DefaultInstance; // let ort handle that or whatever
324+
325+
326+
// INPUT
327+
var inputShape = new long[] { 1, 3, imageSize, imageSize };
328+
int inputLen = (int)(inputShape.Aggregate(1L, (a, b) => a * b));
329+
330+
if (_reusableInputArray == null || _reusableInputArray.Length != inputLen)
331+
{
332+
_reusableInputArray = new float[inputLen]; // still used for preprocessing
333+
_reusableTensor = null;
334+
_reusableInputs = null;
335+
}
336+
337+
switch (_modelInputElementType)
338+
{
339+
// should mainly be f16
340+
case TensorElementType.Float:
341+
_inputOrtValue = OrtValue.CreateTensorValueFromMemory<float>(
342+
memoryInfo, _reusableInputArray, inputShape);
343+
_ioBinding.BindInput(_modelManager.inputName ?? "images", _inputOrtValue);
344+
Log(LogLevel.Info, "IOBinding: bound float input");
345+
break;
346+
347+
case TensorElementType.Float16:
348+
_inputU16Buffer ??= new ushort[inputLen];
349+
if (_inputU16Buffer.Length != inputLen)
350+
_inputU16Buffer = new ushort[inputLen];
351+
352+
_inputOrtValue = OrtValue.CreateTensorValueFromMemory<ushort>(
353+
memoryInfo, _inputU16Buffer, inputShape);
354+
_ioBinding.BindInput(_modelManager.inputName ?? "images", _inputOrtValue);
355+
Log(LogLevel.Info, "IOBinding: bound float16 input (ushort buffer)");
356+
break;
357+
358+
default:
359+
// iobindinginit would be false but we will handle that in the inference step
360+
throw new NotSupportedException(
361+
$"Unsupported model input element type: {_modelInputElementType}"); // yikes
362+
}
363+
364+
// OUTPUT
365+
if (_modelManager.outputNames == null || _modelManager.outputNames.Count == 0)
366+
throw new InvalidOperationException("Model output names are not defined.");
367+
368+
var outputShape = new long[] { 1, _modelManager.NUM_CLASSES + 4, _modelManager.NUM_DETECTIONS };
369+
int outputLen = (int)outputShape.Aggregate(1L, (a, b) => a * b);
370+
switch (_modelOutputElementType)
371+
{
372+
case TensorElementType.Float:
373+
_outputFloatBuffer ??= new float[outputLen];
374+
if (_outputFloatBuffer.Length != outputLen)
375+
_outputFloatBuffer = new float[outputLen];
376+
377+
_outputOrtValue = OrtValue.CreateTensorValueFromMemory<float>(
378+
memoryInfo, _outputFloatBuffer, outputShape);
379+
_ioBinding.BindOutput(_modelManager.outputNames[0], _outputOrtValue);
380+
Log(LogLevel.Info, "IOBinding: bound float output");
381+
break;
382+
383+
case TensorElementType.Float16:
384+
_outputU16Buffer ??= new ushort[outputLen];
385+
if (_outputU16Buffer.Length != outputLen)
386+
_outputU16Buffer = new ushort[outputLen];
387+
388+
_outputOrtValue = OrtValue.CreateTensorValueFromMemory<ushort>(
389+
memoryInfo, _outputU16Buffer, outputShape);
390+
_ioBinding.BindOutput(_modelManager.outputNames[0], _outputOrtValue);
391+
Log(LogLevel.Info, "IOBinding: bound float16 output (ushort buffer)");
392+
break;
393+
394+
default:
395+
throw new NotSupportedException(
396+
$"Unsupported model output element type: {_modelOutputElementType}");
397+
}
398+
399+
_ioBindingInitialized = true;
400+
Log(LogLevel.Info, "IO Binding initialized successfully");
401+
}
402+
catch (Exception ex)
403+
{
404+
// cleanup on failure, keep _ioBindingInitialized false
405+
try { _ioBinding?.Dispose(); } catch { }
406+
try { _inputOrtValue?.Dispose(); } catch { }
407+
try { _outputOrtValue?.Dispose(); } catch { }
408+
409+
_ioBinding = null;
410+
_inputOrtValue = null;
411+
_outputOrtValue = null;
412+
_ioBindingInitialized = false;
413+
414+
Log(LogLevel.Error, $"Failed to initialize IO Binding: {ex.Message}");
415+
}
416+
}
417+
}
252418
private void StartAILoop(InferenceSession? onnxModel)
253419
{
254420
if (onnxModel?.OutputMetadata != null && onnxModel.OutputMetadata.Count > 0)
@@ -627,30 +793,96 @@ private void HandlePredictions(KalmanPrediction kalmanPrediction, Prediction clo
627793
BitmapToFloatArrayInPlace(frame, inputArray, IMAGE_SIZE);
628794
}
629795

630-
// Reuse tensor and inputs - recreate if size changed
631-
if (_reusableTensor == null || _reusableTensor.Dimensions[2] != IMAGE_SIZE)
632-
{
633-
_reusableTensor = new DenseTensor<float>(inputArray, new int[] { 1, 3, IMAGE_SIZE, IMAGE_SIZE });
634-
_reusableInputs = new List<NamedOnnxValue> { NamedOnnxValue.CreateFromTensor("images", _reusableTensor) };
635-
}
636-
//else
637-
//{
638-
// // Directly copy into existing DenseTensor buffer
639-
// inputArray.AsSpan().CopyTo(_reusableTensor.Buffer.Span);
640-
//}
641-
642796
if (_modelManager.onnxModel == null)
643797
{
644798
frame.Dispose();
645799
return null; // Model not loaded, exit early
646800
}
647801

648-
//IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results;
802+
if (!_ioBindingInitialized ||
803+
_reusableTensor == null ||
804+
_reusableTensor.Dimensions[2] != IMAGE_SIZE)
805+
{
806+
using (Benchmark("IOBindingInitialization"))
807+
{
808+
InitializeIOBinding(IMAGE_SIZE);
809+
}
810+
}
811+
649812
Tensor<float>? outputTensor = null;
650813
using (Benchmark("ModelInference"))
651814
{
652-
using var results = _modelManager.onnxModel.Run(_reusableInputs, _modelManager.outputNames, _modelManager.modelOptions);
653-
outputTensor = results[0].AsTensor<float>();
815+
try
816+
{
817+
if (_ioBindingInitialized && _inputOrtValue != null && _outputOrtValue != null)
818+
{
819+
// convert float into half precision bc of .net
820+
if (_modelInputElementType == TensorElementType.Float16)
821+
{
822+
int len = _reusableInputArray!.Length;
823+
var inputU16 = _inputU16Buffer!;
824+
for (int i = 0; i < len; i++)
825+
{
826+
// clamp to reasonable range first (some models require [0,1])
827+
float v = _reusableInputArray[i];
828+
inputU16[i] = FloatToHalfBits(v);
829+
}
830+
}
831+
832+
//run inference as per IO Binding
833+
_modelManager.onnxModel.RunWithBinding(_modelManager.modelOptions, _ioBinding);
834+
835+
// im too lazy to turn this into a switch statement
836+
if (_modelOutputElementType == TensorElementType.Float)
837+
{
838+
// get model output
839+
outputTensor = new DenseTensor<float>(_outputFloatBuffer, new int[] { 1, _modelManager.NUM_CLASSES + 4, _modelManager.NUM_DETECTIONS });
840+
}
841+
else if (_modelOutputElementType == TensorElementType.Float16) // usually f16
842+
{
843+
// convert ushort half-bits -> float[] into a temp array
844+
var outU16 = _outputU16Buffer!;
845+
var outFloat = new float[outU16.Length];
846+
for (int i = 0; i < outU16.Length; i++)
847+
outFloat[i] = HalfBitsToFloat(outU16[i]);
848+
849+
outputTensor = new DenseTensor<float>(
850+
outFloat,
851+
new int[] { 1, _modelManager.NUM_CLASSES + 4, _modelManager.NUM_DETECTIONS }
852+
);
853+
}
854+
else
855+
{ // yikes
856+
throw new NotSupportedException($"Unsupported model output element type: {_modelOutputElementType}");
857+
}
858+
}
859+
else
860+
{
861+
// run it without i/o binding
862+
if (_reusableTensor == null || _reusableTensor.Dimensions[2] != IMAGE_SIZE)
863+
{
864+
_reusableTensor = new DenseTensor<float>(_reusableInputArray, new int[] { 1, 3, IMAGE_SIZE, IMAGE_SIZE });
865+
866+
if (_reusableInputs == null)
867+
_reusableInputs = new List<NamedOnnxValue>(1);
868+
869+
_reusableInputs.Clear();
870+
_reusableInputs.Add(NamedOnnxValue.CreateFromTensor(_modelManager.inputName ?? "images", _reusableTensor));
871+
}
872+
else
873+
{
874+
_reusableInputArray.AsSpan().CopyTo(_reusableTensor.Buffer.Span);
875+
}
876+
877+
using var results = _modelManager.onnxModel.Run(_reusableInputs, _modelManager.outputNames, _modelManager.modelOptions);
878+
outputTensor = results[0].AsTensor<float>();
879+
}
880+
}
881+
catch (Exception ex)
882+
{
883+
Log(LogLevel.Error, $"Inference error: {ex.Message}");
884+
_ioBindingInitialized = false; // Reset IO Binding on error
885+
}
654886
}
655887

656888
if (outputTensor == null)
@@ -667,16 +899,18 @@ private void HandlePredictions(KalmanPrediction kalmanPrediction, Prediction clo
667899
float fovMinY = (IMAGE_SIZE - FovSize) / 2.0f;
668900
float fovMaxY = (IMAGE_SIZE + FovSize) / 2.0f;
669901

902+
//we replaced kdtree
670903
//List<double[]> KDpoints;
671904
List<Prediction> KDPredictions;
672-
using (Benchmark("PrepareKDTreeData"))
905+
using (Benchmark("PrepareKDTreeData")) // not really kd tree data anymore
673906
{
674-
KDPredictions = PrepareKDTreeData(outputTensor, detectionBox, fovMinX, fovMaxX, fovMinY, fovMaxY);
907+
KDPredictions = PrepareKDTreeData(outputTensor, detectionBox, fovMinX, fovMaxX, fovMinY, fovMaxY);
675908
}
676909

677910
if (KDPredictions.Count == 0)
678911
{
679912
SaveFrame(frame);
913+
frame.Dispose();
680914
return null;
681915
}
682916

@@ -770,17 +1004,17 @@ private void HandlePredictions(KalmanPrediction kalmanPrediction, Prediction clo
7701004
return bestCandidate;
7711005
}
7721006

773-
1007+
private readonly List<Prediction> _kdPredictions = new(8400);
7741008
private List<Prediction> PrepareKDTreeData(
7751009
Tensor<float> outputTensor,
7761010
Rectangle detectionBox,
7771011
float fovMinX, float fovMaxX, float fovMinY, float fovMaxY)
7781012
{
1013+
_kdPredictions.Clear();
7791014
float minConfidence = (float)Dictionary.sliderSettings["AI Minimum Confidence"] / 100.0f;
7801015
string selectedClass = Dictionary.dropdownState["Target Class"];
7811016
int selectedClassId = -1;
7821017

783-
7841018
int numDetections = _modelManager.NUM_DETECTIONS;
7851019
int numClasses = _modelManager.NUM_CLASSES;
7861020
var modelClasses = _modelManager.modelClasses;
@@ -795,7 +1029,7 @@ private List<Prediction> PrepareKDTreeData(
7951029

7961030

7971031
//var KDpoints = new List<double[]>(_modelManager.NUM_DETECTIONS); // Pre-allocate with estimated capacity
798-
var KDpredictions = new List<Prediction>(numDetections);
1032+
//var KDpredictions = new List<Prediction>(numDetections);
7991033

8001034
for (int i = 0; i < numDetections; i++)
8011035
{
@@ -848,17 +1082,17 @@ private List<Prediction> PrepareKDTreeData(
8481082
Confidence = bestConfidence,
8491083
ClassId = bestClassId,
8501084
ClassName = modelClasses.GetValueOrDefault(bestClassId, $"Class_{bestClassId}"),
851-
CenterXTranslated = x_center / IMAGE_SIZE,
1085+
CenterXTranslated = x_center / IMAGE_SIZE,
8521086
CenterYTranslated = y_center / IMAGE_SIZE,
8531087
ScreenCenterX = detectionBox.Left + x_center,
8541088
ScreenCenterY = detectionBox.Top + y_center
8551089
};
8561090

8571091
//KDpoints.Add(new double[] { x_center, y_center });
858-
KDpredictions.Add(prediction);
1092+
_kdPredictions.Add(prediction);
8591093
}
8601094

861-
return KDpredictions;
1095+
return _kdPredictions;
8621096
}
8631097
private void UpdateDetectionBox(Prediction target, Rectangle detectionBox)
8641098
{

0 commit comments

Comments
 (0)