Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 5d35a84

Browse files
authored
Merge pull request #48 from saddam213/InstaFlow
InstaFlow Pipeline
2 parents 0ad0bb0 + 7e70a41 commit 5d35a84

File tree

16 files changed

+428
-44
lines changed

16 files changed

+428
-44
lines changed

OnnxStack.Console/appsettings.json

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,46 @@
132132
"OnnxModelPath": "D:\\Repositories\\photon\\vae_decoder\\model.onnx"
133133
}
134134
]
135+
},
136+
{
137+
"Name": "InstaFlow",
138+
"IsEnabled": true,
139+
"PadTokenId": 49407,
140+
"BlankTokenId": 49407,
141+
"TokenizerLimit": 77,
142+
"EmbeddingsLength": 768,
143+
"ScaleFactor": 0.18215,
144+
"PipelineType": "InstaFlow",
145+
"Diffusers": [
146+
"TextToImage"
147+
],
148+
"DeviceId": 0,
149+
"InterOpNumThreads": 0,
150+
"IntraOpNumThreads": 0,
151+
"ExecutionMode": "ORT_SEQUENTIAL",
152+
"ExecutionProvider": "DirectML",
153+
"ModelConfigurations": [
154+
{
155+
"Type": "Tokenizer",
156+
"OnnxModelPath": "D:\\Repositories\\InstaFlow-0.9B-ONNX\\tokenizer\\model.onnx"
157+
},
158+
{
159+
"Type": "Unet",
160+
"OnnxModelPath": "D:\\Repositories\\InstaFlow-0.9B-ONNX\\unet\\model.onnx"
161+
},
162+
{
163+
"Type": "TextEncoder",
164+
"OnnxModelPath": "D:\\Repositories\\InstaFlow-0.9B-ONNX\\text_encoder\\model.onnx"
165+
},
166+
{
167+
"Type": "VaeEncoder",
168+
"OnnxModelPath": "D:\\Repositories\\InstaFlow-0.9B-ONNX\\vae_encoder\\model.onnx"
169+
},
170+
{
171+
"Type": "VaeDecoder",
172+
"OnnxModelPath": "D:\\Repositories\\InstaFlow-0.9B-ONNX\\vae_decoder\\model.onnx"
173+
}
174+
]
135175
}
136176
]
137177
}

OnnxStack.StableDiffusion/Common/IScheduler.cs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
2-
using OnnxStack.StableDiffusion.Enums;
32
using OnnxStack.StableDiffusion.Schedulers;
43
using System;
54
using System.Collections.Generic;
@@ -8,11 +7,6 @@ namespace OnnxStack.StableDiffusion.Common
87
{
98
public interface IScheduler : IDisposable
109
{
11-
/// <summary>
12-
/// Gets the compatible pipeline
13-
/// </summary>
14-
DiffuserPipelineType PipelineType { get; }
15-
1610
/// <summary>
1711
/// Gets the initial noise sigma.
1812
/// </summary>
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core;
4+
using OnnxStack.Core.Config;
5+
using OnnxStack.Core.Model;
6+
using OnnxStack.Core.Services;
7+
using OnnxStack.StableDiffusion.Common;
8+
using OnnxStack.StableDiffusion.Config;
9+
using OnnxStack.StableDiffusion.Enums;
10+
using OnnxStack.StableDiffusion.Helpers;
11+
using OnnxStack.StableDiffusion.Schedulers.InstaFlow;
12+
using System;
13+
using System.Diagnostics;
14+
using System.Linq;
15+
using System.Threading;
16+
using System.Threading.Tasks;
17+
18+
namespace OnnxStack.StableDiffusion.Diffusers.InstaFlow
19+
{
20+
public abstract class InstaFlowDiffuser : DiffuserBase, IDiffuser
21+
{
22+
/// <summary>
23+
/// Initializes a new instance of the <see cref="InstaFlowDiffuser"/> class.
24+
/// </summary>
25+
/// <param name="configuration">The configuration.</param>
26+
/// <param name="onnxModelService">The onnx model service.</param>
27+
public InstaFlowDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<InstaFlowDiffuser> logger)
28+
: base(onnxModelService, promptService, logger) { }
29+
30+
31+
/// <summary>
32+
/// Gets the type of the pipeline.
33+
/// </summary>
34+
public override DiffuserPipelineType PipelineType => DiffuserPipelineType.InstaFlow;
35+
36+
37+
/// <summary>
38+
/// Runs the scheduler steps.
39+
/// </summary>
40+
/// <param name="modelOptions">The model options.</param>
41+
/// <param name="promptOptions">The prompt options.</param>
42+
/// <param name="schedulerOptions">The scheduler options.</param>
43+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
44+
/// <param name="performGuidance">if set to <c>true</c> [perform guidance].</param>
45+
/// <param name="progressCallback">The progress callback.</param>
46+
/// <param name="cancellationToken">The cancellation token.</param>
47+
/// <returns></returns>
48+
protected override async Task<DenseTensor<float>> SchedulerStepAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, DenseTensor<float> promptEmbeddings, bool performGuidance, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
49+
{
50+
// Get Scheduler
51+
using (var scheduler = GetScheduler(schedulerOptions))
52+
{
53+
// Get timesteps
54+
var timesteps = GetTimesteps(schedulerOptions, scheduler);
55+
56+
// Create latent sample
57+
var latents = await PrepareLatentsAsync(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
58+
59+
// Get Model metadata
60+
var metadata = _onnxModelService.GetModelMetadata(modelOptions, OnnxModelType.Unet);
61+
62+
// Get the distilled Timestep
63+
var distilledTimestep = 1.0f / timesteps.Count;
64+
65+
// Loop though the timesteps
66+
var step = 0;
67+
foreach (var timestep in timesteps)
68+
{
69+
step++;
70+
var stepTime = Stopwatch.GetTimestamp();
71+
cancellationToken.ThrowIfCancellationRequested();
72+
73+
// Create input tensor.
74+
var inputLatent = performGuidance ? latents.Repeat(2) : latents;
75+
var inputTensor = scheduler.ScaleInput(inputLatent, timestep);
76+
var timestepTensor = CreateTimestepTensor(inputLatent, timestep);
77+
78+
var outputChannels = performGuidance ? 2 : 1;
79+
var outputDimension = schedulerOptions.GetScaledDimension(outputChannels);
80+
using (var inferenceParameters = new OnnxInferenceParameters(metadata))
81+
{
82+
inferenceParameters.AddInputTensor(inputTensor);
83+
inferenceParameters.AddInputTensor(timestepTensor);
84+
inferenceParameters.AddInputTensor(promptEmbeddings);
85+
inferenceParameters.AddOutputBuffer(outputDimension);
86+
87+
var results = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inferenceParameters);
88+
using (var result = results.First())
89+
{
90+
var noisePred = result.ToDenseTensor();
91+
92+
// Perform guidance
93+
if (performGuidance)
94+
noisePred = PerformGuidance(noisePred, schedulerOptions.GuidanceScale);
95+
96+
// Scheduler Step
97+
latents = scheduler.Step(noisePred, timestep, latents).Result;
98+
99+
latents = noisePred
100+
.MultiplyTensorByFloat(distilledTimestep)
101+
.AddTensors(latents);
102+
}
103+
}
104+
105+
progressCallback?.Invoke(step, timesteps.Count);
106+
_logger?.LogEnd($"Step {step}/{timesteps.Count}", stepTime);
107+
}
108+
109+
// Decode Latents
110+
return await DecodeLatentsAsync(modelOptions, promptOptions, schedulerOptions, latents);
111+
}
112+
}
113+
114+
115+
/// <summary>
116+
/// Creates the timestep tensor.
117+
/// </summary>
118+
/// <param name="latents">The latents.</param>
119+
/// <param name="timestep">The timestep.</param>
120+
/// <returns></returns>
121+
private DenseTensor<float> CreateTimestepTensor(DenseTensor<float> latents, int timestep)
122+
{
123+
var timestepTensor = new DenseTensor<float>(new[] { latents.Dimensions[0] });
124+
timestepTensor.Fill(timestep);
125+
return timestepTensor;
126+
}
127+
128+
129+
/// <summary>
130+
/// Gets the scheduler.
131+
/// </summary>
132+
/// <param name="options">The options.</param>
133+
/// <param name="schedulerConfig">The scheduler configuration.</param>
134+
/// <returns></returns>
135+
protected override IScheduler GetScheduler(SchedulerOptions options)
136+
{
137+
return options.SchedulerType switch
138+
{
139+
SchedulerType.InstaFlow => new InstaFlowScheduler(options),
140+
_ => default
141+
};
142+
}
143+
}
144+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core.Services;
4+
using OnnxStack.StableDiffusion.Common;
5+
using OnnxStack.StableDiffusion.Config;
6+
using OnnxStack.StableDiffusion.Enums;
7+
using System.Collections.Generic;
8+
using System.Threading.Tasks;
9+
10+
namespace OnnxStack.StableDiffusion.Diffusers.InstaFlow
11+
{
12+
public sealed class TextDiffuser : InstaFlowDiffuser
13+
{
14+
/// <summary>
15+
/// Initializes a new instance of the <see cref="TextDiffuser"/> class.
16+
/// </summary>
17+
/// <param name="configuration">The configuration.</param>
18+
/// <param name="onnxModelService">The onnx model service.</param>
19+
public TextDiffuser(IOnnxModelService onnxModelService, IPromptService promptService, ILogger<InstaFlowDiffuser> logger)
20+
: base(onnxModelService, promptService, logger)
21+
{
22+
}
23+
24+
25+
/// <summary>
26+
/// Gets the type of the diffuser.
27+
/// </summary>
28+
public override DiffuserType DiffuserType => DiffuserType.TextToImage;
29+
30+
31+
/// <summary>
32+
/// Gets the timesteps.
33+
/// </summary>
34+
/// <param name="prompt">The prompt.</param>
35+
/// <param name="options">The options.</param>
36+
/// <param name="scheduler">The scheduler.</param>
37+
/// <returns></returns>
38+
protected override IReadOnlyList<int> GetTimesteps(SchedulerOptions options, IScheduler scheduler)
39+
{
40+
return scheduler.Timesteps;
41+
}
42+
43+
44+
/// <summary>
45+
/// Prepares the latents for inference.
46+
/// </summary>
47+
/// <param name="prompt">The prompt.</param>
48+
/// <param name="options">The options.</param>
49+
/// <param name="scheduler">The scheduler.</param>
50+
/// <returns></returns>
51+
protected override Task<DenseTensor<float>> PrepareLatentsAsync(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps)
52+
{
53+
return Task.FromResult(scheduler.CreateRandomSample(options.GetScaledDimension(), scheduler.InitNoiseSigma));
54+
}
55+
}
56+
}

OnnxStack.StableDiffusion/Enums/DiffuserPipelineType.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
public enum DiffuserPipelineType
44
{
55
StableDiffusion = 0,
6-
LatentConsistency = 10
6+
LatentConsistency = 10,
7+
InstaFlow = 11,
78
}
89
}

OnnxStack.StableDiffusion/Enums/SchedulerType.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ public enum SchedulerType
2222
[Display(Name = "KDPM2")]
2323
KDPM2 = 5,
2424

25-
[Display(Name = "LCM")]
26-
LCM = 20
25+
[Display(Name = "LCM")]
26+
LCM = 20,
27+
28+
[Display(Name = "InstaFlow")]
29+
InstaFlow = 21
2730
}
2831
}

OnnxStack.StableDiffusion/Extensions.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
using Microsoft.ML.OnnxRuntime;
2-
using NumSharp;
32
using OnnxStack.StableDiffusion.Config;
43
using OnnxStack.StableDiffusion.Enums;
54

65
using System;
76
using System.Linq;
8-
using System.Numerics;
9-
using System.Threading.Tasks;
107

118
namespace OnnxStack.StableDiffusion
129
{
@@ -102,20 +99,23 @@ public static SchedulerType[] GetSchedulerTypes(this DiffuserPipelineType pipeli
10299
{
103100
return pipelineType switch
104101
{
105-
DiffuserPipelineType.StableDiffusion => new[]
102+
DiffuserPipelineType.InstaFlow => new[]
103+
{
104+
SchedulerType.InstaFlow
105+
},
106+
DiffuserPipelineType.LatentConsistency => new[]
107+
{
108+
SchedulerType.LCM
109+
},
110+
_ => new[]
106111
{
107112
SchedulerType.LMS,
108113
SchedulerType.Euler,
109114
SchedulerType.EulerAncestral,
110115
SchedulerType.DDPM,
111116
SchedulerType.DDIM,
112117
SchedulerType.KDPM2
113-
},
114-
DiffuserPipelineType.LatentConsistency => new[]
115-
{
116-
SchedulerType.LCM
117-
},
118-
_ => default
118+
}
119119
};
120120
}
121121

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
using Microsoft.Extensions.Logging;
2+
using OnnxStack.Core;
3+
using OnnxStack.StableDiffusion.Common;
4+
using OnnxStack.StableDiffusion.Diffusers;
5+
using OnnxStack.StableDiffusion.Enums;
6+
using System.Collections.Concurrent;
7+
using System.Collections.Generic;
8+
using System.Linq;
9+
10+
namespace OnnxStack.StableDiffusion.Pipelines
11+
{
12+
public sealed class InstaFlowPipeline : IPipeline
13+
{
14+
private readonly DiffuserPipelineType _pipelineType;
15+
private readonly ILogger<InstaFlowPipeline> _logger;
16+
private readonly ConcurrentDictionary<DiffuserType, IDiffuser> _diffusers;
17+
18+
/// <summary>
19+
/// Initializes a new instance of the <see cref="InstaFlowPipeline"/> class.
20+
/// </summary>
21+
/// <param name="onnxModelService">The onnx model service.</param>
22+
/// <param name="promptService">The prompt service.</param>
23+
public InstaFlowPipeline(IEnumerable<IDiffuser> diffusers, ILogger<InstaFlowPipeline> logger)
24+
{
25+
_logger = logger;
26+
_pipelineType = DiffuserPipelineType.InstaFlow;
27+
_diffusers = diffusers
28+
.Where(x => x.PipelineType == _pipelineType)
29+
.ToConcurrentDictionary(k => k.DiffuserType, v => v);
30+
}
31+
32+
33+
/// <summary>
34+
/// Gets the type of the pipeline.
35+
/// </summary>
36+
public DiffuserPipelineType PipelineType => _pipelineType;
37+
38+
39+
/// <summary>
40+
/// Gets the diffusers.
41+
/// </summary>
42+
public ConcurrentDictionary<DiffuserType, IDiffuser> Diffusers => _diffusers;
43+
44+
45+
/// <summary>
46+
/// Gets the diffuser.
47+
/// </summary>
48+
/// <param name="diffuserType">Type of the diffuser.</param>
49+
/// <returns></returns>
50+
public IDiffuser GetDiffuser(DiffuserType diffuserType)
51+
{
52+
_diffusers.TryGetValue(diffuserType, out var diffuser);
53+
return diffuser;
54+
}
55+
}
56+
}

0 commit comments

Comments
 (0)