Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit b39ce21

Browse files
committed
Refactor Diffusers after new Pipeline implementation
1 parent d9c96f9 commit b39ce21

13 files changed

+259
-226
lines changed

OnnxStack.StableDiffusion/Diffusers/IDiffuser.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
22
using OnnxStack.StableDiffusion.Common;
33
using OnnxStack.StableDiffusion.Config;
4-
using SixLabors.ImageSharp;
54
using System;
65
using System.Threading;
76
using System.Threading.Tasks;

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/ImageDiffuser.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010
using System.Collections.Generic;
1111
using System.Linq;
1212

13-
1413
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency
1514
{
16-
public sealed class ImageDiffuser : TextDiffuser
15+
public sealed class ImageDiffuser : LatentConsistencyDiffuser
1716
{
1817
/// <summary>
1918
/// Initializes a new instance of the <see cref="ImageDiffuser"/> class.
@@ -69,6 +68,5 @@ protected override DenseTensor<float> PrepareLatents(IModelOptions model, Prompt
6968
return noisySample;
7069
}
7170
}
72-
7371
}
7472
}
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
using Microsoft.ML.OnnxRuntime;
2+
using Microsoft.ML.OnnxRuntime.Tensors;
3+
using OnnxStack.Core.Config;
4+
using OnnxStack.Core.Services;
5+
using OnnxStack.StableDiffusion.Common;
6+
using OnnxStack.StableDiffusion.Config;
7+
using OnnxStack.StableDiffusion.Enums;
8+
using OnnxStack.StableDiffusion.Helpers;
9+
using OnnxStack.StableDiffusion.Schedulers.LatentConsistency;
10+
using System;
11+
using System.Collections.Generic;
12+
using System.Linq;
13+
using System.Threading;
14+
using System.Threading.Tasks;
15+
16+
namespace OnnxStack.StableDiffusion.Diffusers.LatentConsistency
17+
{
18+
public abstract class LatentConsistencyDiffuser : IDiffuser
19+
{
20+
protected readonly IPromptService _promptService;
21+
protected readonly IOnnxModelService _onnxModelService;
22+
23+
/// <summary>
24+
/// Initializes a new instance of the <see cref="LatentConsistencyDiffuser"/> class.
25+
/// </summary>
26+
/// <param name="configuration">The configuration.</param>
27+
/// <param name="onnxModelService">The onnx model service.</param>
28+
public LatentConsistencyDiffuser(IOnnxModelService onnxModelService, IPromptService promptService)
29+
{
30+
_promptService = promptService;
31+
_onnxModelService = onnxModelService;
32+
}
33+
34+
35+
/// <summary>
36+
/// Gets the timesteps.
37+
/// </summary>
38+
/// <param name="prompt">The prompt.</param>
39+
/// <param name="options">The options.</param>
40+
/// <param name="scheduler">The scheduler.</param>
41+
/// <returns></returns>
42+
protected abstract IReadOnlyList<int> GetTimesteps(PromptOptions prompt, SchedulerOptions options, IScheduler scheduler);
43+
44+
/// <summary>
45+
/// Prepares the latents.
46+
/// </summary>
47+
/// <param name="prompt">The prompt.</param>
48+
/// <param name="options">The options.</param>
49+
/// <param name="scheduler">The scheduler.</param>
50+
/// <param name="timesteps">The timesteps.</param>
51+
/// <returns></returns>
52+
protected abstract DenseTensor<float> PrepareLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, IScheduler scheduler, IReadOnlyList<int> timesteps);
53+
54+
55+
/// <summary>
56+
/// Rund the stable diffusion loop
57+
/// </summary>
58+
/// <param name="promptOptions">The prompt options.</param>
59+
/// <param name="schedulerOptions">The scheduler options.</param>
60+
/// <param name="progress">The progress.</param>
61+
/// <param name="cancellationToken">The cancellation token.</param>
62+
/// <returns></returns>
63+
public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOptions, PromptOptions promptOptions, SchedulerOptions schedulerOptions, Action<int, int> progressCallback = null, CancellationToken cancellationToken = default)
64+
{
65+
// Create random seed if none was set
66+
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
67+
68+
// LCM does not support classifier-free guidance
69+
var guidance = schedulerOptions.GuidanceScale;
70+
schedulerOptions.GuidanceScale = 0f;
71+
72+
// LCM does not support negative prompting
73+
promptOptions.NegativePrompt = string.Empty;
74+
75+
// Get Scheduler
76+
using (var scheduler = GetScheduler(promptOptions, schedulerOptions))
77+
{
78+
// Process prompts
79+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, schedulerOptions);
80+
81+
// Get timesteps
82+
var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler);
83+
84+
// Create latent sample
85+
var latents = PrepareLatents(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
86+
87+
// Get Guidance Scale Embedding
88+
var guidanceEmbeddings = GetGuidanceScaleEmbedding(guidance);
89+
90+
// Denoised result
91+
DenseTensor<float> denoised = null;
92+
93+
// Loop though the timesteps
94+
var step = 0;
95+
foreach (var timestep in timesteps)
96+
{
97+
step++;
98+
cancellationToken.ThrowIfCancellationRequested();
99+
100+
// Create input tensor.
101+
var inputTensor = scheduler.ScaleInput(latents, timestep);
102+
103+
// Create Input Parameters
104+
var inputParameters = CreateUnetInputParams(modelOptions, inputTensor, promptEmbeddings, guidanceEmbeddings, timestep);
105+
106+
// Run Inference
107+
using (var inferResult = await _onnxModelService.RunInferenceAsync(modelOptions, OnnxModelType.Unet, inputParameters))
108+
{
109+
var noisePred = inferResult.FirstElementAs<DenseTensor<float>>();
110+
111+
// Scheduler Step
112+
var schedulerResult = scheduler.Step(noisePred, timestep, latents);
113+
114+
latents = schedulerResult.Result;
115+
denoised = schedulerResult.SampleData;
116+
}
117+
118+
progressCallback?.Invoke(step, timesteps.Count);
119+
}
120+
121+
// Decode Latents
122+
return await DecodeLatents(modelOptions, promptOptions, schedulerOptions, denoised);
123+
}
124+
}
125+
126+
127+
/// <summary>
128+
/// Decodes the latents.
129+
/// </summary>
130+
/// <param name="options">The options.</param>
131+
/// <param name="latents">The latents.</param>
132+
/// <returns></returns>
133+
protected virtual async Task<DenseTensor<float>> DecodeLatents(IModelOptions model, PromptOptions prompt, SchedulerOptions options, DenseTensor<float> latents)
134+
{
135+
// Scale and decode the image latents with vae.
136+
latents = latents.MultiplyBy(1.0f / model.ScaleFactor);
137+
138+
var images = prompt.BatchCount > 1
139+
? latents.Split(prompt.BatchCount)
140+
: new[] { latents };
141+
var imageTensors = new List<DenseTensor<float>>();
142+
foreach (var image in images)
143+
{
144+
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.VaeDecoder);
145+
var inputParameters = CreateInputParameters(NamedOnnxValue.CreateFromTensor(inputNames[0], image));
146+
147+
// Run inference.
148+
using (var inferResult = await _onnxModelService.RunInferenceAsync(model, OnnxModelType.VaeDecoder, inputParameters))
149+
{
150+
var resultTensor = inferResult.FirstElementAs<DenseTensor<float>>();
151+
if (prompt.BatchCount == 1)
152+
return resultTensor.ToDenseTensor();
153+
154+
imageTensors.Add(resultTensor.ToDenseTensor());
155+
}
156+
}
157+
return imageTensors.Join();
158+
}
159+
160+
161+
/// <summary>
162+
/// Creates the Unet input parameters.
163+
/// </summary>
164+
/// <param name="model">The model.</param>
165+
/// <param name="inputTensor">The input tensor.</param>
166+
/// <param name="promptEmbeddings">The prompt embeddings.</param>
167+
/// <param name="timestep">The timestep.</param>
168+
/// <returns></returns>
169+
protected virtual IReadOnlyList<NamedOnnxValue> CreateUnetInputParams(IModelOptions model, DenseTensor<float> inputTensor, DenseTensor<float> promptEmbeddings, DenseTensor<float> guidanceEmbeddings, int timestep)
170+
{
171+
var inputNames = _onnxModelService.GetInputNames(model, OnnxModelType.Unet);
172+
return CreateInputParameters(
173+
NamedOnnxValue.CreateFromTensor(inputNames[0], inputTensor),
174+
NamedOnnxValue.CreateFromTensor(inputNames[1], new DenseTensor<long>(new long[] { timestep }, new int[] { 1 })),
175+
NamedOnnxValue.CreateFromTensor(inputNames[2], promptEmbeddings),
176+
NamedOnnxValue.CreateFromTensor(inputNames[3], guidanceEmbeddings));
177+
}
178+
179+
180+
/// <summary>
181+
/// Gets the scheduler.
182+
/// </summary>
183+
/// <param name="prompt"></param>
184+
/// <param name="options">The options.</param>
185+
/// <returns></returns>
186+
protected IScheduler GetScheduler(PromptOptions prompt, SchedulerOptions options)
187+
{
188+
return prompt.SchedulerType switch
189+
{
190+
SchedulerType.LCM => new LCMScheduler(options),
191+
_ => default
192+
};
193+
}
194+
195+
196+
/// <summary>
197+
/// Gets the guidance scale embedding.
198+
/// </summary>
199+
/// <param name="options">The options.</param>
200+
/// <param name="embeddingDim">The embedding dim.</param>
201+
/// <returns></returns>
202+
private DenseTensor<float> GetGuidanceScaleEmbedding(float guidance, int embeddingDim = 256)
203+
{
204+
var scale = guidance - 1f;
205+
var halfDim = embeddingDim / 2;
206+
float log = MathF.Log(10000.0f) / (halfDim - 1);
207+
var emb = Enumerable.Range(0, halfDim)
208+
.Select(x => MathF.Exp(x * -log))
209+
.ToArray();
210+
var embSin = emb.Select(MathF.Sin).ToArray();
211+
var embCos = emb.Select(MathF.Cos).ToArray();
212+
var result = new DenseTensor<float>(new[] { 1, 2 * halfDim });
213+
for (int i = 0; i < halfDim; i++)
214+
{
215+
result[0, i] = embSin[i];
216+
result[0, i + halfDim] = embCos[i];
217+
}
218+
return result;
219+
}
220+
221+
222+
/// <summary>
223+
/// Helper for creating the input parameters.
224+
/// </summary>
225+
/// <param name="parameters">The parameters.</param>
226+
/// <returns></returns>
227+
protected static IReadOnlyList<NamedOnnxValue> CreateInputParameters(params NamedOnnxValue[] parameters)
228+
{
229+
return parameters.ToList();
230+
}
231+
}
232+
}

0 commit comments

Comments
 (0)