@@ -24,16 +24,16 @@ void ComputeJob(
2424 const T* bias_data,
2525 const ptrdiff_t task_idx,
2626 const int64_t norm_size,
27- IAllocatorUniquePtr< float >& scale_float_uptr ,
28- IAllocatorUniquePtr< float >& bias_float_uptr ,
27+ const float * scale_float_ptr ,
28+ const float * bias_float_ptr ,
2929 float epsilon,
3030 bool simplified,
3131 T* Y_data,
3232 U* mean_data,
3333 U* inv_std_dev_data,
3434 AllocatorPtr alloc) {
35- ORT_UNUSED_PARAMETER (scale_float_uptr ); // only used in MLFloat16 overload
36- ORT_UNUSED_PARAMETER (bias_float_uptr ); // only used in MLFloat16 overload
35+ ORT_UNUSED_PARAMETER (scale_float_ptr ); // only used in MLFloat16 overload
36+ ORT_UNUSED_PARAMETER (bias_float_ptr ); // only used in MLFloat16 overload
3737 ORT_UNUSED_PARAMETER (alloc);
3838
3939 const T* p_input = X_data + task_idx * norm_size;
@@ -82,14 +82,17 @@ void ComputeJob(
8282 const MLFloat16* bias_data,
8383 const ptrdiff_t task_idx,
8484 const int64_t norm_size,
85- IAllocatorUniquePtr< float >& scale_float_uptr ,
86- IAllocatorUniquePtr< float >& bias_float_uptr ,
85+ const float * scale_float_ptr ,
86+ const float * bias_float_ptr ,
8787 float epsilon,
8888 bool simplified,
8989 MLFloat16* Y_data,
9090 U* mean_data,
9191 U* inv_std_dev_data,
9292 AllocatorPtr alloc) {
93+ ORT_UNUSED_PARAMETER (scale_data); // only used in float/double overload
94+ ORT_UNUSED_PARAMETER (bias_data); // only used in float/double overload
95+
9396 const MLFloat16* p_input = X_data + task_idx * norm_size;
9497 MLFloat16* p_output = Y_data + task_idx * norm_size;
9598
@@ -117,22 +120,10 @@ void ComputeJob(
117120 mean_square = sqrt (mean_square / norm_size - mean * mean + epsilon);
118121 }
119122
120- if (!scale_float_uptr) {
121- scale_float_uptr = std::move (input_float_uptr); // overwrite input with scale values, since they have the same size
122- MlasConvertHalfToFloatBuffer (scale_data, scale_float_uptr.get (), num_elems);
123- }
124-
125- if (bias_data && !bias_float_uptr) {
126- bias_float_uptr = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
127- MlasConvertHalfToFloatBuffer (bias_data, bias_float_uptr.get (), num_elems);
128- }
129-
130- const float * scale_float_ptr = scale_float_uptr.get ();
131- const float * bias_float_ptr = bias_float_uptr.get ();
132123 for (size_t h = 0 ; h < num_elems; h++) {
133124 if (simplified) {
134125 output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h];
135- } else if (nullptr == bias_data ) {
126+ } else if (nullptr == bias_float_ptr ) {
136127 output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h];
137128 } else {
138129 output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h] + bias_float_ptr[h];
@@ -166,7 +157,13 @@ void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, I
166157} // namespace
167158
168159LayerNormImpl::LayerNormImpl (const OpKernelInfo& op_kernel_info, bool simplified, bool contrib_op)
169- : OpKernel(op_kernel_info), simplified_{simplified}, contrib_op_{contrib_op}, scale_fp32_(nullptr ), bias_fp32_(nullptr ) {
160+ : OpKernel(op_kernel_info),
161+ simplified_{simplified},
162+ contrib_op_{contrib_op},
163+ prepacked_scale_fp32_data_ (nullptr ),
164+ prepacked_scale_fp32_size_ (0 ),
165+ prepacked_bias_fp32_data_ (nullptr ),
166+ prepacked_bias_fp32_size_ (0 ) {
170167 ORT_ENFORCE (op_kernel_info.GetAttr (" axis" , &axis_).IsOK ());
171168 ORT_ENFORCE (op_kernel_info.GetAttr <float >(" epsilon" , &epsilon_).IsOK ());
172169}
@@ -175,15 +172,15 @@ template <typename T, typename U>
175172Status LayerNormImpl::ComputeImpl (OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) const {
176173 // Inputs
177174 const Tensor* X = p_ctx->Input <Tensor>(0 );
178- const Tensor* scale = p_ctx->Input <Tensor>(1 );
179- const Tensor* bias = p_ctx->Input <Tensor>(2 );
175+ const Tensor* scale = prepacked_scale_fp32_data_ ? nullptr : p_ctx->Input <Tensor>(1 );
176+ const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input <Tensor>(2 );
180177 const T* X_data = X->Data <T>();
181- const T* scale_data = scale->Data <T>();
178+ const T* scale_data = scale ? scale ->Data <T>() : nullptr ;
182179 const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data <T>();
183180
184181 const TensorShape& x_shape = X->Shape ();
185- const TensorShape& scale_shape = scale->Shape ();
186- const TensorShape& bias_shape = bias->Shape ();
182+ size_t scale_size = scale ? static_cast < size_t >(scale ->Shape (). Size ()) : prepacked_scale_fp32_size_ ;
183+ size_t bias_size = bias ? static_cast < size_t >(bias ->Shape (). Size ()) : prepacked_bias_fp32_size_ ;
187184 Tensor* Y = p_ctx->Output (0 , x_shape);
188185 T* Y_data = Y->MutableData <T>();
189186
@@ -218,7 +215,7 @@ Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, flo
218215
219216 AllocatorPtr alloc;
220217 ORT_RETURN_IF_ERROR (p_ctx->GetTempSpaceAllocator (&alloc));
221- return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_shape , bias_data, bias_shape , Y_data, mean_data,
218+ return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_size , bias_data, bias_size , Y_data, mean_data,
222219 inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc);
223220}
224221
@@ -237,9 +234,11 @@ Status LayerNormImpl::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr
237234
238235 is_packed = false ;
239236 if (input_idx == 1 ) { // scale
240- ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, scale_fp32_, is_packed);
237+ prepacked_scale_fp32_size_ = static_cast <size_t >(tensor.Shape ().Size ());
238+ ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, prepacked_scale_fp32_data_, is_packed);
241239 } else if (input_idx == 2 ) { // bias
242- ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, bias_fp32_, is_packed);
240+ prepacked_bias_fp32_size_ = static_cast <size_t >(tensor.Shape ().Size ());
241+ ConvertMLFloat16ToFloatIfNeeded (tensor, alloc, prepacked_bias_fp32_data_, is_packed);
243242 }
244243
245244 return Status::OK ();
@@ -250,9 +249,9 @@ Status LayerNormImpl::ComputeWithoutContext(
250249 const T* X_data,
251250 const TensorShape& x_shape,
252251 const T* scale_data,
253- const TensorShape& scale_shape ,
252+ size_t scale_size ,
254253 const T* bias_data,
255- const TensorShape& bias_shape ,
254+ size_t bias_size ,
256255 T* Y_data,
257256 U* mean_data,
258257 U* inv_std_dev_data,
@@ -264,19 +263,34 @@ Status LayerNormImpl::ComputeWithoutContext(
264263 int64_t norm_count = x_shape.SizeToDimension (onnxruntime::narrow<size_t >(axis));
265264 int64_t norm_size = x_shape.SizeFromDimension (onnxruntime::narrow<size_t >(axis));
266265
267- const auto scale_size = scale_shape.Size ();
268- const auto bias_size = (bias_data) ? bias_shape.Size () : 0 ;
269- if (scale_size != norm_size || (bias_data && bias_size != norm_size)) {
266+ if (static_cast <int64_t >(scale_size) != norm_size || (bias_data && static_cast <int64_t >(bias_size) != norm_size)) {
270267 return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
271268 " Size of X.shape()[axis:] == " , norm_size,
272269 " . Size of scale and bias (if provided) must match this. Got scale size of " ,
273270 scale_size, " and bias size of " , bias_size);
274271 }
275272
273+ IAllocatorUniquePtr<float > scale_fp32;
274+ IAllocatorUniquePtr<float > bias_fp32;
275+ if constexpr (std::is_same_v<T, MLFloat16>) {
276+ if (prepacked_scale_fp32_data_ == nullptr ) {
277+ const size_t num_elems = static_cast <size_t >(norm_size);
278+ scale_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
279+ MlasConvertHalfToFloatBuffer (scale_data, scale_fp32.get (), num_elems);
280+ }
281+ if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
282+ const size_t num_elems = static_cast <size_t >(norm_size);
283+ bias_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
284+ MlasConvertHalfToFloatBuffer (bias_data, bias_fp32.get (), num_elems);
285+ }
286+ }
287+
276288 concurrency::ThreadPool::TryBatchParallelFor (
277289 thread_pool, static_cast <int32_t >(norm_count),
278290 [&](ptrdiff_t task_idx) {
279- ComputeJob (X_data, scale_data, bias_data, task_idx, norm_size, scale_fp32_, bias_fp32_,
291+ ComputeJob (X_data, scale_data, bias_data, task_idx, norm_size,
292+ prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get () : scale_fp32.get (),
293+ prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get () : bias_fp32.get (),
280294 epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc);
281295 },
282296 0 );
0 commit comments