Skip to content
This repository was archived by the owner on Nov 27, 2024. It is now read-only.

Commit 736ac4f

Browse files
committed
Tidy up some math, Code review
1 parent 906cb21 commit 736ac4f

File tree

6 files changed

+220
-162
lines changed

6 files changed

+220
-162
lines changed

OnnxStack.StableDiffusion/Config/SchedulerOptions.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,15 @@ public class SchedulerOptions
6363
public BetaScheduleType BetaSchedule { get; set; } = BetaScheduleType.ScaledLinear;
6464
public int StepsOffset { get; set; } = 0;
6565
public bool UseKarrasSigmas { get; set; } = false;
66-
public VarianceType VarianceType { get; internal set; } = VarianceType.FixedSmall;
66+
public VarianceType VarianceType { get; set; } = VarianceType.FixedSmall;
6767
public float SampleMaxValue { get; set; } = 1.0f;
68-
public bool Thresholding { get; internal set; } = false;
69-
public bool ClipSample { get; internal set; } = false;
70-
public float ClipSampleRange { get; internal set; } = 1f;
71-
public PredictionType PredictionType { get; internal set; } = PredictionType.Epsilon;
68+
public bool Thresholding { get; set; } = false;
69+
public bool ClipSample { get; set; } = false;
70+
public float ClipSampleRange { get; set; } = 1f;
71+
public PredictionType PredictionType { get; set; } = PredictionType.Epsilon;
7272
public AlphaTransformType AlphaTransformType { get; set; } = AlphaTransformType.Cosine;
7373
public float MaximumBeta { get; set; } = 0.999f;
7474

75-
75+
7676
}
7777
}

OnnxStack.StableDiffusion/Helpers/TensorHelper.cs

Lines changed: 149 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,102 @@
11
using Microsoft.ML.OnnxRuntime.Tensors;
2-
using OnnxStack.StableDiffusion.Config;
32
using System;
43
using System.Linq;
54

65
namespace OnnxStack.StableDiffusion.Helpers
76
{
87
public static class TensorHelper
98
{
9+
/// <summary>
10+
/// Creates a new tensor.
11+
/// </summary>
12+
/// <typeparam name="T"></typeparam>
13+
/// <param name="data">The data.</param>
14+
/// <param name="dimensions">The dimensions.</param>
15+
/// <returns></returns>
1016
public static DenseTensor<T> CreateTensor<T>(T[] data, ReadOnlySpan<int> dimensions)
1117
{
1218
return new DenseTensor<T>(data, dimensions);
1319
}
1420

15-
public static DenseTensor<float> DivideTensorByFloat(this DenseTensor<float> data, float value, ReadOnlySpan<int> dimensions)
21+
22+
/// <summary>
23+
/// Divides the tensor by float.
24+
/// </summary>
25+
/// <param name="tensor">The data.</param>
26+
/// <param name="value">The value.</param>
27+
/// <param name="dimensions">The dimensions.</param>
28+
/// <returns></returns>
29+
public static DenseTensor<float> DivideTensorByFloat(this DenseTensor<float> tensor, float value, ReadOnlySpan<int> dimensions)
1630
{
1731
var divTensor = new DenseTensor<float>(dimensions);
18-
for (int i = 0; i < data.Length; i++)
32+
for (int i = 0; i < tensor.Length; i++)
1933
{
20-
divTensor.SetValue(i, data.GetValue(i) / value);
34+
divTensor.SetValue(i, tensor.GetValue(i) / value);
2135
}
2236
return divTensor;
2337
}
2438

25-
public static DenseTensor<float> DivideTensorByFloat(this DenseTensor<float> data, float value)
39+
40+
/// <summary>
41+
/// Divides the tensor by float.
42+
/// </summary>
43+
/// <param name="tensor">The data.</param>
44+
/// <param name="value">The value.</param>
45+
/// <returns></returns>
46+
public static DenseTensor<float> DivideTensorByFloat(this DenseTensor<float> tensor, float value)
2647
{
27-
var divTensor = new DenseTensor<float>(data.Dimensions);
28-
for (int i = 0; i < data.Length; i++)
48+
var divTensor = new DenseTensor<float>(tensor.Dimensions);
49+
for (int i = 0; i < tensor.Length; i++)
2950
{
30-
divTensor.SetValue(i, data.GetValue(i) / value);
51+
divTensor.SetValue(i, tensor.GetValue(i) / value);
3152
}
3253
return divTensor;
3354
}
3455

35-
public static DenseTensor<float> MultipleTensorByFloat(this DenseTensor<float> data, float value)
56+
57+
/// <summary>
58+
/// Multiples the tensor by float.
59+
/// </summary>
60+
/// <param name="tensor">The data.</param>
61+
/// <param name="value">The value.</param>
62+
/// <returns></returns>
63+
public static DenseTensor<float> MultipleTensorByFloat(this DenseTensor<float> tensor, float value)
3664
{
37-
var mullTensor = new DenseTensor<float>(data.Dimensions);
38-
for (int i = 0; i < data.Length; i++)
65+
var mullTensor = new DenseTensor<float>(tensor.Dimensions);
66+
for (int i = 0; i < tensor.Length; i++)
3967
{
40-
mullTensor.SetValue(i, data.GetValue(i) * value);
68+
mullTensor.SetValue(i, tensor.GetValue(i) * value);
4169
}
4270
return mullTensor;
4371
}
4472

45-
public static DenseTensor<float> AddTensors(this DenseTensor<float> sample, DenseTensor<float> sumTensor)
73+
74+
/// <summary>
75+
/// Adds the tensors.
76+
/// </summary>
77+
/// <param name="tensor">The sample.</param>
78+
/// <param name="sumTensor">The sum tensor.</param>
79+
/// <returns></returns>
80+
public static DenseTensor<float> AddTensors(this DenseTensor<float> tensor, DenseTensor<float> sumTensor)
4681
{
47-
var addTensor = new DenseTensor<float>(sample.Dimensions);
48-
for (var i = 0; i < sample.Length; i++)
82+
var addTensor = new DenseTensor<float>(tensor.Dimensions);
83+
for (var i = 0; i < tensor.Length; i++)
4984
{
50-
addTensor.SetValue(i, sample.GetValue(i) + sumTensor.GetValue(i));
85+
addTensor.SetValue(i, tensor.GetValue(i) + sumTensor.GetValue(i));
5186
}
5287
return addTensor;
5388
}
5489

55-
public static Tuple<DenseTensor<float>, DenseTensor<float>> SplitTensor(this DenseTensor<float> tensorToSplit, ReadOnlySpan<int> dimensions, int scaledHeight, int scaledWidth)
90+
91+
/// <summary>
92+
/// Splits the tensor.
93+
/// </summary>
94+
/// <param name="tensorToSplit">The tensor to split.</param>
95+
/// <param name="dimensions">The dimensions.</param>
96+
/// <param name="scaledHeight">Height of the scaled.</param>
97+
/// <param name="scaledWidth">Width of the scaled.</param>
98+
/// <returns></returns>
99+
public static Tuple<DenseTensor<float>, DenseTensor<float>> SplitTensor(this DenseTensor<float> tensor, ReadOnlySpan<int> dimensions, int scaledHeight, int scaledWidth)
56100
{
57101
var tensor1 = new DenseTensor<float>(dimensions);
58102
var tensor2 = new DenseTensor<float>(dimensions);
@@ -64,21 +108,28 @@ public static Tuple<DenseTensor<float>, DenseTensor<float>> SplitTensor(this Den
64108
{
65109
for (int l = 0; l < scaledWidth; l++)
66110
{
67-
tensor1[i, j, k, l] = tensorToSplit[i, j, k, l];
68-
tensor2[i, j, k, l] = tensorToSplit[i, j + 4, k, l];
111+
tensor1[i, j, k, l] = tensor[i, j, k, l];
112+
tensor2[i, j, k, l] = tensor[i, j + 4, k, l];
69113
}
70114
}
71115
}
72116
}
73117
return new Tuple<DenseTensor<float>, DenseTensor<float>>(tensor1, tensor2);
74118
}
75119

76-
public static DenseTensor<float> SumTensors(this DenseTensor<float>[] tensorArray, ReadOnlySpan<int> dimensions)
120+
121+
/// <summary>
122+
/// Sums the tensors.
123+
/// </summary>
124+
/// <param name="tensors">The tensor array.</param>
125+
/// <param name="dimensions">The dimensions.</param>
126+
/// <returns></returns>
127+
public static DenseTensor<float> SumTensors(this DenseTensor<float>[] tensors, ReadOnlySpan<int> dimensions)
77128
{
78129
var sumTensor = new DenseTensor<float>(dimensions);
79-
for (int m = 0; m < tensorArray.Length; m++)
130+
for (int m = 0; m < tensors.Length; m++)
80131
{
81-
var tensorToSum = tensorArray[m];
132+
var tensorToSum = tensors[m];
82133
for (var i = 0; i < tensorToSum.Length; i++)
83134
{
84135
sumTensor.SetValue(i, sumTensor.GetValue(i) + tensorToSum.GetValue(i));
@@ -87,51 +138,79 @@ public static DenseTensor<float> SumTensors(this DenseTensor<float>[] tensorArra
87138
return sumTensor;
88139
}
89140

90-
public static DenseTensor<float> Duplicate(this DenseTensor<float> data, ReadOnlySpan<int> dimensions)
141+
142+
/// <summary>
143+
/// Duplicates the specified tensor.
144+
/// </summary>
145+
/// <param name="tensor">The data.</param>
146+
/// <param name="dimensions">The dimensions.</param>
147+
/// <returns></returns>
148+
public static DenseTensor<float> Duplicate(this DenseTensor<float> tensor, ReadOnlySpan<int> dimensions)
91149
{
92-
var dupTensor = data.Concat(data).ToArray();
150+
var dupTensor = tensor.Concat(tensor).ToArray();
93151
return CreateTensor(dupTensor, dimensions);
94152
}
95153

96-
public static DenseTensor<float> SubtractTensors(this DenseTensor<float> sample, DenseTensor<float> subTensor, ReadOnlySpan<int> dimensions)
154+
155+
/// <summary>
156+
/// Subtracts the tensors.
157+
/// </summary>
158+
/// <param name="tensor">The tensor.</param>
159+
/// <param name="subTensor">The sub tensor.</param>
160+
/// <param name="dimensions">The dimensions.</param>
161+
/// <returns></returns>
162+
public static DenseTensor<float> SubtractTensors(this DenseTensor<float> tensor, DenseTensor<float> subTensor, ReadOnlySpan<int> dimensions)
97163
{
98164
var result = new DenseTensor<float>(dimensions);
99-
for (var i = 0; i < sample.Length; i++)
165+
for (var i = 0; i < tensor.Length; i++)
100166
{
101-
result.SetValue(i, sample.GetValue(i) - subTensor.GetValue(i));
167+
result.SetValue(i, tensor.GetValue(i) - subTensor.GetValue(i));
102168
}
103169
return result;
104170
}
105171

106-
public static DenseTensor<float> SubtractTensors(this DenseTensor<float> sample, DenseTensor<float> subTensor)
172+
173+
/// <summary>
174+
/// Subtracts the tensors.
175+
/// </summary>
176+
/// <param name="tensor">The sample.</param>
177+
/// <param name="subTensor">The sub tensor.</param>
178+
/// <returns></returns>
179+
public static DenseTensor<float> SubtractTensors(this DenseTensor<float> tensor, DenseTensor<float> subTensor)
107180
{
108-
return sample.SubtractTensors(subTensor, sample.Dimensions);
181+
return tensor.SubtractTensors(subTensor, tensor.Dimensions);
109182
}
110183

111184

112-
113185
/// <summary>
114186
/// Reorders the tensor.
115187
/// </summary>
116-
/// <param name="inputTensor">The input tensor.</param>
188+
/// <param name="tensor">The input tensor.</param>
117189
/// <returns></returns>
118-
public static DenseTensor<float> ReorderTensor(this DenseTensor<float> inputTensor, ReadOnlySpan<int> dimensions)
190+
public static DenseTensor<float> ReorderTensor(this DenseTensor<float> tensor, ReadOnlySpan<int> dimensions)
119191
{
120192
//reorder from batch channel height width to batch height width channel
121193
var inputImagesTensor = new DenseTensor<float>(dimensions);
122-
for (int y = 0; y < inputTensor.Dimensions[2]; y++)
194+
for (int y = 0; y < tensor.Dimensions[2]; y++)
123195
{
124-
for (int x = 0; x < inputTensor.Dimensions[3]; x++)
196+
for (int x = 0; x < tensor.Dimensions[3]; x++)
125197
{
126-
inputImagesTensor[0, y, x, 0] = inputTensor[0, 0, y, x];
127-
inputImagesTensor[0, y, x, 1] = inputTensor[0, 1, y, x];
128-
inputImagesTensor[0, y, x, 2] = inputTensor[0, 2, y, x];
198+
inputImagesTensor[0, y, x, 0] = tensor[0, 0, y, x];
199+
inputImagesTensor[0, y, x, 1] = tensor[0, 1, y, x];
200+
inputImagesTensor[0, y, x, 2] = tensor[0, 2, y, x];
129201
}
130202
}
131203
return inputImagesTensor;
132204
}
133205

134206

207+
/// <summary>
208+
/// Performs classifier free guidance
209+
/// </summary>
210+
/// <param name="noisePred">The noise pred.</param>
211+
/// <param name="noisePredText">The noise pred text.</param>
212+
/// <param name="guidanceScale">The guidance scale.</param>
213+
/// <returns></returns>
135214
public static DenseTensor<float> PerformGuidance(this DenseTensor<float> noisePred, DenseTensor<float> noisePredText, double guidanceScale)
136215
{
137216
for (int i = 0; i < noisePred.Dimensions[0]; i++)
@@ -151,16 +230,47 @@ public static DenseTensor<float> PerformGuidance(this DenseTensor<float> noisePr
151230
}
152231

153232

233+
/// <summary>
234+
/// Clips the specified Tensor valuse to the specified minimum/maximum.
235+
/// </summary>
236+
/// <param name="tensor">The tensor.</param>
237+
/// <param name="minValue">The minimum value.</param>
238+
/// <param name="maxValue">The maximum value.</param>
239+
/// <returns></returns>
154240
public static DenseTensor<float> Clip(this DenseTensor<float> tensor, float minValue, float maxValue)
155241
{
242+
var clipTensor = new DenseTensor<float>(tensor.Dimensions);
156243
for (int i = 0; i < tensor.Length; i++)
157244
{
158-
tensor.SetValue(i, Math.Clamp(tensor.GetValue(i), minValue, maxValue));
245+
clipTensor.SetValue(i, Math.Clamp(tensor.GetValue(i), minValue, maxValue));
159246
}
160-
return tensor;
247+
return clipTensor;
161248
}
162249

163250

251+
/// <summary>
252+
/// Computes the absolute values of the Tensor
253+
/// </summary>
254+
/// <param name="tensor">The tensor.</param>
255+
/// <returns></returns>
256+
public static DenseTensor<float> Abs(this DenseTensor<float> tensor)
257+
{
258+
var absTensor = new DenseTensor<float>(tensor.Dimensions);
259+
for (int i = 0; i < tensor.Length; i++)
260+
{
261+
absTensor.SetValue(i, Math.Abs(tensor.GetValue(i)));
262+
}
263+
return absTensor;
264+
}
265+
266+
267+
/// <summary>
268+
/// Generate a random Tensor from a normal distribution with mean 0 and variance 1
269+
/// </summary>
270+
/// <param name="random">The random.</param>
271+
/// <param name="dimensions">The dimensions.</param>
272+
/// <param name="initNoiseSigma">The initialize noise sigma.</param>
273+
/// <returns></returns>
164274
public static DenseTensor<float> GetRandomTensor(Random random, ReadOnlySpan<int> dimensions, float initNoiseSigma = 1f)
165275
{
166276
var latents = new DenseTensor<float>(dimensions);

OnnxStack.StableDiffusion/Schedulers/DDPMScheduler.cs

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,14 @@ protected override int[] SetTimesteps()
9191
else if (Options.TimestepSpacing == TimestepSpacingType.Leading)
9292
{
9393
var stepRatio = Options.TrainTimesteps / Options.InferenceSteps;
94-
timestepsArray = np.arange(0, Options.InferenceSteps) * stepRatio;
94+
timestepsArray = np.arange(0, (float)Options.InferenceSteps) * stepRatio;
9595
timestepsArray = np.around(timestepsArray)["::1"];
9696
timestepsArray += Options.StepsOffset;
9797
}
9898
else if (Options.TimestepSpacing == TimestepSpacingType.Trailing)
9999
{
100-
var stepRatio = Options.TrainTimesteps / Options.InferenceSteps;
101-
timestepsArray = np.arange(Options.TrainTimesteps, 0, -stepRatio);
100+
var stepRatio = Options.TrainTimesteps / (Options.InferenceSteps - 1);
101+
timestepsArray = np.arange((float)Options.TrainTimesteps, 0, -stepRatio)["::-1"];
102102
timestepsArray = np.around(timestepsArray);
103103
timestepsArray -= 1;
104104
}
@@ -177,20 +177,17 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
177177
// pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
178178
var alphaSqrt = (float)Math.Sqrt(alphaProdT);
179179
var betaSqrt = (float)Math.Sqrt(betaProdT);
180-
predOriginalSample = new DenseTensor<float>((int)sample.Length);
181-
for (int i = 0; i < sample.Length - 1; i++)
182-
{
183-
predOriginalSample.SetValue(i, alphaSqrt * sample.GetValue(i) - betaSqrt * modelOutput.GetValue(i));
184-
}
180+
predOriginalSample = sample
181+
.MultipleTensorByFloat(alphaSqrt)
182+
.SubtractTensors(modelOutput.MultipleTensorByFloat(betaSqrt));
185183
}
186184

187185

188186
//# 3. Clip or threshold "predicted x_0"
189187
if (Options.Thresholding)
190188
{
191189
// TODO: https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddpm.py#L322
192-
//predOriginalSample = ThresholdSample(predOriginalSample);
193-
throw new NotImplementedException("DDPMScheduler Thresholding currently not implemented");
190+
predOriginalSample = ThresholdSample(predOriginalSample);
194191
}
195192
else if (Options.ClipSample)
196193
{
@@ -199,7 +196,7 @@ public override DenseTensor<float> Step(DenseTensor<float> modelOutput, int time
199196

200197
//# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
201198
//# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
202-
float predOriginalSampleCoeff = ((float)Math.Sqrt(alphaProdTPrev) * currentBetaT) / betaProdT;
199+
float predOriginalSampleCoeff = (float)Math.Sqrt(alphaProdTPrev) * currentBetaT / betaProdT;
203200
float currentSampleCoeff = (float)Math.Sqrt(currentAlphaT) * betaProdTPrev / betaProdT;
204201

205202

@@ -314,13 +311,6 @@ private float GetVariance(int timestep, float predictedVariance = 0f)
314311
private DenseTensor<float> ThresholdSample(DenseTensor<float> input, float dynamicThresholdingRatio = 0.995f, float sampleMaxValue = 1f)
315312
{
316313
var sample = new NDArray(input.ToArray(), new Shape(input.Dimensions.ToArray()));
317-
318-
// Ensure the data type is float32 or float64
319-
if (sample.dtype != typeof(float) && sample.dtype != typeof(double))
320-
{
321-
sample = sample.astype(NPTypeCode.Single); // Upcast for quantile calculation
322-
}
323-
324314
var batch_size = sample.shape[0];
325315
var channels = sample.shape[1];
326316
var height = sample.shape[2];

0 commit comments

Comments
 (0)