Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 19b1211

Browse files
committed
Support recursive tiling, tidy up logging
1 parent 01d6290 commit 19b1211

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

OnnxStack.ImageUpscaler/Pipelines/ImageUpscalePipeline.cs

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
using Microsoft.Extensions.Logging;
22
using Microsoft.ML.OnnxRuntime.Tensors;
3-
using Newtonsoft.Json.Linq;
43
using OnnxStack.Core;
54
using OnnxStack.Core.Config;
65
using OnnxStack.Core.Image;
@@ -11,7 +10,6 @@
1110
using System.Collections.Generic;
1211
using System.IO;
1312
using System.Linq;
14-
using System.Numerics.Tensors;
1513
using System.Runtime.CompilerServices;
1614
using System.Threading;
1715
using System.Threading.Tasks;
@@ -73,9 +71,9 @@ public async Task UnloadAsync()
7371
/// <returns></returns>
7472
public async Task<DenseTensor<float>> RunAsync(DenseTensor<float> inputImage, CancellationToken cancellationToken = default)
7573
{
76-
var timestamp = _logger?.LogBegin("Upscale image..");
74+
var timestamp = _logger?.LogBegin("Upscale DenseTensor..");
7775
var result = await UpscaleTensorAsync(inputImage, cancellationToken);
78-
_logger?.LogEnd("Upscale image complete.", timestamp);
76+
_logger?.LogEnd("Upscale DenseTensor complete.", timestamp);
7977
return result;
8078
}
8179

@@ -88,9 +86,9 @@ public async Task<DenseTensor<float>> RunAsync(DenseTensor<float> inputImage, Ca
8886
/// <returns></returns>
8987
public async Task<OnnxImage> RunAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
9088
{
91-
var timestamp = _logger?.LogBegin("Upscale image..");
89+
var timestamp = _logger?.LogBegin("Upscale OnnxImage..");
9290
var result = await UpscaleImageAsync(inputImage, cancellationToken);
93-
_logger?.LogEnd("Upscale image complete.", timestamp);
91+
_logger?.LogEnd("Upscale OnnxImage complete.", timestamp);
9492
return result;
9593
}
9694

@@ -103,7 +101,7 @@ public async Task<OnnxImage> RunAsync(OnnxImage inputImage, CancellationToken ca
103101
/// <returns></returns>
104102
public async Task<OnnxVideo> RunAsync(OnnxVideo inputVideo, CancellationToken cancellationToken = default)
105103
{
106-
var timestamp = _logger?.LogBegin("Upscale video..");
104+
var timestamp = _logger?.LogBegin("Upscale OnnxVideo..");
107105
var upscaledFrames = new List<OnnxImage>();
108106
foreach (var videoFrame in inputVideo.Frames)
109107
{
@@ -117,7 +115,7 @@ public async Task<OnnxVideo> RunAsync(OnnxVideo inputVideo, CancellationToken ca
117115
Height = firstFrame.Height,
118116
};
119117

120-
_logger?.LogEnd("Upscale video complete.", timestamp);
118+
_logger?.LogEnd("Upscale OnnxVideo complete.", timestamp);
121119
return new OnnxVideo(videoInfo, upscaledFrames);
122120
}
123121

@@ -130,16 +128,15 @@ public async Task<OnnxVideo> RunAsync(OnnxVideo inputVideo, CancellationToken ca
130128
/// <returns></returns>
131129
public async IAsyncEnumerable<OnnxImage> RunAsync(IAsyncEnumerable<OnnxImage> imageFrames, [EnumeratorCancellation] CancellationToken cancellationToken = default)
132130
{
133-
var timestamp = _logger?.LogBegin("Upscale video stream..");
131+
var timestamp = _logger?.LogBegin("Upscale OnnxImage stream..");
134132
await foreach (var imageFrame in imageFrames)
135133
{
136134
yield return await UpscaleImageAsync(imageFrame, cancellationToken);
137135
}
138-
_logger?.LogEnd("Upscale video stream complete.", timestamp);
136+
_logger?.LogEnd("Upscale OnnxImage stream complete.", timestamp);
139137
}
140138

141139

142-
143140
/// <summary>
144141
/// Upscales the OnnxImage.
145142
/// </summary>
@@ -149,23 +146,25 @@ public async IAsyncEnumerable<OnnxImage> RunAsync(IAsyncEnumerable<OnnxImage> im
149146
private async Task<OnnxImage> UpscaleImageAsync(OnnxImage inputImage, CancellationToken cancellationToken = default)
150147
{
151148
var inputTensor = inputImage.GetImageTensor(_upscaleModel.NormalizeType, _upscaleModel.Channels);
152-
var outputTensor = await RunInternalAsync(inputTensor, cancellationToken);
149+
var outputTensor = await RunInternalAsync(inputTensor, inputImage.Height, inputImage.Width, cancellationToken);
153150
return new OnnxImage(outputTensor, _upscaleModel.NormalizeType);
154151
}
155152

156153

157154
/// <summary>
158155
/// Upscales the DenseTensor
159156
/// </summary>
160-
/// <param name="inputImage">The input image.</param>
157+
/// <param name="inputTensor">The input Tensor.</param>
161158
/// <param name="cancellationToken">The cancellation token.</param>
162159
/// <returns></returns>
163-
public async Task<DenseTensor<float>> UpscaleTensorAsync(DenseTensor<float> inputImage, CancellationToken cancellationToken = default)
160+
public async Task<DenseTensor<float>> UpscaleTensorAsync(DenseTensor<float> inputTensor, CancellationToken cancellationToken = default)
164161
{
165162
if (_upscaleModel.NormalizeInput && _upscaleModel.NormalizeType == ImageNormalizeType.ZeroToOne)
166-
inputImage.NormalizeOneOneToZeroOne();
163+
inputTensor.NormalizeOneOneToZeroOne();
167164

168-
var result = await RunInternalAsync(inputImage, cancellationToken);
165+
var height = inputTensor.Dimensions[2];
166+
var width = inputTensor.Dimensions[3];
167+
var result = await RunInternalAsync(inputTensor, height, width, cancellationToken);
169168

170169
if (_upscaleModel.NormalizeInput && _upscaleModel.NormalizeType == ImageNormalizeType.ZeroToOne)
171170
result.NormalizeZeroOneToOneOne();
@@ -180,9 +179,9 @@ public async Task<DenseTensor<float>> UpscaleTensorAsync(DenseTensor<float> inpu
180179
/// <param name="inputTensor">The input tensor.</param>
181180
/// <param name="cancellationToken">The cancellation token.</param>
182181
/// <returns></returns>
183-
private async Task<DenseTensor<float>> RunInternalAsync(DenseTensor<float> inputTensor, CancellationToken cancellationToken = default)
182+
private async Task<DenseTensor<float>> RunInternalAsync(DenseTensor<float> inputTensor, int height, int width, CancellationToken cancellationToken = default)
184183
{
185-
if (inputTensor.Dimensions[2] <= _upscaleModel.TileSize && inputTensor.Dimensions[3] <= _upscaleModel.TileSize)
184+
if (height <= _upscaleModel.TileSize && width <= _upscaleModel.TileSize)
186185
{
187186
return await RunInferenceAsync(inputTensor, cancellationToken);
188187
}
@@ -193,10 +192,10 @@ private async Task<DenseTensor<float>> RunInternalAsync(DenseTensor<float> input
193192
inputTiles.Width * _upscaleModel.ScaleFactor,
194193
inputTiles.Height * _upscaleModel.ScaleFactor,
195194
inputTiles.Overlap * _upscaleModel.ScaleFactor,
196-
await RunInternalAsync(inputTiles.Tile1, cancellationToken),
197-
await RunInternalAsync(inputTiles.Tile2, cancellationToken),
198-
await RunInternalAsync(inputTiles.Tile3, cancellationToken),
199-
await RunInternalAsync(inputTiles.Tile4, cancellationToken)
195+
await RunInternalAsync(inputTiles.Tile1, inputTiles.Height, inputTiles.Width, cancellationToken),
196+
await RunInternalAsync(inputTiles.Tile2, inputTiles.Height, inputTiles.Width, cancellationToken),
197+
await RunInternalAsync(inputTiles.Tile3, inputTiles.Height, inputTiles.Width, cancellationToken),
198+
await RunInternalAsync(inputTiles.Tile4, inputTiles.Height, inputTiles.Width, cancellationToken)
200199
);
201200
return outputTiles.JoinImageTiles();
202201
}

0 commit comments

Comments
 (0)