22from openvino import Type
33
44from keras .src import backend
5+ from keras .src .backend .openvino .core import OPENVINO_DTYPES
56from keras .src .backend .openvino .core import OpenVINOKerasTensor
67from keras .src .backend .openvino .core import get_ov_output
78
@@ -16,6 +17,23 @@ def relu6(x):
1617 return OpenVINOKerasTensor (ov_opset .clamp (x , 0.0 , 6.0 ).output (0 ))
1718
1819
20+ def celu (x , alpha = 1.0 ):
21+ x = get_ov_output (x )
22+ const_zero = get_ov_output (0.0 , x .get_element_type ())
23+ const_alpha = get_ov_output (alpha , x .get_element_type ())
24+ const_one = get_ov_output (1.0 , x .get_element_type ())
25+ exp_x_div_alpha = ov_opset .exp (ov_opset .divide (x , const_alpha )).output (0 )
26+ negative_branch = ov_opset .multiply (
27+ const_alpha , ov_opset .subtract (exp_x_div_alpha , const_one )
28+ )
29+
30+ celu_x = ov_opset .add (
31+ ov_opset .maximum (x , const_zero ).output (0 ),
32+ ov_opset .minimum (negative_branch , const_zero ).output (0 ),
33+ )
34+ return OpenVINOKerasTensor (celu_x .output (0 ))
35+
36+
1937def sigmoid (x ):
2038 x = get_ov_output (x )
2139 return OpenVINOKerasTensor (ov_opset .sigmoid (x ).output (0 ))
@@ -26,6 +44,39 @@ def tanh(x):
2644 return OpenVINOKerasTensor (ov_opset .tanh (x ).output (0 ))
2745
2846
47+ def tanh_shrink (x ):
48+ x = get_ov_output (x )
49+ return OpenVINOKerasTensor (ov_opset .subtract (x , ov_opset .tanh (x )).output (0 ))
50+
51+
52+ def hard_tanh (x ):
53+ x = get_ov_output (x )
54+ return OpenVINOKerasTensor (ov_opset .clamp (x , - 1.0 , 1.0 ).output (0 ))
55+
56+
57+ def soft_shrink (x , threshold = 0.5 ):
58+ x = get_ov_output (x )
59+ et = x .get_element_type ()
60+ thr = get_ov_output (threshold , et )
61+ zero = get_ov_output (0.0 , et )
62+ abs_x = ov_opset .abs (x )
63+ sub = ov_opset .subtract (abs_x , thr )
64+ shrunk = ov_opset .maximum (sub , zero )
65+ sign = ov_opset .sign (x )
66+ out = ov_opset .multiply (sign , shrunk )
67+ return OpenVINOKerasTensor (out .output (0 ))
68+
69+
70+ def hard_shrink (x , threshold = 0.5 ):
71+ x = get_ov_output (x )
72+ et = x .get_element_type ()
73+ thr = get_ov_output (threshold , et )
74+ zero = get_ov_output (0.0 , et )
75+ cond = ov_opset .greater (ov_opset .abs (x ), thr )
76+ out = ov_opset .select (cond , x , zero )
77+ return OpenVINOKerasTensor (out .output (0 ))
78+
79+
2980def softplus (x ):
3081 x = get_ov_output (x )
3182 return OpenVINOKerasTensor (ov_opset .softplus (x ).output (0 ))
@@ -38,14 +89,15 @@ def softsign(x):
3889
3990def silu (x ):
4091 x = get_ov_output (x )
41- return OpenVINOKerasTensor (
42- ov_opset .multiply (x , ov_opset .sigmoid (x )).output (0 )
43- )
92+ beta = get_ov_output (1.0 , x .get_element_type ())
93+ return OpenVINOKerasTensor (ov_opset .swish (x , beta = beta ).output (0 ))
4494
4595
4696def log_sigmoid (x ):
47- raise NotImplementedError (
48- "`log_sigmoid` is not supported with openvino backend"
97+ x = get_ov_output (x )
98+ neg_x = ov_opset .negative (x )
99+ return OpenVINOKerasTensor (
100+ ov_opset .negative (ov_opset .softplus (neg_x )).output (0 )
49101 )
50102
51103
@@ -58,6 +110,17 @@ def leaky_relu(x, negative_slope=0.2):
58110 return OpenVINOKerasTensor (leaky_relu )
59111
60112
113+ def sparse_sigmoid (x ):
114+ x = get_ov_output (x )
115+ et = x .get_element_type ()
116+ one = get_ov_output (1.0 , et )
117+ neg_one = get_ov_output (- 1.0 , et )
118+ half = get_ov_output (0.5 , et )
119+ y = ov_opset .minimum (ov_opset .maximum (x , neg_one ), one )
120+ out = ov_opset .multiply (half , ov_opset .add (y , one ))
121+ return OpenVINOKerasTensor (out .output (0 ))
122+
123+
61124def hard_sigmoid (x ):
62125 x = get_ov_output (x )
63126 alpha = get_ov_output (1.0 / 6.0 , x .get_element_type ())
@@ -121,15 +184,67 @@ def log_softmax(x, axis=-1):
121184 return OpenVINOKerasTensor (ov_opset .log_softmax (x , axis ).output (0 ))
122185
123186
187+ def squareplus (x , b = 4 ):
188+ x = get_ov_output (x )
189+ et = x .get_element_type ()
190+ b = get_ov_output (b , et )
191+ two = get_ov_output (2.0 , et )
192+ x_squared = ov_opset .multiply (x , x )
193+ inside = ov_opset .add (x_squared , b )
194+ root = ov_opset .sqrt (inside )
195+ summed = ov_opset .add (x , root )
196+ out = ov_opset .divide (summed , two )
197+ return OpenVINOKerasTensor (out .output (0 ))
198+
199+
200+ def sparse_plus (x ):
201+ x = get_ov_output (x )
202+ et = x .get_element_type ()
203+ one = get_ov_output (1.0 , et )
204+ neg_one = get_ov_output (- 1.0 , et )
205+ zero = get_ov_output (0.0 , et )
206+ quarter = get_ov_output (0.25 , et )
207+ x_plus_1 = ov_opset .add (x , one )
208+ quad = ov_opset .multiply (quarter , ov_opset .multiply (x_plus_1 , x_plus_1 ))
209+ leq_than_neg_one = ov_opset .less_equal (x , neg_one )
210+ less_than_one = ov_opset .less (x , one )
211+ out = ov_opset .select (
212+ leq_than_neg_one ,
213+ zero ,
214+ ov_opset .select (less_than_one , quad , x ),
215+ )
216+ return OpenVINOKerasTensor (out .output (0 ))
217+
218+
219+ def threshold (x , threshold , default_value ):
220+ x = get_ov_output (x )
221+ et = x .get_element_type ()
222+ thr = get_ov_output (threshold , et )
223+ dv = get_ov_output (default_value , et )
224+ cond = ov_opset .greater (x , thr )
225+ out = ov_opset .select (cond , x , dv )
226+ return OpenVINOKerasTensor (out .output (0 ))
227+
228+
124229def max_pool (
125230 inputs ,
126231 pool_size ,
127232 strides = None ,
128233 padding = "valid" ,
129234 data_format = None ,
130235):
131- raise NotImplementedError (
132- "`max_pool` is not supported with openvino backend"
236+ num_spatial_dims = (
237+ get_ov_output (inputs ).get_partial_shape ().rank .get_length () - 2
238+ )
239+ kwargs = {"dilations" : [1 ] * num_spatial_dims } # required for ov max_pool
240+ return _pool (
241+ inputs ,
242+ pool_size ,
243+ ov_opset .max_pool ,
244+ strides ,
245+ padding ,
246+ data_format ,
247+ ** kwargs ,
133248 )
134249
135250
@@ -140,11 +255,52 @@ def average_pool(
140255 padding = "valid" ,
141256 data_format = None ,
142257):
143- raise NotImplementedError (
144- "`average_pool` is not supported with openvino backend"
258+ return _pool (
259+ inputs ,
260+ pool_size ,
261+ ov_opset .avg_pool ,
262+ strides ,
263+ padding ,
264+ data_format ,
265+ exclude_pad = True ,
145266 )
146267
147268
269+ def _pool (
270+ inputs ,
271+ pool_size ,
272+ pooling_func ,
273+ strides = None ,
274+ padding = "valid" ,
275+ data_format = None ,
276+ ** kwargs ,
277+ ):
278+ data_format = backend .standardize_data_format (data_format )
279+ inputs = get_ov_output (inputs )
280+
281+ num_spatial_dims = inputs .get_partial_shape ().rank .get_length () - 2
282+ if isinstance (pool_size , int ):
283+ pool_size = [pool_size ] * num_spatial_dims
284+
285+ if strides is None :
286+ strides = pool_size
287+
288+ strides = _adjust_strides_dilation (strides , num_spatial_dims )
289+ pad_mode , pads_begin , pads_end = _adjust_padding (padding )
290+ inputs = _adjust_input (inputs , num_spatial_dims , data_format )
291+ pool_kwargs = {
292+ "kernel_shape" : pool_size ,
293+ "strides" : strides ,
294+ "auto_pad" : pad_mode ,
295+ "pads_begin" : pads_begin ,
296+ "pads_end" : pads_end ,
297+ ** kwargs ,
298+ }
299+ pooled = pooling_func (inputs , ** pool_kwargs ).output (0 )
300+ adjusted_pooled = _adjust_outputs (pooled , num_spatial_dims , data_format )
301+ return OpenVINOKerasTensor (adjusted_pooled )
302+
303+
148304def _adjust_strides_dilation (
149305 x ,
150306 num_spatial_dims ,
@@ -374,9 +530,22 @@ def conv_transpose(
374530
375531
376532def one_hot (x , num_classes , axis = - 1 , dtype = None , sparse = False ):
377- raise NotImplementedError (
378- "`one_hot` is not supported with openvino backend"
379- )
533+ if sparse :
534+ raise ValueError ("`sparse=True` is not supported with openvino backend" )
535+ x = get_ov_output (x )
536+ if dtype is None :
537+ dtype = backend .floatx ()
538+ ov_dtype = OPENVINO_DTYPES [dtype ]
539+ on_value = get_ov_output (1 , ov_dtype )
540+ off_value = get_ov_output (0 , ov_dtype )
541+ one_hot_encoded = ov_opset .one_hot (
542+ x ,
543+ depth = num_classes ,
544+ axis = axis ,
545+ on_value = on_value ,
546+ off_value = off_value ,
547+ ).output (0 )
548+ return OpenVINOKerasTensor (one_hot_encoded )
380549
381550
382551def multi_hot (x , num_classes , axis = - 1 , dtype = None , sparse = False ):
@@ -465,9 +634,15 @@ def batch_normalization(
465634
466635
467636def ctc_loss (target , output , target_length , output_length , mask_index = 0 ):
468- raise NotImplementedError (
469- "`ctc_loss` is not supported with openvino backend"
637+ target = get_ov_output (target )
638+ output = get_ov_output (output )
639+ target_length = get_ov_output (target_length )
640+ output_length = get_ov_output (output_length )
641+ ctc_loss_ = ov_opset .ctc_loss (
642+ output , output_length , target , target_length , blank_index = mask_index
470643 )
644+ ctc_loss_ = ov_opset .convert (ctc_loss_ , OPENVINO_DTYPES [backend .floatx ()])
645+ return OpenVINOKerasTensor (ctc_loss_ .output (0 ))
471646
472647
473648def ctc_decode (
@@ -499,9 +674,46 @@ def dot_product_attention(
499674 flash_attention = None ,
500675 attn_logits_soft_cap = None ,
501676):
502- raise NotImplementedError (
503- "`dot_product_attention` is not supported with openvino backend"
677+ if bias is not None :
678+ raise NotImplementedError (
679+ "`dot_product_attention` with `bias` is not supported "
680+ "with openvino backend"
681+ )
682+ if flash_attention :
683+ raise NotImplementedError (
684+ "`dot_product_attention` with `flash_attention` is not supported "
685+ "with openvino backend"
686+ )
687+ if attn_logits_soft_cap is not None :
688+ raise NotImplementedError (
689+ "`dot_product_attention` with `attn_logits_soft_cap` is not "
690+ "supported with openvino backend"
691+ )
692+ query = get_ov_output (query )
693+ key = get_ov_output (key )
694+ value = get_ov_output (value )
695+ if query .get_element_type () != key .get_element_type ():
696+ ov_type = OPENVINO_DTYPES [backend .floatx ()]
697+ query = ov_opset .convert (query , ov_type ).output (0 )
698+ key = ov_opset .convert (key , ov_type ).output (0 )
699+ if value .get_element_type () != query .get_element_type ():
700+ value = ov_opset .convert (value , query .get_element_type ()).output (0 )
701+ axes_const = ov_opset .constant ([0 , 2 , 1 , 3 ], Type .i32 ).output (0 )
702+
703+ query = ov_opset .transpose (query , axes_const )
704+ key = ov_opset .transpose (key , axes_const )
705+ value = ov_opset .transpose (value , axes_const )
706+ mask = get_ov_output (mask ) if mask is not None else None
707+ scale = (
708+ get_ov_output (scale , query .get_element_type ())
709+ if scale is not None
710+ else None
711+ )
712+ dpa = ov_opset .scaled_dot_product_attention (
713+ query , key , value , attention_mask = mask , scale = scale , causal = is_causal
504714 )
715+ dpa = ov_opset .transpose (dpa , axes_const )
716+ return OpenVINOKerasTensor (dpa .output (0 ))
505717
506718
507719def unfold (input , kernel_size , dilation = 1 , padding = 0 , stride = 1 ):
0 commit comments