Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 3d9cd5a

Browse files
committed
Simplify prompt guidance variable
1 parent b39ce21 commit 3d9cd5a

File tree

6 files changed

+14
-18
lines changed

6 files changed

+14
-18
lines changed

OnnxStack.StableDiffusion/Common/IPromptService.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ namespace OnnxStack.StableDiffusion.Common
66
{
77
public interface IPromptService
88
{
9-
Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions);
9+
Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, PromptOptions promptOptions, bool isGuidanceEnabled);
1010
Task<int[]> DecodeTextAsync(IModelOptions model, string inputText);
1111
Task<float[]> EncodeTokensAsync(IModelOptions model, int[] tokenizedInput);
1212
}

OnnxStack.StableDiffusion/Diffusers/LatentConsistency/LatentConsistencyDiffuser.cs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,14 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
6565
// Create random seed if none was set
6666
schedulerOptions.Seed = schedulerOptions.Seed > 0 ? schedulerOptions.Seed : Random.Shared.Next();
6767

68-
// LCM does not support classifier-free guidance
69-
var guidance = schedulerOptions.GuidanceScale;
70-
schedulerOptions.GuidanceScale = 0f;
71-
7268
// LCM does not support negative prompting
7369
promptOptions.NegativePrompt = string.Empty;
7470

7571
// Get Scheduler
7672
using (var scheduler = GetScheduler(promptOptions, schedulerOptions))
7773
{
7874
// Process prompts
79-
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, schedulerOptions);
75+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, false);
8076

8177
// Get timesteps
8278
var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler);
@@ -85,7 +81,7 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
8581
var latents = PrepareLatents(modelOptions, promptOptions, schedulerOptions, scheduler, timesteps);
8682

8783
// Get Guidance Scale Embedding
88-
var guidanceEmbeddings = GetGuidanceScaleEmbedding(guidance);
84+
var guidanceEmbeddings = GetGuidanceScaleEmbedding(schedulerOptions.GuidanceScale);
8985

9086
// Denoised result
9187
DenseTensor<float> denoised = null;

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintDiffuser.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
4242
// Get Scheduler
4343
using (var scheduler = GetScheduler(promptOptions, schedulerOptions))
4444
{
45-
// Process prompts
46-
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, schedulerOptions);
47-
4845
// Should we perform classifier free guidance
4946
var performGuidance = schedulerOptions.GuidanceScale > 1.0f;
5047

48+
// Process prompts
49+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
50+
5151
// Get timesteps
5252
var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler);
5353

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/InpaintLegacyDiffuser.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ public override async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelO
3535
// Get Scheduler
3636
using (var scheduler = GetScheduler(promptOptions, schedulerOptions))
3737
{
38-
// Process prompts
39-
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, schedulerOptions);
40-
4138
// Should we perform classifier free guidance
4239
var performGuidance = schedulerOptions.GuidanceScale > 1.0f;
4340

41+
// Process prompts
42+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
43+
4444
// Get timesteps
4545
var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler);
4646

OnnxStack.StableDiffusion/Diffusers/StableDiffusion/StableDiffusionDiffuser.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ public virtual async Task<DenseTensor<float>> DiffuseAsync(IModelOptions modelOp
6868
// Get Scheduler
6969
using (var scheduler = GetScheduler(promptOptions, schedulerOptions))
7070
{
71-
// Process prompts
72-
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, schedulerOptions);
73-
7471
// Should we perform classifier free guidance
7572
var performGuidance = schedulerOptions.GuidanceScale > 1.0f;
7673

74+
// Process prompts
75+
var promptEmbeddings = await _promptService.CreatePromptAsync(modelOptions, promptOptions, performGuidance);
76+
7777
// Get timesteps
7878
var timesteps = GetTimesteps(promptOptions, schedulerOptions, scheduler);
7979

OnnxStack.StableDiffusion/Services/PromptService.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ public PromptService(IOnnxModelService onnxModelService)
3535
/// <param name="prompt">The prompt.</param>
3636
/// <param name="negativePrompt">The negative prompt.</param>
3737
/// <returns>Tensor containing all text embeds generated from the prompt and negative prompt</returns>
38-
public async Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, PromptOptions promptOptions, SchedulerOptions schedulerOptions)
38+
public async Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, PromptOptions promptOptions, bool isGuidanceEnabled)
3939
{
4040
// Tokenize Prompt and NegativePrompt
4141
var promptTokens = await DecodeTextAsync(model, promptOptions.Prompt);
@@ -55,7 +55,7 @@ public async Task<DenseTensor<float>> CreatePromptAsync(IModelOptions model, Pro
5555

5656
// If we are doing guided diffusion, concatenate the negative prompt embeddings
5757
// If not we ingore the negative prompt embeddings
58-
if (schedulerOptions.GuidanceScale > 1)
58+
if (isGuidanceEnabled)
5959
return negativePromptEmbeddings.Concatenate(promptEmbeddings);
6060

6161
return promptEmbeddings;

0 commit comments

Comments
 (0)