66import math
77
88import inspect
9- signature = inspect .signature (mindspore .ops .max )
10- if 'keep_dims' in signature .parameters :
11- MS_MAX_UNDERLINE = True
12- elif 'keepdims' in signature .parameters :
13- MS_MAX_UNDERLINE = False
9+ import functools
10+ _max_signature = inspect .signature (mindspore .ops .max )
11+ if 'keep_dims' in _max_signature .parameters :
12+ def _ms_max (* args , keep_dims = False , ** kwargs ):
13+ return mindspore .ops .max (* args , keep_dims = keep_dims , ** kwargs )
14+ elif 'keepdims' in _max_signature .parameters :
15+ def _ms_max (* args , keep_dims = False , ** kwargs ):
16+ max , indices = mindspore .ops .max (* args , keepdims = keep_dims , ** kwargs )
17+ return indices , max
1418else :
15- raise ValueError ('Mindspore function mindspore.ops.max has unknown signature' )
19+ raise ValueError ('Mindspore function mindspore.ops.max has unsupported signature. It is likely you are working with '
20+ 'a new Mindspore version which breaks backward compatibility. Please report your Mindspore version '
21+ 'to GitHub issues.' )
22+
23+ _logsumexp_signature = inspect .signature (mindspore .ops .logsumexp )
24+ if 'keep_dims' in _logsumexp_signature .parameters :
25+ _ms_logsumexp_keepdim = functools .partial (mindspore .ops .logsumexp , keep_dims = True )
26+ elif 'keepdim' in _logsumexp_signature .parameters :
27+ _ms_logsumexp_keepdim = functools .partial (mindspore .ops .logsumexp , keepdim = True )
28+ else :
29+ raise ValueError ('Mindspore function mindspore.ops.logsumexp has unsupported signature. It is likely you are '
30+ 'working with a new Mindspore version which breaks backward compatibility. Please report your '
31+ 'Mindspore version to GitHub issues.' )
1632
1733#############################################
1834# Linear Assignment Problem Solvers #
@@ -162,22 +178,14 @@ def sinkhorn(s: mindspore.Tensor, nrows: mindspore.Tensor = None, ncols: mindspo
162178
163179 for i in range (max_iter ):
164180 if i % 2 == 0 :
165- if MS_MAX_UNDERLINE :
166- index , m = mindspore .ops .max (log_s , axis = 2 , keep_dims = True )
167- log_sum = mindspore .ops .logsumexp (log_s - m , 2 , keep_dims = True ) + m
168- else :
169- index , m = mindspore .ops .max (log_s , axis = 2 , keepdims = True )
170- log_sum = mindspore .ops .logsumexp (log_s - m , 2 , keepdim = True ) + m
181+ index , m = _ms_max (log_s , axis = 2 , keep_dims = True )
182+ log_sum = _ms_logsumexp_keepdim (log_s - m , 2 ) + m
171183 log_s = log_s - mindspore .numpy .where (row_mask , log_sum , mindspore .numpy .zeros_like (log_sum ))
172184 if mindspore .ops .isnan (log_s ).any ():
173185 raise RuntimeError (f'NaN encountered in Sinkhorn iter_num={ i } /{ max_iter } ' )
174186 else :
175- if MS_MAX_UNDERLINE :
176- index , m = mindspore .ops .max (log_s , axis = 1 , keep_dims = True )
177- log_sum = mindspore .ops .logsumexp (log_s - m , 1 , keep_dims = True ) + m
178- else :
179- index , m = mindspore .ops .max (log_s , axis = 1 , keepdims = True )
180- log_sum = mindspore .ops .logsumexp (log_s - m , 1 , keepdim = True ) + m
187+ index , m = _ms_max (log_s , axis = 1 , keep_dims = True )
188+ log_sum = _ms_logsumexp_keepdim (log_s - m , 1 ) + m
181189 log_s = log_s - mindspore .numpy .where (col_mask , log_sum , mindspore .numpy .zeros_like (log_sum ))
182190 if mindspore .ops .isnan (log_s ).any ():
183191 raise RuntimeError (f'NaN encountered in Sinkhorn iter_num={ i } /{ max_iter } ' )
@@ -195,20 +203,12 @@ def sinkhorn(s: mindspore.Tensor, nrows: mindspore.Tensor = None, ncols: mindspo
195203
196204 for i in range (max_iter ):
197205 if i % 2 == 0 :
198- if MS_MAX_UNDERLINE :
199- index , m = mindspore .ops .max (log_s_b , axis = 1 , keep_dims = True )
200- log_sum = mindspore .ops .logsumexp (log_s_b - m , 1 , keep_dims = True ) + m
201- else :
202- index , m = mindspore .ops .max (log_s_b , axis = 1 , keepdims = True )
203- log_sum = mindspore .ops .logsumexp (log_s_b - m , 1 , keepdim = True ) + m
206+ index , m = _ms_max (log_s_b , axis = 1 , keep_dims = True )
207+ log_sum = _ms_logsumexp_keepdim (log_s_b - m , 1 ) + m
204208 log_s_b = log_s_b - mindspore .numpy .where (row_mask_b , log_sum , mindspore .numpy .zeros_like (log_sum ))
205209 else :
206- if MS_MAX_UNDERLINE :
207- index , m = mindspore .ops .max (log_s_b , axis = 0 , keep_dims = True )
208- log_sum = mindspore .ops .logsumexp (log_s_b - m , 0 , keep_dims = True ) + m
209- else :
210- index , m = mindspore .ops .max (log_s_b , axis = 0 , keepdims = True )
211- log_sum = mindspore .ops .logsumexp (log_s_b - m , 0 , keepdim = True ) + m
210+ index , m = _ms_max (log_s_b , axis = 0 , keep_dims = True )
211+ log_sum = _ms_logsumexp_keepdim (log_s_b - m , 0 ) + m
212212 log_s_b = log_s_b - mindspore .numpy .where (col_mask_b , log_sum , mindspore .numpy .zeros_like (log_sum ))
213213
214214 ret_log_s [b , row_slice , col_slice ] = log_s_b
@@ -325,7 +325,7 @@ def comp_obj_score(v1, K, v2):
325325 best_v = mindspore .numpy .where (current_obj > best_obj , binary_v , best_v )
326326 best_obj = mindspore .numpy .where (current_obj > best_obj , current_obj , best_obj )
327327
328- if (mindspore . ops . max (mindspore .ops .abs (last_v_obj - current_obj ) / last_v_obj )[1 ] < 1e-3 ).any ():
328+ if (_ms_max (mindspore .ops .abs (last_v_obj - current_obj ) / last_v_obj )[1 ] < 1e-3 ).any ():
329329 break
330330 last_v = v
331331
@@ -344,9 +344,9 @@ def _check_and_init_gm(K, n1, n2, n1max, n2max, x0):
344344 if n2 is None :
345345 n2 = mindspore .numpy .full ((batch_num ,), n2max , dtype = mindspore .numpy .int_ )
346346 if n1max is None :
347- n1max = mindspore . ops . max (n1 )[1 ]
347+ n1max = _ms_max (n1 )[1 ]
348348 if n2max is None :
349- n2max = mindspore . ops . max (n2 )[1 ]
349+ n2max = _ms_max (n2 )[1 ]
350350
351351 if not n1max * n2max == n1n2 :
352352 raise ValueError ('the input size of K does not match with n1max * n2max!' )
0 commit comments