@@ -30,6 +30,19 @@ def _ms_max(*args, keep_dims=False, **kwargs):
3030 'working with a new Mindspore version which breaks backward compatibility. Please report your '
3131 'Mindspore version to GitHub issues.' )
3232
33+ _norm_signature = inspect .signature (mindspore .ops .norm )
34+ if 'axis' in _norm_signature .parameters :
35+ def _ms_norm (* args , axis = None , ** kwargs ):
36+ return mindspore .ops .norm (* args , axis = axis , ** kwargs )
37+ elif 'dim' in _norm_signature .parameters :
38+ def _ms_norm (* args , axis = None , ** kwargs ):
39+ return mindspore .ops .norm (* args , dim = axis , ** kwargs )
40+ else :
41+ raise ValueError ('Mindspore function mindspore.ops.norm has unsupported signature. It is likely you are '
42+ 'working with a new Mindspore version which breaks backward compatibility. Please report your '
43+ 'Mindspore version to GitHub issues.' )
44+
45+
3346#############################################
3447# Linear Assignment Problem Solvers #
3548#############################################
@@ -261,15 +274,15 @@ def rrwm(K: mindspore.Tensor, n1: mindspore.Tensor, n2: mindspore.Tensor, n1max,
261274 # random walk
262275 v = mindspore .ops .BatchMatMul ()(K , v )
263276 last_v = v
264- n = mindspore . ops . norm (v , axis = 1 , p = 1 , keep_dims = True )
277+ n = _ms_norm (v , axis = 1 , p = 1 , keep_dims = True )
265278 v = v / n
266279
267280 # reweighted jump
268281 s = v .view (batch_num , int (n2max ), int (n1max )).swapaxes (1 , 2 )
269282 s = beta * s / s .max (axis = 1 , keepdims = True ).max (axis = 2 , keepdims = True )
270283 v = alpha * sinkhorn (s , n1 , n2 , max_iter = sk_iter , batched_operation = True ).swapaxes (1 , 2 ).reshape (batch_num , n1n2 , 1 ) + \
271284 (1 - alpha ) * v
272- n = mindspore . ops . norm (v , axis = 1 , p = 1 , keep_dims = True )
285+ n = _ms_norm (v , axis = 1 , p = 1 , keep_dims = True )
273286 v = mindspore .ops .matmul (v , 1 / n )
274287
275288 if (v - last_v ).sum ().sqrt () < 1e-5 :
@@ -287,7 +300,7 @@ def sm(K: mindspore.Tensor, n1: mindspore.Tensor, n2: mindspore.Tensor, n1max, n
287300 v = vlast = v0
288301 for i in range (max_iter ):
289302 v = mindspore .ops .BatchMatMul ()(K , v )
290- n = mindspore . ops . norm (v , axis = 1 , p = 2 )
303+ n = _ms_norm (v , axis = 1 , p = 2 )
291304 v = mindspore .ops .matmul (v , (1 / n ).view (batch_num , 1 , 1 ))
292305 if (v - vlast ).sum ().sqrt () < 1e-5 :
293306 break
0 commit comments