Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit ea63f96

Browse files
committed
FeatureExtractor normalization
1 parent 19b1211 commit ea63f96

File tree

7 files changed

+117
-55
lines changed

7 files changed

+117
-55
lines changed

OnnxStack.Console/Examples/ControlNetFeatureExample.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public async Task RunAsync()
3535
var inputImage = await OnnxImage.FromFileAsync("D:\\Repositories\\OnnxStack\\Assets\\Samples\\Img2Img_Start.bmp");
3636

3737
// Create Annotation pipeline
38-
var annotationPipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutputTensor: true);
38+
var annotationPipeline = FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutput: true);
3939

4040
// Create Depth Image
4141
var controlImage = await annotationPipeline.RunAsync(inputImage);

OnnxStack.Console/Examples/FeatureExtractorExample.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public async Task RunAsync()
3737
{
3838
FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\canny.onnx"),
3939
FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\hed.onnx"),
40-
FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutputTensor: true, inputResizeMode: ImageResizeMode.Stretch),
40+
FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\controlnet_onnx\\annotators\\depth.onnx", sampleSize: 512, normalizeOutput: true, inputResizeMode: ImageResizeMode.Stretch),
4141
FeatureExtractorPipeline.CreatePipeline("D:\\Repositories\\RMBG-1.4\\onnx\\model.onnx", sampleSize: 1024, setOutputToInputAlpha: true, inputResizeMode: ImageResizeMode.Stretch)
4242
};
4343

OnnxStack.Core/Image/Extensions.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@ public static class Extensions
1616
/// <param name="imageTensor">The image tensor.</param>
1717
/// <returns></returns>
1818
public static OnnxImage ToImageMask(this DenseTensor<float> imageTensor)
19+
{
20+
return new OnnxImage(imageTensor.FromMaskTensor());
21+
}
22+
23+
24+
/// <summary>
25+
/// Convert from single channle mask tensor to Rgba32 (Greyscale)
26+
/// </summary>
27+
/// <param name="imageTensor">The image tensor.</param>
28+
/// <returns></returns>
29+
public static Image<Rgba32> FromMaskTensor(this DenseTensor<float> imageTensor)
1930
{
2031
var width = imageTensor.Dimensions[3];
2132
var height = imageTensor.Dimensions[2];
@@ -28,7 +39,7 @@ public static OnnxImage ToImageMask(this DenseTensor<float> imageTensor)
2839
result[x, y] = new L8((byte)(imageTensor[0, 0, y, x] * 255.0f));
2940
}
3041
}
31-
return new OnnxImage(result.CloneAs<Rgba32>());
42+
return result.CloneAs<Rgba32>();
3243
}
3344
}
3445

OnnxStack.Core/Image/OnnxImage.cs

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,27 +64,35 @@ public OnnxImage(DenseTensor<float> imageTensor, ImageNormalizeType normalizeTyp
6464
{
6565
var height = imageTensor.Dimensions[2];
6666
var width = imageTensor.Dimensions[3];
67-
var hasTransparency = imageTensor.Dimensions[1] == 4;
68-
_imageData = new Image<Rgba32>(width, height);
69-
for (var y = 0; y < height; y++)
67+
var channels = imageTensor.Dimensions[1];
68+
if (channels == 1)
7069
{
71-
for (var x = 0; x < width; x++)
70+
_imageData = imageTensor.FromMaskTensor();
71+
}
72+
else
73+
{
74+
var hasTransparency = channels == 4;
75+
_imageData = new Image<Rgba32>(width, height);
76+
for (var y = 0; y < height; y++)
7277
{
73-
if (normalizeType == ImageNormalizeType.ZeroToOne)
74-
{
75-
_imageData[x, y] = new Rgba32(
76-
DenormalizeZeroToOneToByte(imageTensor, 0, y, x),
77-
DenormalizeZeroToOneToByte(imageTensor, 1, y, x),
78-
DenormalizeZeroToOneToByte(imageTensor, 2, y, x),
79-
hasTransparency ? DenormalizeZeroToOneToByte(imageTensor, 3, y, x) : byte.MaxValue);
80-
}
81-
else
78+
for (var x = 0; x < width; x++)
8279
{
83-
_imageData[x, y] = new Rgba32(
84-
DenormalizeOneToOneToByte(imageTensor, 0, y, x),
85-
DenormalizeOneToOneToByte(imageTensor, 1, y, x),
86-
DenormalizeOneToOneToByte(imageTensor, 2, y, x),
87-
hasTransparency ? DenormalizeOneToOneToByte(imageTensor, 3, y, x) : byte.MaxValue);
80+
if (normalizeType == ImageNormalizeType.ZeroToOne)
81+
{
82+
_imageData[x, y] = new Rgba32(
83+
DenormalizeZeroToOneToByte(imageTensor, 0, y, x),
84+
DenormalizeZeroToOneToByte(imageTensor, 1, y, x),
85+
DenormalizeZeroToOneToByte(imageTensor, 2, y, x),
86+
hasTransparency ? DenormalizeZeroToOneToByte(imageTensor, 3, y, x) : byte.MaxValue);
87+
}
88+
else
89+
{
90+
_imageData[x, y] = new Rgba32(
91+
DenormalizeOneToOneToByte(imageTensor, 0, y, x),
92+
DenormalizeOneToOneToByte(imageTensor, 1, y, x),
93+
DenormalizeOneToOneToByte(imageTensor, 2, y, x),
94+
hasTransparency ? DenormalizeOneToOneToByte(imageTensor, 3, y, x) : byte.MaxValue);
95+
}
8896
}
8997
}
9098
}

OnnxStack.FeatureExtractor/Common/FeatureExtractorModel.cs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,18 @@ public FeatureExtractorModel(FeatureExtractorModelConfig configuration)
1717

1818
public int OutputChannels => _configuration.OutputChannels;
1919
public int SampleSize => _configuration.SampleSize;
20-
public bool NormalizeOutputTensor => _configuration.NormalizeOutputTensor;
20+
public bool NormalizeOutput => _configuration.NormalizeOutput;
2121
public bool SetOutputToInputAlpha => _configuration.SetOutputToInputAlpha;
2222
public ImageResizeMode InputResizeMode => _configuration.InputResizeMode;
23-
public ImageNormalizeType InputNormalization => _configuration.NormalizeInputTensor;
23+
public ImageNormalizeType NormalizeType => _configuration.NormalizeType;
24+
public bool NormalizeInput => _configuration.NormalizeInput;
2425

2526
public static FeatureExtractorModel Create(FeatureExtractorModelConfig configuration)
2627
{
2728
return new FeatureExtractorModel(configuration);
2829
}
2930

30-
public static FeatureExtractorModel Create(string modelFile, int sampleSize = 0, int outputChannels = 1, bool normalizeOutputTensor = false, ImageNormalizeType normalizeInputTensor = ImageNormalizeType.ZeroToOne, ImageResizeMode inputResizeMode = ImageResizeMode.Crop, bool setOutputToInputAlpha = false, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
31+
public static FeatureExtractorModel Create(string modelFile, int sampleSize = 0, int outputChannels = 1, ImageNormalizeType normalizeType = ImageNormalizeType.ZeroToOne, bool normalizeInput = true, bool normalizeOutput = false, ImageResizeMode inputResizeMode = ImageResizeMode.Crop, bool setOutputToInputAlpha = false, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML)
3132
{
3233
var configuration = new FeatureExtractorModelConfig
3334
{
@@ -38,12 +39,12 @@ public static FeatureExtractorModel Create(string modelFile, int sampleSize = 0,
3839
IntraOpNumThreads = 0,
3940
OnnxModelPath = modelFile,
4041

41-
4242
SampleSize = sampleSize,
4343
OutputChannels = outputChannels,
44-
NormalizeOutputTensor = normalizeOutputTensor,
44+
NormalizeType = normalizeType,
45+
NormalizeInput = normalizeInput,
46+
NormalizeOutput = normalizeOutput,
4547
SetOutputToInputAlpha = setOutputToInputAlpha,
46-
NormalizeInputTensor = normalizeInputTensor,
4748
InputResizeMode = inputResizeMode
4849
};
4950
return new FeatureExtractorModel(configuration);

OnnxStack.FeatureExtractor/Common/FeatureExtractorModelConfig.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ public record FeatureExtractorModelConfig : OnnxModelConfig
77
{
88
public int SampleSize { get; set; }
99
public int OutputChannels { get; set; }
10-
public bool NormalizeOutputTensor { get; set; }
10+
public bool NormalizeOutput { get; set; }
1111
public bool SetOutputToInputAlpha { get; set; }
1212
public ImageResizeMode InputResizeMode { get; set; }
13-
public ImageNormalizeType NormalizeInputTensor { get; set; }
13+
public ImageNormalizeType NormalizeType { get; set; }
14+
public bool NormalizeInput { get; set; }
1415
}
1516
}

OnnxStack.FeatureExtractor/Pipelines/FeatureExtractorPipeline.cs

Lines changed: 67 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,30 @@ public async Task UnloadAsync()
6262
}
6363

6464

65+
/// <summary>
66+
/// Generates the feature extractor image
67+
/// </summary>
68+
/// <param name="inputImage">The input image.</param>
69+
/// <returns></returns>
70+
public async Task<DenseTensor<float>> RunAsync(DenseTensor<float> inputTensor, CancellationToken cancellationToken = default)
71+
{
72+
var timestamp = _logger?.LogBegin("Extracting DenseTensor feature...");
73+
var result = await ExtractTensorAsync(inputTensor, cancellationToken);
74+
_logger?.LogEnd("Extracting DenseTensor feature complete.", timestamp);
75+
return result;
76+
}
77+
78+
6579
/// <summary>
6680
/// Generates the feature extractor image
6781
/// </summary>
6882
/// <param name="inputImage">The input image.</param>
6983
/// <returns></returns>
7084
public async Task<OnnxImage> RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
7185
{
72-
var timestamp = _logger?.LogBegin("Extracting image feature...");
73-
var result = await RunInternalAsync(inputImage, cancellationToken);
74-
_logger?.LogEnd("Extracting image feature complete.", timestamp);
86+
var timestamp = _logger?.LogBegin("Extracting OnnxImage feature...");
87+
var result = await ExtractImageAsync(inputImage, cancellationToken);
88+
_logger?.LogEnd("Extracting OnnxImage feature complete.", timestamp);
7589
return result;
7690
}
7791

@@ -83,13 +97,13 @@ public async Task<OnnxImage> RunAsync(OnnxImage inputImage, CancellationToken ca
8397
/// <returns></returns>
8498
public async Task<OnnxVideo> RunAsync(OnnxVideo video, CancellationToken cancellationToken = default)
8599
{
86-
var timestamp = _logger?.LogBegin("Extracting video features...");
100+
var timestamp = _logger?.LogBegin("Extracting OnnxVideo features...");
87101
var featureFrames = new List<OnnxImage>();
88102
foreach (var videoFrame in video.Frames)
89103
{
90104
featureFrames.Add(await RunAsync(videoFrame, cancellationToken));
91105
}
92-
_logger?.LogEnd("Extracting video features complete.", timestamp);
106+
_logger?.LogEnd("Extracting OnnxVideo features complete.", timestamp);
93107
return new OnnxVideo(video.Info, featureFrames);
94108
}
95109

@@ -102,28 +116,62 @@ public async Task<OnnxVideo> RunAsync(OnnxVideo video, CancellationToken cancell
102116
/// <returns></returns>
103117
public async IAsyncEnumerable<OnnxImage> RunAsync(IAsyncEnumerable<OnnxImage> imageFrames, [EnumeratorCancellation] CancellationToken cancellationToken = default)
104118
{
105-
var timestamp = _logger?.LogBegin("Extracting video stream features...");
119+
var timestamp = _logger?.LogBegin("Extracting OnnxImage stream features...");
106120
await foreach (var imageFrame in imageFrames)
107121
{
108-
yield return await RunInternalAsync(imageFrame, cancellationToken);
122+
yield return await ExtractImageAsync(imageFrame, cancellationToken);
109123
}
110-
_logger?.LogEnd("Extracting video stream features complete.", timestamp);
124+
_logger?.LogEnd("Extracting OnnxImage stream features complete.", timestamp);
111125
}
112126

113127

114128
/// <summary>
115-
/// Runs the pipeline
129+
/// Extracts the feature to OnnxImage.
116130
/// </summary>
117131
/// <param name="inputImage">The input image.</param>
118132
/// <param name="cancellationToken">The cancellation token.</param>
119133
/// <returns></returns>
120-
private async Task<OnnxImage> RunInternalAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
134+
private async Task<OnnxImage> ExtractImageAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
121135
{
122136
var originalWidth = inputImage.Width;
123137
var originalHeight = inputImage.Height;
124138
var inputTensor = _featureExtractorModel.SampleSize <= 0
125-
? await inputImage.GetImageTensorAsync(_featureExtractorModel.InputNormalization)
126-
: await inputImage.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, _featureExtractorModel.InputNormalization, resizeMode: _featureExtractorModel.InputResizeMode);
139+
? await inputImage.GetImageTensorAsync(_featureExtractorModel.NormalizeType)
140+
: await inputImage.GetImageTensorAsync(_featureExtractorModel.SampleSize, _featureExtractorModel.SampleSize, _featureExtractorModel.NormalizeType, resizeMode: _featureExtractorModel.InputResizeMode);
141+
142+
var outputTensor = await RunInternalAsync(inputTensor, cancellationToken);
143+
var imageResult = new OnnxImage(outputTensor, _featureExtractorModel.NormalizeType);
144+
145+
if (_featureExtractorModel.InputResizeMode == ImageResizeMode.Stretch && (imageResult.Width != originalWidth || imageResult.Height != originalHeight))
146+
imageResult.Resize(originalHeight, originalWidth, _featureExtractorModel.InputResizeMode);
147+
148+
return imageResult;
149+
}
150+
151+
152+
/// <summary>
153+
/// Extracts the feature to DenseTensor.
154+
/// </summary>
155+
/// <param name="inputTensor">The input tensor.</param>
156+
/// <param name="cancellationToken">The cancellation token.</param>
157+
/// <returns></returns>
158+
public async Task<DenseTensor<float>> ExtractTensorAsync(DenseTensor<float> inputTensor, CancellationToken cancellationToken = default)
159+
{
160+
if (_featureExtractorModel.NormalizeInput && _featureExtractorModel.NormalizeType == ImageNormalizeType.ZeroToOne)
161+
inputTensor.NormalizeOneOneToZeroOne();
162+
163+
return await RunInternalAsync(inputTensor, cancellationToken);
164+
}
165+
166+
167+
/// <summary>
168+
/// Runs the pipeline
169+
/// </summary>
170+
/// <param name="inputTensor">The input tensor.</param>
171+
/// <param name="cancellationToken">The cancellation token.</param>
172+
/// <returns></returns>
173+
private async Task<DenseTensor<float>> RunInternalAsync(DenseTensor<float> inputTensor, CancellationToken cancellationToken = default)
174+
{
127175
var metadata = await _featureExtractorModel.GetMetadataAsync();
128176
cancellationToken.ThrowIfCancellationRequested();
129177
var outputShape = new[] { 1, _featureExtractorModel.OutputChannels, inputTensor.Dimensions[2], inputTensor.Dimensions[3] };
@@ -139,21 +187,13 @@ private async Task<OnnxImage> RunInternalAsync(OnnxImage inputImage, Cancellatio
139187
cancellationToken.ThrowIfCancellationRequested();
140188

141189
var outputTensor = inferenceResult.ToDenseTensor(outputShape);
142-
if (_featureExtractorModel.NormalizeOutputTensor)
190+
if (_featureExtractorModel.NormalizeOutput)
143191
outputTensor.NormalizeMinMax();
144192

145-
var imageResult = default(OnnxImage);
146193
if (_featureExtractorModel.SetOutputToInputAlpha)
147-
imageResult = new OnnxImage(AddAlphaChannel(inputTensor, outputTensor), _featureExtractorModel.InputNormalization);
148-
else if (_featureExtractorModel.OutputChannels >= 3)
149-
imageResult = new OnnxImage(outputTensor, _featureExtractorModel.InputNormalization);
150-
else
151-
imageResult = outputTensor.ToImageMask();
152-
153-
if (_featureExtractorModel.InputResizeMode == ImageResizeMode.Stretch && (imageResult.Width != originalWidth || imageResult.Height != originalHeight))
154-
imageResult.Resize(originalHeight, originalWidth, _featureExtractorModel.InputResizeMode);
194+
return AddAlphaChannel(inputTensor, outputTensor);
155195

156-
return imageResult;
196+
return outputTensor;
157197
}
158198
}
159199
}
@@ -200,7 +240,7 @@ public static FeatureExtractorPipeline CreatePipeline(FeatureExtractorModelSet m
200240
/// <param name="executionProvider">The execution provider.</param>
201241
/// <param name="logger">The logger.</param>
202242
/// <returns></returns>
203-
public static FeatureExtractorPipeline CreatePipeline(string modelFile, int sampleSize = 0, int outputChannels = 1, bool normalizeOutputTensor = false, ImageNormalizeType normalizeInputTensor = ImageNormalizeType.ZeroToOne, ImageResizeMode inputResizeMode = ImageResizeMode.Crop, bool setOutputToInputAlpha = false, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
243+
public static FeatureExtractorPipeline CreatePipeline(string modelFile, int sampleSize = 0, int outputChannels = 1, ImageNormalizeType normalizeType = ImageNormalizeType.ZeroToOne, bool normalizeInput = true, bool normalizeOutput = false, ImageResizeMode inputResizeMode = ImageResizeMode.Crop, bool setOutputToInputAlpha = false, int deviceId = 0, ExecutionProvider executionProvider = ExecutionProvider.DirectML, ILogger logger = default)
204244
{
205245
var name = Path.GetFileNameWithoutExtension(modelFile);
206246
var configuration = new FeatureExtractorModelSet
@@ -214,9 +254,10 @@ public static FeatureExtractorPipeline CreatePipeline(string modelFile, int samp
214254
OnnxModelPath = modelFile,
215255
SampleSize = sampleSize,
216256
OutputChannels = outputChannels,
217-
NormalizeOutputTensor = normalizeOutputTensor,
257+
NormalizeOutput = normalizeOutput,
258+
NormalizeInput = normalizeInput,
259+
NormalizeType = normalizeType,
218260
SetOutputToInputAlpha = setOutputToInputAlpha,
219-
NormalizeInputTensor = normalizeInputTensor,
220261
InputResizeMode = inputResizeMode
221262
}
222263
};

0 commit comments

Comments
 (0)