@@ -1062,6 +1062,8 @@ namespace dml
1062
1062
return output;
1063
1063
}
1064
1064
1065
+ #if DML_TARGET_VERSION >= 0x3100
1066
+
1065
1067
inline Expression ClipGrad (Expression input, Expression inputGradient, float min, float max)
1066
1068
{
1067
1069
detail::GraphBuilder* builder = input.Impl ()->GetGraphBuilder ();
@@ -1084,6 +1086,8 @@ namespace dml
1084
1086
return output;
1085
1087
}
1086
1088
1089
+ #endif // DML_TARGET_VERSION >= 0x3100
1090
+
1087
1091
inline Expression Cos (Expression input, const Optional<DML_SCALE_BIAS>& scaleBias = NullOpt)
1088
1092
{
1089
1093
return detail::ElementWiseUnary<DML_OPERATOR_ELEMENT_WISE_COS, DML_ELEMENT_WISE_COS_OPERATOR_DESC>(input, scaleBias);
@@ -1254,11 +1258,15 @@ namespace dml
1254
1258
return detail::ElementWiseUnary<DML_OPERATOR_ELEMENT_WISE_SQRT, DML_ELEMENT_WISE_SQRT_OPERATOR_DESC>(input, scaleBias);
1255
1259
}
1256
1260
1261
+ #if DML_TARGET_VERSION >= 0x3100
1262
+
1257
1263
inline Expression DifferenceSquare (Expression a, Expression b)
1258
1264
{
1259
1265
return detail::ElementWiseBinary<DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE, DML_ELEMENT_WISE_DIFFERENCE_SQUARE_OPERATOR_DESC>(a, b);
1260
1266
}
1261
1267
1268
+ #endif // DML_TARGET_VERSION >= 0x3100
1269
+
1262
1270
inline Expression Subtract (Expression a, Expression b)
1263
1271
{
1264
1272
return detail::ElementWiseBinary<DML_OPERATOR_ELEMENT_WISE_SUBTRACT, DML_ELEMENT_WISE_SUBTRACT_OPERATOR_DESC>(a, b);
@@ -2659,6 +2667,8 @@ namespace dml
2659
2667
return output;
2660
2668
}
2661
2669
2670
+ #if DML_TARGET_VERSION >= 0x3100
2671
+
2662
2672
struct BatchNormalizationGradOutputs
2663
2673
{
2664
2674
Expression gradient;
@@ -2684,29 +2694,31 @@ namespace dml
2684
2694
TensorDesc outputScaleGradientTensor (meanTensor.dataType , meanTensor.sizes , builder->GetTensorPolicy ());
2685
2695
TensorDesc outputBiasGradientTensor (meanTensor.dataType , meanTensor.sizes , builder->GetTensorPolicy ());
2686
2696
2687
- DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC bng_desc = {};
2688
- bng_desc .InputTensor = inputTensor.AsPtr <DML_TENSOR_DESC>();
2689
- bng_desc .InputGradientTensor = inputGradientTensor.AsPtr <DML_TENSOR_DESC>();
2690
- bng_desc .MeanTensor = meanTensor.AsPtr <DML_TENSOR_DESC>();
2691
- bng_desc .VarianceTensor = varianceTensor.AsPtr <DML_TENSOR_DESC>();
2692
- bng_desc .ScaleTensor = scaleTensor.AsPtr <DML_TENSOR_DESC>();
2693
- bng_desc .Epsilon = epsilon;
2697
+ DML_BATCH_NORMALIZATION_GRAD_OPERATOR_DESC desc = {};
2698
+ desc .InputTensor = inputTensor.AsPtr <DML_TENSOR_DESC>();
2699
+ desc .InputGradientTensor = inputGradientTensor.AsPtr <DML_TENSOR_DESC>();
2700
+ desc .MeanTensor = meanTensor.AsPtr <DML_TENSOR_DESC>();
2701
+ desc .VarianceTensor = varianceTensor.AsPtr <DML_TENSOR_DESC>();
2702
+ desc .ScaleTensor = scaleTensor.AsPtr <DML_TENSOR_DESC>();
2703
+ desc .Epsilon = epsilon;
2694
2704
2695
- bng_desc .OutputGradientTensor = outputGradientTensor.AsPtr <DML_TENSOR_DESC>();
2696
- bng_desc .OutputScaleGradientTensor = outputScaleGradientTensor.AsPtr <DML_TENSOR_DESC>();
2697
- bng_desc .OutputBiasGradientTensor = outputBiasGradientTensor.AsPtr <DML_TENSOR_DESC>();
2705
+ desc .OutputGradientTensor = outputGradientTensor.AsPtr <DML_TENSOR_DESC>();
2706
+ desc .OutputScaleGradientTensor = outputScaleGradientTensor.AsPtr <DML_TENSOR_DESC>();
2707
+ desc .OutputBiasGradientTensor = outputBiasGradientTensor.AsPtr <DML_TENSOR_DESC>();
2698
2708
2699
2709
dml::detail::NodeOutput* const inputs[] = { input.Impl (), inputGradient.Impl (), mean.Impl (), variance.Impl (), scale.Impl () };
2700
- dml::detail::NodeID node = builder->CreateOperatorNode (DML_OPERATOR_BATCH_NORMALIZATION_GRAD, &bng_desc , inputs);
2710
+ dml::detail::NodeID node = builder->CreateOperatorNode (DML_OPERATOR_BATCH_NORMALIZATION_GRAD, &desc , inputs);
2701
2711
2702
2712
BatchNormalizationGradOutputs outputValues;
2703
- outputValues.gradient = builder->CreateNodeOutput (node, 0 , *bng_desc .OutputGradientTensor );
2704
- outputValues.scaleGradient = builder->CreateNodeOutput (node, 1 , *bng_desc .OutputScaleGradientTensor );
2705
- outputValues.biasGradient = builder->CreateNodeOutput (node, 2 , *bng_desc .OutputBiasGradientTensor );
2713
+ outputValues.gradient = builder->CreateNodeOutput (node, 0 , *desc .OutputGradientTensor );
2714
+ outputValues.scaleGradient = builder->CreateNodeOutput (node, 1 , *desc .OutputScaleGradientTensor );
2715
+ outputValues.biasGradient = builder->CreateNodeOutput (node, 2 , *desc .OutputBiasGradientTensor );
2706
2716
2707
2717
return outputValues;
2708
2718
}
2709
2719
2720
+ #endif // DML_TARGET_VERSION >= 0x3100
2721
+
2710
2722
inline Expression MeanVarianceNormalization (
2711
2723
Expression input,
2712
2724
Optional<Expression> scale,
0 commit comments