Skip to content

Commit f10e9f7

Browse files
OpenVINO NN Module Functions (#21803)
* support for celu, one hot, avg pooling * ctc loss * revert test change * Update keras/src/backend/openvino/nn.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras/src/backend/openvino/nn.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * fix one hot with sparse check * simplify pooling * handle dtype of one_hot * use swish op for silu * support for log_sigmoid * address gemini feedback * Update keras/src/backend/openvino/nn.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * fix consolidated pool function * enable testing * fix max_pool call * fix dtype for ctc_loss * fix dtype for swish * permit nn test to be run * support for more activation functions, enabling tests that should stay enabled * support squareplus and sparse_plus * enable selu test * support threshold * selu test * support scaled dot product attention * Update keras/src/backend/openvino/nn.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update keras/src/ops/nn_test.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * fix formatting * fix handling of flash_attention * fix spacing in functions * fix spacing * add tolerance to specific tests * oneliner for atol * omit tolerance changes --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 1240100 commit f10e9f7

File tree

4 files changed

+259
-22
lines changed

4 files changed

+259
-22
lines changed

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,28 @@ TestMathErrors::test_stft_invalid_window
257257
TestMathErrors::test_stft_invalid_window_shape
258258
LinalgOpsCorrectnessTest::test_cholesky
259259
LinalgOpsCorrectnessTest::test_cholesky_inverse
260+
NNOpsDynamicShapeTest::test_binary_crossentropy
261+
NNOpsDynamicShapeTest::test_categorical_crossentropy
262+
NNOpsDynamicShapeTest::test_multi_hot_dtype_
263+
NNOpsCorrectnessTest::test_conv_transpose_
264+
NNOpsCorrectnessTest::test_ctc_decode
265+
NNOpsCorrectnessTest::test_multi_hot_
266+
NNOpsCorrectnessTest::test_binary_crossentropy
267+
NNOpsCorrectnessTest::test_categorical_crossentropy
268+
NNOpsCorrectnessTest::test_log_softmax_correctness_with_axis_tuple
269+
NNOpsCorrectnessTest::test_softmax_correctness_with_axis_tuple
270+
NNOpsCorrectnessTest::test_separable_conv_
271+
NNOpsCorrectnessTest::test_glu
272+
NNOpsCorrectnessTest::test_moments
273+
NNOpsCorrectnessTest::test_normalize
274+
NNOpsCorrectnessTest::test_polar_corectness
275+
NNOpsCorrectnessTest::test_psnr
276+
NNOpsCorrectnessTest::test_sparse_categorical_crossentropy
277+
NNOpsCorrectnessTest::test_sparsemax
278+
NNOpsCorrectnessTest::test_rms_normalization_10.0
279+
NNOpsDtypeTest::test_ctc_decode
280+
NNOpsDtypeTest::test_glu_
281+
NNOpsDtypeTest::test_polar_
282+
NNOpsDynamicShapeTest::test_glu
283+
NNOpsBehaviorTest::test_invalid_strategy_ctc_decode
284+
NNOpsBehaviorTest::test_logit_recovery_binary_crossentropy

keras/src/backend/openvino/excluded_tests.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ keras/src/metrics
3030
keras/src/models
3131
keras/src/ops/image_test.py
3232
keras/src/ops/linalg_test.py
33-
keras/src/ops/nn_test.py
3433
keras/src/optimizers
3534
keras/src/quantizers
3635
keras/src/random/seed_generator_test.py

keras/src/backend/openvino/nn.py

Lines changed: 228 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from openvino import Type
33

44
from keras.src import backend
5+
from keras.src.backend.openvino.core import OPENVINO_DTYPES
56
from keras.src.backend.openvino.core import OpenVINOKerasTensor
67
from keras.src.backend.openvino.core import get_ov_output
78

@@ -16,6 +17,23 @@ def relu6(x):
1617
return OpenVINOKerasTensor(ov_opset.clamp(x, 0.0, 6.0).output(0))
1718

1819

20+
def celu(x, alpha=1.0):
21+
x = get_ov_output(x)
22+
const_zero = get_ov_output(0.0, x.get_element_type())
23+
const_alpha = get_ov_output(alpha, x.get_element_type())
24+
const_one = get_ov_output(1.0, x.get_element_type())
25+
exp_x_div_alpha = ov_opset.exp(ov_opset.divide(x, const_alpha)).output(0)
26+
negative_branch = ov_opset.multiply(
27+
const_alpha, ov_opset.subtract(exp_x_div_alpha, const_one)
28+
)
29+
30+
celu_x = ov_opset.add(
31+
ov_opset.maximum(x, const_zero).output(0),
32+
ov_opset.minimum(negative_branch, const_zero).output(0),
33+
)
34+
return OpenVINOKerasTensor(celu_x.output(0))
35+
36+
1937
def sigmoid(x):
2038
x = get_ov_output(x)
2139
return OpenVINOKerasTensor(ov_opset.sigmoid(x).output(0))
@@ -26,6 +44,39 @@ def tanh(x):
2644
return OpenVINOKerasTensor(ov_opset.tanh(x).output(0))
2745

2846

47+
def tanh_shrink(x):
48+
x = get_ov_output(x)
49+
return OpenVINOKerasTensor(ov_opset.subtract(x, ov_opset.tanh(x)).output(0))
50+
51+
52+
def hard_tanh(x):
53+
x = get_ov_output(x)
54+
return OpenVINOKerasTensor(ov_opset.clamp(x, -1.0, 1.0).output(0))
55+
56+
57+
def soft_shrink(x, threshold=0.5):
58+
x = get_ov_output(x)
59+
et = x.get_element_type()
60+
thr = get_ov_output(threshold, et)
61+
zero = get_ov_output(0.0, et)
62+
abs_x = ov_opset.abs(x)
63+
sub = ov_opset.subtract(abs_x, thr)
64+
shrunk = ov_opset.maximum(sub, zero)
65+
sign = ov_opset.sign(x)
66+
out = ov_opset.multiply(sign, shrunk)
67+
return OpenVINOKerasTensor(out.output(0))
68+
69+
70+
def hard_shrink(x, threshold=0.5):
71+
x = get_ov_output(x)
72+
et = x.get_element_type()
73+
thr = get_ov_output(threshold, et)
74+
zero = get_ov_output(0.0, et)
75+
cond = ov_opset.greater(ov_opset.abs(x), thr)
76+
out = ov_opset.select(cond, x, zero)
77+
return OpenVINOKerasTensor(out.output(0))
78+
79+
2980
def softplus(x):
3081
x = get_ov_output(x)
3182
return OpenVINOKerasTensor(ov_opset.softplus(x).output(0))
@@ -38,14 +89,15 @@ def softsign(x):
3889

3990
def silu(x):
4091
x = get_ov_output(x)
41-
return OpenVINOKerasTensor(
42-
ov_opset.multiply(x, ov_opset.sigmoid(x)).output(0)
43-
)
92+
beta = get_ov_output(1.0, x.get_element_type())
93+
return OpenVINOKerasTensor(ov_opset.swish(x, beta=beta).output(0))
4494

4595

4696
def log_sigmoid(x):
47-
raise NotImplementedError(
48-
"`log_sigmoid` is not supported with openvino backend"
97+
x = get_ov_output(x)
98+
neg_x = ov_opset.negative(x)
99+
return OpenVINOKerasTensor(
100+
ov_opset.negative(ov_opset.softplus(neg_x)).output(0)
49101
)
50102

51103

@@ -58,6 +110,17 @@ def leaky_relu(x, negative_slope=0.2):
58110
return OpenVINOKerasTensor(leaky_relu)
59111

60112

113+
def sparse_sigmoid(x):
114+
x = get_ov_output(x)
115+
et = x.get_element_type()
116+
one = get_ov_output(1.0, et)
117+
neg_one = get_ov_output(-1.0, et)
118+
half = get_ov_output(0.5, et)
119+
y = ov_opset.minimum(ov_opset.maximum(x, neg_one), one)
120+
out = ov_opset.multiply(half, ov_opset.add(y, one))
121+
return OpenVINOKerasTensor(out.output(0))
122+
123+
61124
def hard_sigmoid(x):
62125
x = get_ov_output(x)
63126
alpha = get_ov_output(1.0 / 6.0, x.get_element_type())
@@ -121,15 +184,67 @@ def log_softmax(x, axis=-1):
121184
return OpenVINOKerasTensor(ov_opset.log_softmax(x, axis).output(0))
122185

123186

187+
def squareplus(x, b=4):
188+
x = get_ov_output(x)
189+
et = x.get_element_type()
190+
b = get_ov_output(b, et)
191+
two = get_ov_output(2.0, et)
192+
x_squared = ov_opset.multiply(x, x)
193+
inside = ov_opset.add(x_squared, b)
194+
root = ov_opset.sqrt(inside)
195+
summed = ov_opset.add(x, root)
196+
out = ov_opset.divide(summed, two)
197+
return OpenVINOKerasTensor(out.output(0))
198+
199+
200+
def sparse_plus(x):
201+
x = get_ov_output(x)
202+
et = x.get_element_type()
203+
one = get_ov_output(1.0, et)
204+
neg_one = get_ov_output(-1.0, et)
205+
zero = get_ov_output(0.0, et)
206+
quarter = get_ov_output(0.25, et)
207+
x_plus_1 = ov_opset.add(x, one)
208+
quad = ov_opset.multiply(quarter, ov_opset.multiply(x_plus_1, x_plus_1))
209+
leq_than_neg_one = ov_opset.less_equal(x, neg_one)
210+
less_than_one = ov_opset.less(x, one)
211+
out = ov_opset.select(
212+
leq_than_neg_one,
213+
zero,
214+
ov_opset.select(less_than_one, quad, x),
215+
)
216+
return OpenVINOKerasTensor(out.output(0))
217+
218+
219+
def threshold(x, threshold, default_value):
220+
x = get_ov_output(x)
221+
et = x.get_element_type()
222+
thr = get_ov_output(threshold, et)
223+
dv = get_ov_output(default_value, et)
224+
cond = ov_opset.greater(x, thr)
225+
out = ov_opset.select(cond, x, dv)
226+
return OpenVINOKerasTensor(out.output(0))
227+
228+
124229
def max_pool(
125230
inputs,
126231
pool_size,
127232
strides=None,
128233
padding="valid",
129234
data_format=None,
130235
):
131-
raise NotImplementedError(
132-
"`max_pool` is not supported with openvino backend"
236+
num_spatial_dims = (
237+
get_ov_output(inputs).get_partial_shape().rank.get_length() - 2
238+
)
239+
kwargs = {"dilations": [1] * num_spatial_dims} # required for ov max_pool
240+
return _pool(
241+
inputs,
242+
pool_size,
243+
ov_opset.max_pool,
244+
strides,
245+
padding,
246+
data_format,
247+
**kwargs,
133248
)
134249

135250

@@ -140,11 +255,52 @@ def average_pool(
140255
padding="valid",
141256
data_format=None,
142257
):
143-
raise NotImplementedError(
144-
"`average_pool` is not supported with openvino backend"
258+
return _pool(
259+
inputs,
260+
pool_size,
261+
ov_opset.avg_pool,
262+
strides,
263+
padding,
264+
data_format,
265+
exclude_pad=True,
145266
)
146267

147268

269+
def _pool(
270+
inputs,
271+
pool_size,
272+
pooling_func,
273+
strides=None,
274+
padding="valid",
275+
data_format=None,
276+
**kwargs,
277+
):
278+
data_format = backend.standardize_data_format(data_format)
279+
inputs = get_ov_output(inputs)
280+
281+
num_spatial_dims = inputs.get_partial_shape().rank.get_length() - 2
282+
if isinstance(pool_size, int):
283+
pool_size = [pool_size] * num_spatial_dims
284+
285+
if strides is None:
286+
strides = pool_size
287+
288+
strides = _adjust_strides_dilation(strides, num_spatial_dims)
289+
pad_mode, pads_begin, pads_end = _adjust_padding(padding)
290+
inputs = _adjust_input(inputs, num_spatial_dims, data_format)
291+
pool_kwargs = {
292+
"kernel_shape": pool_size,
293+
"strides": strides,
294+
"auto_pad": pad_mode,
295+
"pads_begin": pads_begin,
296+
"pads_end": pads_end,
297+
**kwargs,
298+
}
299+
pooled = pooling_func(inputs, **pool_kwargs).output(0)
300+
adjusted_pooled = _adjust_outputs(pooled, num_spatial_dims, data_format)
301+
return OpenVINOKerasTensor(adjusted_pooled)
302+
303+
148304
def _adjust_strides_dilation(
149305
x,
150306
num_spatial_dims,
@@ -374,9 +530,22 @@ def conv_transpose(
374530

375531

376532
def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
377-
raise NotImplementedError(
378-
"`one_hot` is not supported with openvino backend"
379-
)
533+
if sparse:
534+
raise ValueError("`sparse=True` is not supported with openvino backend")
535+
x = get_ov_output(x)
536+
if dtype is None:
537+
dtype = backend.floatx()
538+
ov_dtype = OPENVINO_DTYPES[dtype]
539+
on_value = get_ov_output(1, ov_dtype)
540+
off_value = get_ov_output(0, ov_dtype)
541+
one_hot_encoded = ov_opset.one_hot(
542+
x,
543+
depth=num_classes,
544+
axis=axis,
545+
on_value=on_value,
546+
off_value=off_value,
547+
).output(0)
548+
return OpenVINOKerasTensor(one_hot_encoded)
380549

381550

382551
def multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
@@ -465,9 +634,15 @@ def batch_normalization(
465634

466635

467636
def ctc_loss(target, output, target_length, output_length, mask_index=0):
468-
raise NotImplementedError(
469-
"`ctc_loss` is not supported with openvino backend"
637+
target = get_ov_output(target)
638+
output = get_ov_output(output)
639+
target_length = get_ov_output(target_length)
640+
output_length = get_ov_output(output_length)
641+
ctc_loss_ = ov_opset.ctc_loss(
642+
output, output_length, target, target_length, blank_index=mask_index
470643
)
644+
ctc_loss_ = ov_opset.convert(ctc_loss_, OPENVINO_DTYPES[backend.floatx()])
645+
return OpenVINOKerasTensor(ctc_loss_.output(0))
471646

472647

473648
def ctc_decode(
@@ -499,9 +674,46 @@ def dot_product_attention(
499674
flash_attention=None,
500675
attn_logits_soft_cap=None,
501676
):
502-
raise NotImplementedError(
503-
"`dot_product_attention` is not supported with openvino backend"
677+
if bias is not None:
678+
raise NotImplementedError(
679+
"`dot_product_attention` with `bias` is not supported "
680+
"with openvino backend"
681+
)
682+
if flash_attention:
683+
raise NotImplementedError(
684+
"`dot_product_attention` with `flash_attention` is not supported "
685+
"with openvino backend"
686+
)
687+
if attn_logits_soft_cap is not None:
688+
raise NotImplementedError(
689+
"`dot_product_attention` with `attn_logits_soft_cap` is not "
690+
"supported with openvino backend"
691+
)
692+
query = get_ov_output(query)
693+
key = get_ov_output(key)
694+
value = get_ov_output(value)
695+
if query.get_element_type() != key.get_element_type():
696+
ov_type = OPENVINO_DTYPES[backend.floatx()]
697+
query = ov_opset.convert(query, ov_type).output(0)
698+
key = ov_opset.convert(key, ov_type).output(0)
699+
if value.get_element_type() != query.get_element_type():
700+
value = ov_opset.convert(value, query.get_element_type()).output(0)
701+
axes_const = ov_opset.constant([0, 2, 1, 3], Type.i32).output(0)
702+
703+
query = ov_opset.transpose(query, axes_const)
704+
key = ov_opset.transpose(key, axes_const)
705+
value = ov_opset.transpose(value, axes_const)
706+
mask = get_ov_output(mask) if mask is not None else None
707+
scale = (
708+
get_ov_output(scale, query.get_element_type())
709+
if scale is not None
710+
else None
711+
)
712+
dpa = ov_opset.scaled_dot_product_attention(
713+
query, key, value, attention_mask=mask, scale=scale, causal=is_causal
504714
)
715+
dpa = ov_opset.transpose(dpa, axes_const)
716+
return OpenVINOKerasTensor(dpa.output(0))
505717

506718

507719
def unfold(input, kernel_size, dilation=1, padding=0, stride=1):

keras/src/ops/nn_test.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2492,19 +2492,20 @@ def test_dot_product_attention(
24922492
mask = mask[None, None, ...]
24932493
mask = np.tile(mask, (2, 4, 1, 1))
24942494
if bias is not None:
2495-
if backend.backend() == "torch":
2495+
if backend.backend() in ("torch", "openvino"):
24962496
self.skipTest(
2497-
"torch does not support `bias` with `dot_product_attention`"
2497+
"torch and openvino do not support `bias` with "
2498+
"`dot_product_attention`"
24982499
)
24992500
bias = np.arange(math.prod(bias_shape), dtype=float).reshape(
25002501
bias_shape
25012502
)
25022503

25032504
if flash_attention:
2504-
if backend.backend() in ("tensorflow", "numpy"):
2505+
if backend.backend() in ("tensorflow", "numpy", "openvino"):
25052506
self.skipTest(
2506-
"Flash attention is not supported in tensorflow and numpy "
2507-
"backends."
2507+
"Flash attention is not supported in tensorflow, numpy, "
2508+
"and openvino backends."
25082509
)
25092510
elif backend.backend() == "torch":
25102511
import torch

0 commit comments

Comments
 (0)