Skip to content

Commit 8557cc9

Browse files
committed
try fixing ms backward compatibility
1 parent 9850edc commit 8557cc9

File tree

1 file changed

+33
-33
lines changed

1 file changed

+33
-33
lines changed

pygmtools/mindspore_backend.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,29 @@
66
import math
77

88
import inspect
9-
signature = inspect.signature(mindspore.ops.max)
10-
if 'keep_dims' in signature.parameters:
11-
MS_MAX_UNDERLINE = True
12-
elif 'keepdims' in signature.parameters:
13-
MS_MAX_UNDERLINE = False
9+
import functools
10+
_max_signature = inspect.signature(mindspore.ops.max)
11+
if 'keep_dims' in _max_signature.parameters:
12+
def _ms_max(*args, keep_dims=False, **kwargs):
13+
return mindspore.ops.max(*args, keep_dims=keep_dims, **kwargs)
14+
elif 'keepdims' in _max_signature.parameters:
15+
def _ms_max(*args, keep_dims=False, **kwargs):
16+
max, indices = mindspore.ops.max(*args, keepdims=keep_dims, **kwargs)
17+
return indices, max
1418
else:
15-
raise ValueError('Mindspore function mindspore.ops.max has unknown signature')
19+
raise ValueError('Mindspore function mindspore.ops.max has unsupported signature. It is likely you are working with '
20+
'a new Mindspore version which breaks backward compatibility. Please report your Mindspore version '
21+
'to GitHub issues.')
22+
23+
_logsumexp_signature = inspect.signature(mindspore.ops.logsumexp)
24+
if 'keep_dims' in _logsumexp_signature.parameters:
25+
_ms_logsumexp_keepdim = functools.partial(mindspore.ops.logsumexp, keep_dims=True)
26+
elif 'keepdim' in _logsumexp_signature.parameters:
27+
_ms_logsumexp_keepdim = functools.partial(mindspore.ops.logsumexp, keepdim=True)
28+
else:
29+
raise ValueError('Mindspore function mindspore.ops.logsumexp has unsupported signature. It is likely you are '
30+
'working with a new Mindspore version which breaks backward compatibility. Please report your '
31+
'Mindspore version to GitHub issues.')
1632

1733
#############################################
1834
# Linear Assignment Problem Solvers #
@@ -162,22 +178,14 @@ def sinkhorn(s: mindspore.Tensor, nrows: mindspore.Tensor = None, ncols: mindspo
162178

163179
for i in range(max_iter):
164180
if i % 2 == 0:
165-
if MS_MAX_UNDERLINE:
166-
index, m = mindspore.ops.max(log_s, axis=2, keep_dims=True)
167-
log_sum = mindspore.ops.logsumexp(log_s - m, 2, keep_dims=True) + m
168-
else:
169-
index, m = mindspore.ops.max(log_s, axis=2, keepdims=True)
170-
log_sum = mindspore.ops.logsumexp(log_s - m, 2, keepdim=True) + m
181+
index, m = _ms_max(log_s, axis=2, keep_dims=True)
182+
log_sum = _ms_logsumexp_keepdim(log_s - m, 2) + m
171183
log_s = log_s - mindspore.numpy.where(row_mask, log_sum, mindspore.numpy.zeros_like(log_sum))
172184
if mindspore.ops.isnan(log_s).any():
173185
raise RuntimeError(f'NaN encountered in Sinkhorn iter_num={i}/{max_iter}')
174186
else:
175-
if MS_MAX_UNDERLINE:
176-
index, m = mindspore.ops.max(log_s, axis=1, keep_dims=True)
177-
log_sum = mindspore.ops.logsumexp(log_s - m, 1, keep_dims=True) + m
178-
else:
179-
index, m = mindspore.ops.max(log_s, axis=1, keepdims=True)
180-
log_sum = mindspore.ops.logsumexp(log_s - m, 1, keepdim=True) + m
187+
index, m = _ms_max(log_s, axis=1, keep_dims=True)
188+
log_sum = _ms_logsumexp_keepdim(log_s - m, 1) + m
181189
log_s = log_s - mindspore.numpy.where(col_mask, log_sum, mindspore.numpy.zeros_like(log_sum))
182190
if mindspore.ops.isnan(log_s).any():
183191
raise RuntimeError(f'NaN encountered in Sinkhorn iter_num={i}/{max_iter}')
@@ -195,20 +203,12 @@ def sinkhorn(s: mindspore.Tensor, nrows: mindspore.Tensor = None, ncols: mindspo
195203

196204
for i in range(max_iter):
197205
if i % 2 == 0:
198-
if MS_MAX_UNDERLINE:
199-
index, m = mindspore.ops.max(log_s_b, axis=1, keep_dims=True)
200-
log_sum = mindspore.ops.logsumexp(log_s_b - m, 1, keep_dims=True) + m
201-
else:
202-
index, m = mindspore.ops.max(log_s_b, axis=1, keepdims=True)
203-
log_sum = mindspore.ops.logsumexp(log_s_b - m, 1, keepdim=True) + m
206+
index, m = _ms_max(log_s_b, axis=1, keep_dims=True)
207+
log_sum = _ms_logsumexp_keepdim(log_s_b - m, 1) + m
204208
log_s_b = log_s_b - mindspore.numpy.where(row_mask_b, log_sum, mindspore.numpy.zeros_like(log_sum))
205209
else:
206-
if MS_MAX_UNDERLINE:
207-
index, m = mindspore.ops.max(log_s_b, axis=0, keep_dims=True)
208-
log_sum = mindspore.ops.logsumexp(log_s_b - m, 0, keep_dims=True) + m
209-
else:
210-
index, m = mindspore.ops.max(log_s_b, axis=0, keepdims=True)
211-
log_sum = mindspore.ops.logsumexp(log_s_b - m, 0, keepdim=True) + m
210+
index, m = _ms_max(log_s_b, axis=0, keep_dims=True)
211+
log_sum = _ms_logsumexp_keepdim(log_s_b - m, 0) + m
212212
log_s_b = log_s_b - mindspore.numpy.where(col_mask_b, log_sum, mindspore.numpy.zeros_like(log_sum))
213213

214214
ret_log_s[b, row_slice, col_slice] = log_s_b
@@ -325,7 +325,7 @@ def comp_obj_score(v1, K, v2):
325325
best_v = mindspore.numpy.where(current_obj > best_obj, binary_v, best_v)
326326
best_obj = mindspore.numpy.where(current_obj > best_obj, current_obj, best_obj)
327327

328-
if (mindspore.ops.max(mindspore.ops.abs(last_v_obj - current_obj) / last_v_obj)[1] < 1e-3).any():
328+
if (_ms_max(mindspore.ops.abs(last_v_obj - current_obj) / last_v_obj)[1] < 1e-3).any():
329329
break
330330
last_v = v
331331

@@ -344,9 +344,9 @@ def _check_and_init_gm(K, n1, n2, n1max, n2max, x0):
344344
if n2 is None:
345345
n2 = mindspore.numpy.full((batch_num,), n2max, dtype=mindspore.numpy.int_)
346346
if n1max is None:
347-
n1max = mindspore.ops.max(n1)[1]
347+
n1max = _ms_max(n1)[1]
348348
if n2max is None:
349-
n2max = mindspore.ops.max(n2)[1]
349+
n2max = _ms_max(n2)[1]
350350

351351
if not n1max * n2max == n1n2:
352352
raise ValueError('the input size of K does not match with n1max * n2max!')

0 commit comments

Comments
 (0)